Performance issues when upgrading to JetPack 5

We are trying to upgrade the OS of our product running on Jetson Xavier NX, from JetPack 4.4.1 to JetPack 5.1.2.

We’ve encountered a problematic behavior - the GPU seems to run dramatically slower on the new JetPack.

I’ve tried several different kinds of operators - from simple torch cuda operations, to tensorrt engine inference using torch_tensorrt.

I’ve run each operator 10,000 times and measured average timing in ms per frame.

Here is a table of the results:

It seems that all GPU operators are taking much more time on the JetPack5…

The JetPack 5 has, of course, suitable versions:

  • Python 3.8 (in JP4: 3.6)
  • Torch 2.1.0 with cuda 11.4 (in JP4: Torch 1.8.0 with cuda 10.2)
  • TensorRT 8.5.2 (in JP4: 7.1.3)
  • torch_tensorrt 1.4.0 (compiled) (in JP4: trtorch 0.3.0)
  • Opencv 4.6.0 with cuda 11.4 (compiled) (in JP4: 4.6.0 with cuda 10.2)

Things I’ve checked:

  1. Clocks seem the same (as far as I can tell, from all data appears in jtop).
  2. jetson_clocks is running.
  3. Power mode is the same (15W 6 cores)

Here is a sample script (results are in the last line of the table I’ve attached, about 5% slower on JetPack 5):

import torch
import time


ITERATIONS = 10000


torch_gpu_tensor = torch.ones((10000, 10000), dtype=torch.float32, device=torch.device('cuda'))

torch_gpu_tensor.sum()  # First call sometimes slower, do not count
start_time = time.time()
for _ in range(ITERATIONS):
    torch_gpu_tensor.sum()
    torch.cuda.current_stream().synchronize()
total_time = time.time() - start_time
print(f"it took {total_time} seconds for {ITERATIONS} iterations, {1000 * total_time / ITERATIONS}ms average for 1 iteration")

Can you please advise us on how to proceed?

Thank you.

Hi,

Due to the security hardening in upstream kernel 5.10, we do expect to see the general perf drop in JetPack 5.

Thanks.

How can we disable these security hardening? We can not work with such latency in our product, we are time critical.
We actually wanted to benefit from the DLA and additional supported features in a newer JetPack, and it is absurd that we actually get so bad performance instead.

Hi,

Unfortunately, this comes from the upstream kernel so it cannot be removed.
But maybe you can tune a custom power mode for 15W to see if it helps.
Ex. lower the CPU clocks and increases the GPU clocks

Thanks.

Hi,

To reproduce this issue locally, could you share the following with us?

  • The sample script for “gpu post process”
  • The building step (or wheel) for PyTorch on JetPack 4.4.1 and 5.1.2.
  • The building step (or wheel) for torch_tensorrt on JetPack 4.4.1 and 5.1.2.
  • The building step (or wheel) for OpenCV on JetPack 4.4.1 and 5.1.2.

Thanks.

Hi,

Would you mind sharing the info above so we can reproduce this issue internally?
Thanks.

Hi AastaLLL, thank you for helping.

I’m attaching a sample script of inference + gpu post process. This is not exactly the model we’ve trained, but a regular mobilenet_v2 that also reproduces the issue on my devices (it is simpler and also does not include any IP of the company). I am including the ONNX of the model and the conversion for both platforms (JP4.4.1 and JP5.1.2) + the wheels for torch, torchvision and trtorch/teorch_tensorrt.

Script + wheels can be found here:
https://drive.google.com/file/d/1ozaFlIYFkznzu0tgSqJ4eHsVMNSS3vol/view?usp=sharing

The gpu post process seem to be slower “only” in around 10% since I’ve re-converted the models to JP5, I don’t know exactly why.
The inference seem to be around 28% slower on our Jetsons.

Regarding the OpenCV, it includes some internal code modifications so it is a bit more complicated to share, anyway the sample script does not relate on it.

I hope you can reproduce the latency, otherwise it might be something wrong in the environment setup we’ve created…
For instance, we are installing JetPack on nvme device and replacing the GUI with LXDE (in both JetPacks), can this be related? Are there any other dependencies of the torch/tensorrt that might be different on our environment?

Thank you again.

Hi,

Thanks for the source and package. Here are some findings.
We have tried the engine. As we don’t have torch_tensorrt before, we run the model with the TensorRT binary directly.

$/usr/src/tensorrt/bin/trtexec --loadEngine=./mobilenet_v2.[ver].engine --iterations=100

With the TensorRT binary, we don’t observe the perf drop issue:

JetPack 4: 5.31199 ms
JetPack 5: 5.26338 ms

Since TensorRT runs the model with fake input (fast device memory), the regression might be memory-related.
(For example, a known memory regression caused by the security update:
Bad memory performance on JetPack 5.0.2 (5.10.104-tegra))

We will try the wheel and check if we can reproduce the issue locally to gather more info.
Could you also test the same command for the real model you used to make sure there is no degradation on the TensorRT side?

Thanks.

Hi,

We have tested the benchmark script and are able to reproduce the issue.

For post-processing, we see about 21.4% perf drop.

JetPack-4.4.1(r32.4.4) + PyTorch 1.8.0

$ python3 ./run_benchmark.py 
execution_time: 55.322532176971436 seconds for 10000 executions, 0.005532253217697143 for 1
gpu_post_process_time: 252.67336797714233 seconds for 10000 executions, 0.025267336797714233 for 1
total_time: 307.99590015411377 seconds for 10000 executions, 0.030799590015411375 for 1

JetPack-5.1.2 (r35.4.1) + PyTorch 2.1.0

$ python3 ./run_benchmark.py 
WARNING: [Torch-TensorRT] - Using an engine plan file across different models of devices is not recommended and is likely to affect performance or even cause errors.
execution_time: 53.42548370361328 seconds for 10000 executions, 0.005342548370361328 for 1
gpu_post_process_time: 306.8492875099182 seconds for 10000 executions, 0.03068492875099182 for 1
total_time: 360.2747712135315 seconds for 10000 executions, 0.03602747712135315 for 1

We will check it further and share more info with you.
But there is a warning in the JetPack 5 test.
Is the engine for JetPack 5 converted on the same environment or copied from another environment?

Thanks.

Hi Aasta, thank you for your replies.

I’ve converted the model on the same device I then executed the test.
I’ve attached the .onnx file of the model before conversion, so you can also try to convert it again on your device.
I sometimes see this warning when converting on one device and copy the engine to another device, although they are theoretically identical, so I’m not sure how it is being tested internally…

I’ve tried to set the EMC rate to maximum on the JP5 and it seems to improve the performance a little bit, but not significantly.

Executing the mobilenet_v2 using trtexec plot similar results on my device as you’ve shared. On our modified model, the JP5 again run slower for some reason:


One guess I have is that the difference is due to the fact that our modified model plots much bigger result, so I will try to create a sample model and see how it behaves.

I will also continue to perform tests on my side and keep updating on any new evidences that might help us track the root cause.

Thank you.

Hi,

Just some update for you.
We ran the sample with nsys and got the below note:

"The following APIs use PAGEABLE memory which causes asynchronous CUDA memcpy operations to block and be executed synchronously. This leads to low GPU utilization. "

So we modify the bandwidthTest CUDA sample to PAGEABLE memory + asynchronous copy but fails to reproduce the regression.

JetPack-4.4.1 (r32.4.4)

$ ./test 
[CUDA Bandwidth Test] - Starting...
Running on...

 Device 0: Xavier
 Quick Mode

 Host to Device Bandwidth, 1 Device(s)
 PAGEABLE Memory Transfers
   Transfer Size (Bytes)	Bandwidth(GB/s)
   32000000			6.1

 Device to Host Bandwidth, 1 Device(s)
 PAGEABLE Memory Transfers
   Transfer Size (Bytes)	Bandwidth(GB/s)
   32000000			5.5

 Device to Device Bandwidth, 1 Device(s)
 PAGEABLE Memory Transfers
   Transfer Size (Bytes)	Bandwidth(GB/s)
   32000000			46.9

Result = PASS

NOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.

JetPack-5.1.2 (r35.4.1)

$ ./test 
[CUDA Bandwidth Test] - Starting...
Running on...

 Device 0: Xavier
 Quick Mode

 Host to Device Bandwidth, 1 Device(s)
 PAGEABLE Memory Transfers
   Transfer Size (Bytes)	Bandwidth(GB/s)
   32000000			6.1

 Device to Host Bandwidth, 1 Device(s)
 PAGEABLE Memory Transfers
   Transfer Size (Bytes)	Bandwidth(GB/s)
   32000000			5.5

 Device to Device Bandwidth, 1 Device(s)
 PAGEABLE Memory Transfers
   Transfer Size (Bytes)	Bandwidth(GB/s)
   32000000			46.7

Result = PASS

NOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.

We are still working on to locate the root cause. Will keep you updated.
Thanks.

Hi Aasta,
Thank you again for investigating the issue.
Have you found any more leads to the root cause?

Hi,

Not yet.

As our internal benchmarks are not showing the regression between JetPack4 and JetPack5, we are trying to understand the implementation differences between PyTorch 1.8 and 2.1 to see what might cause the issue.

Currently, we found the ‘where’ operation is the main reason for the perf drop (in the post-processing case).
The kernel takes around 25ms on JetPack 4 but requires 31ms on JetPack 5.

And now we are building PyTorch 1.8 on JetPack 5 to see if we can reproduce the same perf drop with the exact identical kernel.

Jetpack 4

$ sudo /usr/local/cuda-10.2/bin/nvprof python3 run_benchmark_standalone.py 
==14210== NVPROF is profiling process 14210, command: python3 run_benchmark_standalone.py
==14210== Warning: Unified Memory Profiling is not supported on the underlying platform. System requirements for unified memory can be found at: http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-requirements
gpu_post_process_time: 2.603844404220581 seconds for 100 executions, 0.02603844404220581 for 1
==14210== Profiling application: python3 run_benchmark_standalone.py
==14210== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   99.37%  2.53165s       101  25.066ms  24.942ms  26.711ms  _ZN2at6native27unrolled_elementwise_kernelIZZZNS0_83_GLOBAL__N__59_tmpxft_00007230_00000000_8_TensorCompare_compute_72_cpp1_ii_da58410617where_kernel_implERNS_14TensorIteratorEN3c1010ScalarTypeEENKUlvE_clEvENKUlvE6_clEvEUlbffE_NS_6detail5ArrayIPcLi4EEE16OffsetCalculatorILi3EjESE_ILi1EjENS0_6memory15LoadWithoutCastENSH_16StoreWithoutCastEEEviT_T0_T1_T2_T3_T4_
                    0.17%  4.4158ms       101  43.721us  43.490us  44.642us  void at::native::_GLOBAL__N__63_tmpxft_000008d0_00000000_8_UpSampleNearest2d_compute_72_cpp1_ii_f539c38f::upsample_nearest2d_out_frame<c10::Half, float>(c10::Half const *, at::native::_GLOBAL__N__63_tmpxft_000008d0_00000000_8_UpSampleNearest2d_compute_72_cpp1_ii_f539c38f::upsample_nearest2d_out_frame<c10::Half, float>*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, float, float)
                    0.17%  4.3586ms       101  43.154us  42.050us  46.306us  void at::native::vectorized_elementwise_kernel<int=4, at::native::MulScalarFunctor<float, float>, at::detail::Array<char*, int=2>>(int, float, float)
                    0.17%  4.2543ms       101  42.121us  40.898us  50.498us  _ZN2at6native27unrolled_elementwise_kernelIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE0_clEvENKUlvE6_clEvEUlfE_NS_6detail5ArrayIPcLi2EEE23TrivialOffsetCalculatorILi1EjESC_NS0_6memory12LoadWithCastILi1EEENSD_13StoreWithCastEEEviT_T0_T1_T2_T3_T4_
                    0.12%  3.0912ms       101  30.606us  29.761us  32.385us  void at::native::vectorized_elementwise_kernel<int=4, at::native::BUnaryFunctor<at::native::CompareLTFunctor<float>>, at::detail::Array<char*, int=2>>(int, float, at::native::CompareLTFunctor<float>)
                    0.00%  28.481us         1  28.481us  28.481us  28.481us  void at::native::vectorized_elementwise_kernel<int=4, at::native::FillFunctor<float>, at::detail::Array<char*, int=1>>(int, float, at::native::FillFunctor<float>)
                    0.00%  3.1690us         1  3.1690us  3.1690us  3.1690us  void at::native::vectorized_elementwise_kernel<int=4, at::native::FillFunctor<c10::Half>, at::detail::Array<char*, int=1>>(int, c10::Half, at::native::FillFunctor<c10::Half>)
      API calls:   77.15%  8.66442s         3  2.88814s  329.35us  8.58146s  cudaMalloc
                   22.46%  2.52256s       101  24.976ms  24.758ms  26.537ms  cudaStreamSynchronize
                    0.26%  28.697ms       507  56.600us  37.185us  205.16us  cudaLaunchKernel
                    0.12%  13.345ms      5374  2.4830us  1.3120us  83.938us  cudaGetDevice
                    0.01%  711.92us       611  1.1650us     576ns  28.513us  cudaGetLastError
                    0.00%  224.04us        97  2.3090us  1.0880us  28.832us  cuDeviceGetAttribute
                    0.00%  68.258us         1  68.258us  68.258us  68.258us  cudaGetDeviceProperties
                    0.00%  31.840us         2  15.920us  2.4640us  29.376us  cuDeviceGet
                    0.00%  14.496us         2  7.2480us  4.1920us  10.304us  cudaSetDevice
                    0.00%  9.0240us         1  9.0240us  9.0240us  9.0240us  cuDeviceTotalMem
                    0.00%  8.6400us         3  2.8800us  1.7920us  4.1280us  cuDeviceGetCount
                    0.00%  3.8720us         2  1.9360us  1.2800us  2.5920us  cudaGetDeviceCount
                    0.00%  2.0810us         1  2.0810us  2.0810us  2.0810us  cuDeviceGetName
                    0.00%  1.5360us         1  1.5360us  1.5360us  1.5360us  cuDeviceGetUuid

Jetpack 5

Time    Total Time    Instances    Avg    Med    Min    Max    StdDev    Name
99.5%    3.147 s    101    31.160 ms    31.154 ms    31.125 ms    31.410 ms    31.412 μs    void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::where_kernel_impl(at::TensorIterator &)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 7)]::operator ()() const::[lambda(bool, float, float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
0.1%    4.406 ms    101    43.621 μs    43.424 μs    42.432 μs    48.000 μs    911 ns    void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, float, at::native::binary_internal::MulFunctor<float>>, at::detail::Array<char *, (int)2>>(int, T2, T3)
0.1%    4.312 ms    101    42.692 μs    42.624 μs    42.496 μs    45.152 μs    334 ns    void at::native::<unnamed>::upsample_nearest2d_out_frame<c10::Half, &at::native::nearest_neighbor_compute_source_index>(const T1 *, T1 *, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, float, float)
0.1%    4.178 ms    101    41.368 μs    41.152 μs    39.936 μs    52.064 μs    1.473 μs    void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 7)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>, TrivialOffsetCalculator<(int)1, unsigned int>, TrivialOffsetCalculator<(int)1, unsigned int>, at::native::memory::LoadWithCast<(int)1>, at::native::memory::StoreWithCast<(int)1>>(int, T1, T2, T3, T4, T5, T6)
0.1%    3.118 ms    101    30.873 μs    30.816 μs    30.240 μs    32.384 μs    427 ns    void at::native::vectorized_elementwise_kernel<(int)4, void at::native::compare_scalar_kernel<float>(at::TensorIteratorBase &, at::native::<unnamed>::OpType, T1)::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.0%    27.584 μs    1    27.584 μs    27.584 μs    27.584 μs    27.584 μs    0 ns    void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.0%    4.544 μs    1    4.544 μs    4.544 μs    4.544 μs    4.544 μs    0 ns    void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<c10::Half>, at::detail::Array<char *, (int)1>>(int, T2, T3)

Thanks.