TensorRT 6 slower than TensorFlow with 3D convolutions and pooling

We have shared a network and scripts to reproduce the issue: https://github.com/ralovich/trt-regression . Is it expected that 3D convolutions and pooling perform slower that TF when offloaded to TRT?

See also https://stackoverflow.com/questions/58607849/how-to-avoid-tensorrt-6-0-1-preformance-regression-against-tensorflow-with-3d-co

According to nvprof, it appears TRT 6 spends quite a bit of time in genericReformat::copyPackedKernel. This call frame is completely avoided by TF.

TF-TRT

==25530== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   57.27%  40.5793s        16  2.53621s  2.50036s  2.56996s  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=0>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=0>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *)
                   28.14%  19.9421s       560  35.611ms  30.176us  133.93ms  void genericReformat::copyPackedKernel<float, float, bool=0, bool=1, genericReformat::IdentityCoordMapper<int=5>, int=5>(unsigned int, unsigned int, void const *, genericReformat::ArrayN<genericReformat::IdentityCoordMapper<int=5>>, genericReformat::ArrayNWithReducedDivisors<genericReformat::IdentityCoordMapper<int=5>>, genericReformat::ArrayN, int, int, int, float const *, void*, genericReformat::ArrayN, genericReformat::ArrayNWithReducedDivisors, genericReformat::ArrayNWithReducedDivisors, genericReformat::ArrayN, int, int, int, float const , int=5)
                    5.17%  3.66144s       144  25.427ms  4.9889ms  80.370ms  maxwell_scudnn_128x64_stridedB_splitK_interior_nn
                    4.99%  3.53498s        96  36.823ms  3.5502ms  68.941ms  maxwell_scudnn_128x32_stridedB_splitK_interior_nn
                    2.54%  1.80123s        48  37.526ms  33.979ms  43.396ms  void cuPoolingNd::pooling_NCDHW<float>(float const *, cuPoolingNd::pooling_NCDHW<float>*, int, int, nvinfer1::Dims, nvinfer1, cuPoolingNd::pooling_NCDHW<float>*PoolingParameters, cuPoolingNd::pooling_NCDHW<float>*rt::reduced_divisor, cuPoolingNd::pooling_NCDHW<float>*rt, cuPoolingNd::pooling_NCDHW<float>*rt)
                    0.79%  560.12ms       256  2.1880ms  23.040us  5.6507ms  void cuEltwise::eltwise<cuEltwise::SimpleAlgo<float, float>, cuEltwise::Compute<nvinfer1::ElementWiseOperation>>(cuEltwise::LaunchParams)
                    0.36%  254.86ms        64  3.9822ms  2.3963ms  5.6356ms  void cuSliceLayer::naiveSlice<float>(cuSliceLayer::KernelArgs<float>)
                    0.31%  221.89ms       128  1.7335ms  752.77us  2.9854ms  void genericReformat::copyPackedKernel<float, float, bool=1, bool=1, genericReformat::IdentityCoordMapper<int=4>, int=4>(unsigned int, unsigned int, void const *, genericReformat::ArrayN<genericReformat::IdentityCoordMapper<int=4>>, genericReformat::ArrayNWithReducedDivisors<genericReformat::IdentityCoordMapper<int=4>>, genericReformat::ArrayN, int, int, int, float const *, void*, genericReformat::ArrayN, genericReformat::ArrayNWithReducedDivisors, genericReformat::ArrayNWithReducedDivisors, genericReformat::ArrayN, int, int, int, float const , int=4)
                    0.30%  209.33ms       240  872.22us  334.98us  2.7103ms  void cudnn::maxwell::gemm::setOutputKernel<float>(unsigned long, float*, float)
                    0.07%  49.110ms        16  3.0694ms  3.0006ms  3.1154ms  void cuEltwise::eltwise<cuEltwise::StripMineAlgo<float, float>, cuEltwise::Compute<nvinfer1::ElementWiseOperation>>(cuEltwise::LaunchParams)
                    0.03%  20.719ms        16  1.2949ms  1.2537ms  1.4628ms  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *)
                    0.02%  13.347ms       240  55.613us  23.104us  122.37us  cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams)
                    0.01%  7.6006ms        21  361.93us     640ns  840.90us  [CUDA memcpy HtoD]
                    0.00%  3.0638ms        16  191.48us  191.39us  191.91us  [CUDA memcpy DtoH]
                    0.00%  547.94us        16  34.246us  33.536us  35.968us  void cuScale::scale<float, float, bool=0, cuScale::Mode, bool=0, bool=0, bool=0, int=4, nvinfer1::FusedActType>(float const *, cuScale::scale<float, float, bool=0, cuScale::Mode, bool=0, bool=0, bool=0, int=4, nvinfer1::FusedActType>*, cuScale::KernelParameters<cuScale::scale<float, float, bool=0, cuScale::Mode, bool=0, bool=0, bool=0, int=4, nvinfer1::FusedActType>>, nvinfer1::rt::reduced_divisor, nvinfer1::rt, nvinfer1::rt, nvinfer1::rt, nvinfer1::rt)
                    0.00%  348.74us        16  21.796us  21.248us  22.496us  void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<float, int=1, int=1, long>, int=16, Eigen::MakePointer>, Eigen::TensorCwiseBinaryOp<Eigen::internal::scalar_max_op<float const , float const >, Eigen::TensorMap<Eigen::Tensor<float const , int=1, int=1, long>, int=16, Eigen::MakePointer> const , Eigen::TensorCwiseNullaryOp<Eigen::internal::scalar_constant_op<float const >, Eigen::TensorMap<Eigen::Tensor<float const , int=1, int=1, long>, int=16, Eigen::MakePointer> const > const > const > const , Eigen::GpuDevice>, long>(float, int=1)
                    0.00%  343.74us       240  1.4320us  1.1200us  2.1440us  cudnn::maxwell::gemm::computeBOffsetsKernel(cudnn::maxwell::gemm::ComputeBOffsetsParams)
                    0.00%  320.16us       249  1.2850us     544ns  10.272us  [CUDA memset]

TF

==24863== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   74.67%  43.0849s        17  2.53441s  2.50326s  2.57168s  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=0>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=0>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *)
                    6.38%  3.67934s        69  53.324ms  3.5521ms  67.210ms  maxwell_scudnn_128x32_stridedB_splitK_interior_nn
                    5.82%  3.35918s        51  65.866ms  47.210ms  81.543ms  maxwell_scudnn_128x64_stridedB_splitK_interior_nn
                    3.01%  1.73954s      2739  635.10us  335.36us  1.0583ms  maxwell_gcgemm_32x32_tn
                    2.08%  1.20297s       352  3.4175ms  1.0307ms  6.8907ms  void tensorflow::functor::SwapDimension1And2InTensor3UsingTiles<float, int=256, int=32, int=32, bool=0>(float const *, tensorflow::functor::Dimension<int=3>, tensorflow::functor::SwapDimension1And2InTensor3UsingTiles<float, int=256, int=32, int=32, bool=0>*)
                    1.57%  905.20ms        32  28.288ms  24.092ms  33.252ms  void cudnn::detail::pooling_fw_5d_kernel<float, float, cudnn::detail::averpooling_func<float>, int=2, bool=0>(cudnnTensorStruct, float const *, cudnn::detail::pooling_fw_5d_kernel<float, float, cudnn::detail::averpooling_func<float>, int=2, bool=0>, cudnnTensorStruct*, cudnnPoolingStruct, float, cudnnPoolingStruct, int, cudnn::reduced_divisor, float, float, float, float)
                    0.93%  537.60ms       128  4.2000ms  2.9829ms  4.8289ms  sgemm_32x32x32_NN_vec
                    0.85%  489.98ms        19  25.789ms  1.2564ms  246.26ms  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *)
                    0.83%  478.17ms       256  1.8679ms  735.43us  5.8379ms  void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<float, int=2, int=1, int>, int=16, Eigen::MakePointer>, Eigen::TensorCwiseBinaryOp<Eigen::internal::scalar_sum_op<float, float>, Eigen::TensorMap<Eigen::Tensor<float const , int=2, int=1, int>, int=16, Eigen::MakePointer> const , Eigen::TensorBroadcastingOp<Eigen::array<long, unsigned long=2> const , Eigen::TensorMap<Eigen::Tensor<float const , int=2, int=1, int>, int=16, Eigen::MakePointer> const > const > const > const , Eigen::GpuDevice>, int>(float, int=2)
                    0.82%  473.75ms       272  1.7417ms  20.256us  5.5174ms  void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<float, int=1, int=1, long>, int=16, Eigen::MakePointer>, Eigen::TensorCwiseBinaryOp<Eigen::internal::scalar_max_op<float const , float const >, Eigen::TensorMap<Eigen::Tensor<float const , int=1, int=1, long>, int=16, Eigen::MakePointer> const , Eigen::TensorCwiseNullaryOp<Eigen::internal::scalar_constant_op<float const >, Eigen::TensorMap<Eigen::Tensor<float const , int=1, int=1, long>, int=16, Eigen::MakePointer> const > const > const > const , Eigen::GpuDevice>, long>(float, int=1)
                    0.69%  399.66ms        16  24.979ms  24.397ms  25.758ms  void cudnn::detail::pooling_fw_5d_kernel<float, float, cudnn::detail::maxpooling_func<float, cudnnNanPropagation_t=0>, int=0, bool=0>(cudnnTensorStruct, float const *, cudnn::detail::pooling_fw_5d_kernel<float, float, cudnn::detail::maxpooling_func<float, cudnnNanPropagation_t=0>, int=0, bool=0>, cudnnTensorStruct*, cudnnPoolingStruct, float, cudnnPoolingStruct, int, cudnn::reduced_divisor, float, float, float, float)
                    0.66%  381.00ms         4  95.251ms  75.831ms  118.25ms  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=6, int=7, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=1024, int=6, int=7, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *)
                    0.41%  236.41ms       128  1.8470ms  859.04us  3.0050ms  void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorSlicingOp<Eigen::array<int, unsigned long=2> const , Eigen::array<int, unsigned long=2> const , Eigen::TensorMap<Eigen::Tensor<float, int=2, int=1, int>, int=16, Eigen::MakePointer>>, Eigen::TensorMap<Eigen::Tensor<float const , int=2, int=1, int>, int=16, Eigen::MakePointer> const > const , Eigen::GpuDevice>, int>(int, unsigned long=2)
                    0.24%  138.48ms       120  1.1540ms  463.26us  2.6094ms  void cudnn::maxwell::gemm::setOutputKernel<float>(unsigned long, float*, float)
                    0.18%  105.58ms      6172  17.105us  2.5920us  1.4756ms  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *)
                    0.18%  104.75ms      3082  33.987us  7.9040us  108.74us  void fft3d_c2r_16x16x16<float2, float, float>(float*, float2*, int3, int3, int3, int3, int3, float, float, bool, int, float*, float*)
                    0.17%  98.626ms         1  98.626ms  98.626ms  98.626ms  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *)
                    0.16%  90.844ms       343  264.85us  257.22us  269.41us  maxwell_gcgemm_64x64_tn
                    0.14%  80.336ms        64  1.2553ms  741.60us  1.8466ms  void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<float, int=5, int=1, int>, int=16, Eigen::MakePointer>, Eigen::TensorSlicingOp<Eigen::DSizes<int, int=5> const , Eigen::DSizes<int, int=5> const , Eigen::TensorMap<Eigen::Tensor<float const , int=5, int=1, int>, int=16, Eigen::MakePointer> const > const > const , Eigen::GpuDevice>, int>(float, int=5)
                    0.13%  75.111ms      3090  24.307us  5.6320us  737.38us  void fft3d_r2c_16x16x16<float, float, float2>(float2*, float*, int3, int3, int3, int3, int3, bool)
                    0.02%  11.557ms        82  140.94us  131.52us  157.22us  redzone_checker
                    0.02%  10.873ms       120  90.606us  63.616us  131.14us  cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams)
                    0.02%  8.6915ms        52  167.14us     640ns  655.20us  [CUDA memcpy HtoD]
                    0.01%  3.1427ms        57  55.135us     768ns  203.33us  [CUDA memcpy DtoH]
                    0.00%  2.7738ms       144  19.262us  2.1120us  44.800us  void tensorflow::functor::ShuffleInTensor3Simple<float, int=2, int=1, int=0, bool=0>(int, float const *, tensorflow::functor::Dimension<int=3>, tensorflow::functor::ShuffleInTensor3Simple<float, int=2, int=1, int=0, bool=0>*)
                    0.00%  449.38us        16  28.086us  27.520us  28.864us  void tensorflow::functor::SwapDimension1And2InTensor3UsingTiles<unsigned int, int=1024, int=2, int=1024, bool=0>(unsigned int const *, tensorflow::functor::Dimension<int=3>, tensorflow::functor::SwapDimension1And2InTensor3UsingTiles<unsigned int, int=1024, int=2, int=1024, bool=0>*)
                    0.00%  378.18us        16  23.636us  23.232us  24.032us  void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<float, int=2, int=1, int>, int=16, Eigen::MakePointer>, Eigen::TensorCwiseBinaryOp<Eigen::internal::scalar_sum_op<float, float>, Eigen::TensorBroadcastingOp<Eigen::array<long, unsigned long=2> const , Eigen::TensorMap<Eigen::Tensor<float const , int=2, int=1, int>, int=16, Eigen::MakePointer> const > const , Eigen::TensorMap<Eigen::Tensor<float const , int=2, int=1, int>, int=16, Eigen::MakePointer> const > const > const , Eigen::GpuDevice>, int>(float, int=2)
                    0.00%  224.32us       120  1.8690us  1.5360us  2.1760us  cudnn::maxwell::gemm::computeBOffsetsKernel(cudnn::maxwell::gemm::ComputeBOffsetsParams)
                    0.00%  215.81us       225     959ns     544ns  2.5600us  [CUDA memset]

Hi Kristof,

We’re looking into this. Will let you know when I get some news.

Thanks,
NVIDIA Enterprise Support

Hi Kristof,

Per engineering team:

  1. Can you share verbose TF logs for your repro model as mentioned here: Accelerating Inference In TF-TRT User Guide :: NVIDIA Deep Learning Frameworks Documentation

  2. Can you share the following environment information? The CUDNN version may be particularly important for 3D convolution performance:

Environment

TensorRT Version:
GPU Type:
Nvidia Driver Version:
CUDA Version:
CUDNN Version:
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Operating System + Version:
Baremetal or Container (if container which image + tag):

  1. Verbose conversion logs:

https://github.com/ralovich/trt-regression/commit/41b449d859ff21fee255fd1777df3edfc27e57c7

  1. Environment as follows:

TensorRT Version: 6.0.1-1+cuda10.0
GPU Type: GeForce GTX 1080
Nvidia Driver Version: 430.26
CUDA Version: 10.0.130-1
CUDNN Version: 7.6.3.30-1+cuda10.0
Python Version (if applicable): 3.6.7-1~18.04
TensorFlow Version (if applicable): 1.15 branch + tensorflow/tensorflow@4297539
PyTorch Version (if applicable): N/A
Operating System + Version: Ubuntu 18.04.3 LTS
Baremetal or Container (if container which image + tag): Baremetal

Issue being tracked here as well: https://github.com/NVIDIA/TensorRT/issues/153