How to optimize the tensorRT Engine for Tensor Core?

I tried to use trtexec tool for evaluating my model on Jetson AGX Orin.

The result is not good at execution efficiency.
My model is all most 3x3 convolution and the execution efficiency is 5% only.
(I think my model has dense parameters and cannot use the merit of sparsity.)

So I want to try optimizing my model?
If possible, I want to increase the execution efficiency to be 20%.
Is there any programming tool and libraries for Tensor Core?

Regards,
hiro

Hi,

How do you measure the efficiency?
Do you get it from the GPU utilization of tegrastats?

To run a model with Tensor Core, please infer it with fp16 or int8 mode.
Thanks.

Dear @AastaLLL,

How do you measure the efficiency?
Do you get it from the GPU utilization of tegrastats?

I calculate the efficiency as follows.

Execution efficiency = (The total computational complexity of Convolution) / (processing time * processing speed) * 100

I attached a model and script file for reproduce.
There is a following model, which has 3 layers of 3x3 convolution, in test_model0.py

class TestModel(nn.Module):
    def __init__(self, n_feats, kernel_size):
        super(TestModel, self).__init__()
        self.conv1 = nn.Conv2d(n_feats, n_feats, kernel_size, padding=kernel_size//2, bias=True)
        self.conv2 = nn.Conv2d(n_feats, n_feats, kernel_size, padding=kernel_size//2, bias=True)
        self.conv3 = nn.Conv2d(n_feats, n_feats, kernel_size, padding=kernel_size//2, bias=True)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

And I create onnx model as follows.

size_x = 1920
size_y = 1080
kernel_size = 3
n_feats = 16

model = TestModel(n_feats, kernel_size).eval()
ar = torch.randn(1, n_feats, size_x, size_y)

torch.onnx.export(model, ar, "test_model0.onnx", verbose=False)

The total computational complexity of Convolution is 0.0287 FP16 Tera Floating point operations as follows.

computational complexity = 1920(width) x 1080(height) x 3 x 3 (3x3 convolution) x 2 (multiply-add) x 16 (in_channels) x 16 (out_channels) * 3(layers)/ (10^12) = 0.0287 (FP16 Tera Floating point operations)

When I tested on Jetson AGX Orin Developer Kit with Jetson AGX Orin 32GB emulate mode,
the processing speed on Tensor Core is 47.3 FP16 TFLOPS because I set NVP Model clock as 40W.

Processing speed on Tensor Core = 54 FP16 TFLOPS * 816 MHz/ 930 MHz = 47.3 FP16 TFLOPS.

When I run the trt engine as follows, the GPU Compute Time is around 9 ms.

# creating test_model0.trt
trtexec --buildOnly --fp16 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --precisionConstraints=obey --layerPrecisions=*:fp16 --layerOutputTypes=*:fp16 --sparsity=disable --onnx=test_model0.onnx --saveEngine=test_model0.trt --verbose

# Run test_model0.trt
trtexec --loadEngine=test_model0.trt --verbose

[05/25/2023-18:31:49] [I] GPU Compute Time: min = 8.94312 ms, max = 16.5773 ms, mean = 9.02886 ms, median = 8.97861 ms, percentile(90%) = 8.98486 ms, percentile(95%) = 8.98621 ms, percentile(99%) = 10.574 ms

So I calculate the efficiency(5.46%) as follwos.

Execution Efficiency = 0.0287 / (47.3(TFLOPS) * 0.009(s)) * 100 = 5.46(%)

Please use test_model0_exec.sh for reproduce.

And please give me any advise how to optimize this model.

Regards,
hiro

test_model0.py (826 Bytes)
test_model0_exec.sh (404 Bytes)type or paste code here

Hi,

3x3 convolution implementation is memory bound.

It looks like your channel size is only 3.
It’s recommended to use a big channel (>128) to achieve better computational efficiency.

Thanks.

Dear AastaLLL,

Thank you for your information.

3x3 convolution implementation is memory bound.
It looks like your channel size is only 3.

My channel size is 16 as follows.

n_feats = 16

It’s recommended to use a big channel (>128) to achieve better computational efficiency.

Do you mean we cannot get high computational efficiency if the channel is smaller than 128?

Our network doesn’t need to big channels.
So I want to know why small channel convolution cannot get good computational efficiency.
Could you tell me the implementation of convolution and the reason of low computational efficiency.

Regards,
hiro

Hi,

Sorry for the missing.
However, channel size 16 is still small for TensorRT.

In such a use case, the performance is memory-bound.
This indicates the compute cores access the data but only do a little computation.
The data access overhead is large enough to affect the performance since the computation finishes quite quickly.
The multiplication amount is related to the channel and kernel size.

If the model architecture can be adjusted, it’s recommended to adopt a larger channel size.
For example, changing C=16 & 16layers to C=64 & 4layers.

The former might need to access memory for the kernel weight 16 times but only do 3x3x16 computation for each GPU thread.
The latter only access the memory 4 times and can do 3x3x64 calculation per time.
(Please noted this is an example, the real use case depends on the model and GPU resources)

Thanks

Dear @AastaLLL,

Thank you for your information I will try to use your advice.

By the way, I run the same code on GeForce RTX 3080 10GB and I got the better computational efficiency.

The computational efficiency on RTX 3080 is 10.8589%, so the efficiency is double of Jetson( 5.46%).

Could you tell me the reason?

I calculated it as follows.

1) The total computational complexity of Convolution

The total computational complexity of Convolution is 0.0287 FP16 Tera Floating point operations as follows.

computational complexity = 1920(width) x 1080(height) x 3 x 3 (3x3 convolution) x 2 (multiply-add) x 16 (in_channels) x 16 (out_channels) * 3(layers)/ (10^12) = 0.0287 (FP16 Tera Floating point operations)

2) FP16 Sparse TFLOPS of Tensor Core of RTX 3080

I got 238 FP16 Sparse TFLOPS of Tensor Core of RTX 3080 from following PDF.

nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.pdf

3)Compute Time

When I run the trt engine as follows, the GPU Compute Time is around 1 ms .

# creating test_model0.trt
trtexec --buildOnly --fp16 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --precisionConstraints=obey --layerPrecisions=*:fp16 --layerOutputTypes=*:fp16 --sparsity=disable --onnx=test_model0.onnx --saveEngine=test_model0.trt --verbose

# Run test_model0.trt
trtexec --loadEngine=test_model0.trt --verbose
[06/05/2023-18:08:14] [I] GPU Compute Time: **min = 1.11105 ms**, max = 1.13452 ms, **mean = 1.12111 ms**, median = 1.12024 ms, percentile(90%) = 1.12952 ms, percentile(95%) = 1.13147 ms, percentile(99%) = 1.13257 ms

So I calculate the efficiency(10.9041% ) as follwos.

Computational Efficiency = (The total computational complexity of Convolution) / (processing time * processing speed) * 100
= 0.0287 / (238(TFLOPS) * 0.00111105(s)) * 100 = 10.8589(%)

Regards,
hiro

Hi,

Are the recent experiments also using the model shared in the beginning?
We want to reproduce this and do some further investigation.

Thanks.

Dear AastaLLL,

Thank you for your information.
We use same model.

Note:
We use following TensorRT version on each Architectures.

  • Jetson Orin AGX 32GB: 8.5.2.2-1+cuda11.4
  • RTX 3080: 8.6.1.6-1+cuda12.0

Regards,
hiro

Hi,

Thanks for your reply.

We are going to reproduce this issue internally.
Will share more information with you later.

Dear @AastaLLL ,

Could you reproduce this issue?
If you need additional information, please let me know.

Regards,
hiro

Hi,

It seems that performance result is expected.

RTX 3080 uses a 320-bit memory interface but Orin is 256-bit.
In the memory-bounded use case, this can help RTX 3080 to get better performance.

Thanks.

Hi,

Could you also try the ‘–inputIOFormats=fp16:hwc8 --outputIOFormats=fp16:hwc8’ to avoid the reformat layer computation?
Thanks.

Dear @AastaLLL,

RTX 3080 uses a 320-bit memory interface but Orin is 256-bit.
In the memory-bounded use case, this can help RTX 3080 to get better performance.

Thank you for your information.
Do you mean memory access speed is the reason of Jetson’s bad performance?

I checked memory band width from Following sites as follows.

Jetson AGX Orin 34GB : 204.8 GB/s
RTX 3080 : 760.3 GB/s

Of course, Jetson’s memory band width is lower than RTX 3080 but it is more than 1/4.

On the other hands, Jetson’s performance is less than 1/5 as follows.
Jetson AGX Orin 32GB(40WMode) : 47.3 FP16 Sparse TFLOPS
RTX 3080 : 238 FP16 Sparse TFLOPS

If memory band width ratio is lower than performance ratio, I understand that memory access speed has a negative impact on the processing time.
This time, Memory band width ratio is higher than performance ratio.

So I cannot understand the negative impact.

Could you tell me the detail reason?

Regards,
hiro

Hi,

The performance ratio is measured on peak calculation.
But small kernel convolution is memory-bound so the impact from memory will increase.

However, the environment between 3080 and Orin is quite different.
Lots of mechanisms can impact performance. For example, 3080 has its own memory but Orin needs to share it with the CPU.

We do observe the small kernel convolution runs slower.
We are discussing this internally and plan to improve this in our future release.
Will let you know about the following.

Thanks.

Dear @AastaLLL ,

Thank you for your information.
We are looking forward to hearing the future release.

Regards,
hiro

Dear @AastaLLL ,

I would like to ask you one question about this Topics.

We are discussing this internally and plan to improve this in our future release.

If NVIDIA releases an improved version, do you have any estimate of the release date?
I want to know about NVIDIA’s development span.
Is it half year, or one year, more years?

Regards,
hiro

Hi,

We need to check this with our internal team.
Will update more info with you.

Thanks.

Hi,

Could you share the input/output channel for the 3x3 Convolution layer that you are using?
Thanks.

Could you share the input/output channel for the 3x3 Convolution layer that you are using?

We are using input/output channel from 16 to 32.

Regards,
hiro