TensorRT 7 conv3d is not running on Tensor Cores

Description

Hello.
I’m trying to optimize pytorch NN on Jetson AGX Xavier 32GB with TensorRT, but I can’t make conv3d run on Tensor Cores. I’ve created really small and easy NN with convolution and activation only, and I expect it to run on tensor cores, but it’s not. However, if I replace conv3d with conv2d - I see conv2d running on tensor cores. I am running onnx model with trtexec and profiling it with nvprof + -m tensor_precision_fu_utilization and see no Tensor Core utilisation. Am I missing something?

I have another post in TensorRT section, but I was advised to post it in Jetson related forum.
Original post: TensorRT 7 conv3d is not running on Tensor Cores - #8 by maxim.godovicyn

Environment

TensorRT Version : 7.1.3-1+cuda10.2
GPU Type : Volta
Nvidia Driver Version : Jetpack 4.5.1
CUDA Version : 10.2
CUDNN Version : 8.0
Operating System + Version : 4.9.201-tegra (Jetpack 4.5.1)
Python Version (if applicable) : -
TensorFlow Version (if applicable) :-
PyTorch Version (if applicable) : 1.6 (for onnx model creation)
Baremetal or Container (if container which image + tag) :

Relevant Files

Onnx model:
test.onnx (432.6 KB)

Nvprof log file:
nvprof_test.log (9.5 KB)

Steps To Reproduce

Code to create onnx model:

import torch
from torch import nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv3d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool3d(2, 2)

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        return x

image_dims = (1, 64, 64, 64, 64)

dummy_input = torch.rand(image_dims, device="cuda")

model = Net().to("cuda")

torch.onnx.export(model, dummy_input, "test.onnx", verbose=True, opset_version=11)

Command to run model and create nvprof log:
sudo -E /usr/local/cuda-10.2/bin/nvprof -m tensor_precision_fu_utilization --log-file nvprof_test.log /usr/src/tensorrt/bin/trtexec --onnx=test.onnx --explicitBatch --dumpProfile --fp16 --verbose --workspace=4096

Hi,

Since we have TensorRT 8.0 currently, would you mind reproducing this with it first?
Thanks.

Hi,

I’ve tried to reproduce it on Xavier with jetpack 4.6 and i see absolutely the same problem. 3D convolutions is not on tensor cores.

nvprof log:
nvprof_test_TRT8.log (35.7 KB)

P.S.: I’ve used the same model.

Hi,

Any update on my issue? Can you confirm that you can reproduce this problem?

Thanks.

Hi,

Based on the above document, Tensor Core is used when below algorithm is selected:

CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED

However, TensorRT will automatically choose a fast algorithm based on platform and hardware resources.
It’s not guaranteed that TensorRT will always use Tensor Cores for inference.

Thanks.

Hi,

So, how can i check that algorithms you’ve mentioned are used or at least enabled?

I understand that it is not guaranteed that TRP will use Tensor Cores, but when i run trtexec i do not see any “884” instructions in log, so i think TRT is not even trying to use Tensor Cores. Or, maybe Tensor Core instructions for conv3d is in log, but I don’t know how to identify them. Can you please check trtexec log for conv3d Tensor Core instructions?
Log file (TRT8 setup):
trtexec_test.log (38.8 KB)

Thanks.

Hi,

Please enable the cuDNN log and share it with us.

$ export CUDNN_LOGINFO_DBG=1
$ export CUDNN_LOGDEST_DBG=output.log
$ /usr/src/tensorrt/bin/trtexec ...

Thanks.

Hi,

No problem:
output.log (3.9 MB)

Thanks.

Hi,

Based on the API log, TensorRT does use CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM as the convolution algorithm.

We are checking this issue internally.
Will share more information with you later.

I! CuDNN (v8201) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=0x55c6ef49c0;
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,64,64,64,64];
i!         strideA: type=int; val=[16777216,262144,4096,64,1];
i!     xData: location=dev; addr=0x224ada000;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         vect: type=int; val=0;
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[64,64,3,3,3];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x21487a000;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=3;
i!         padA: type=int; val=[1,1,1];
i!         strideA: type=int; val=[1,1,1];
i!         dilationA: type=int; val=[1,1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM (1);
i!     workSpace: location=dev; addr=0x22cada000;
i!     workSpaceSizeInBytes: type=size_t; val=72280576;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,64,64,64,64];
i!         strideA: type=int; val=[16777216,262144,4096,64,1];
i!     yData: location=dev; addr=0x228ada000;
i! Time: 2021-09-06T23:34:06.382896 (0d+0h+0m+7s since start)
i! Process=12660; Thread=12660; GPU=0; Handle=0x55ac94de20; StreamId=0x55c6ef49c0.

Thanks.

Hi,

It seems the support is added in TensorRT v8.0.

We check your test.onnx model with JetPack 4.6 and do observe the Tensor Cores usage.
Please also give it a check:

==31039== NVPROF is profiling process 31039, command: /usr/src/tensorrt/bin/trtexec --onnx=conv3d.onnx --explicitBatch --dumpProfile --fp16 --verbose --workspace=4096
==31039== Profiling application: /usr/src/tensorrt/bin/trtexec --onnx=conv3d.onnx --explicitBatch --dumpProfile --fp16 --verbose --workspace=4096
==31039== Profiling result:
==31039== Metric result:
Invocations                               Metric Name                           Metric Description         Min         Max         Avg
Device "Xavier (0)"
    Kernel: volta_scudnn_128x64_3dconv_fprop_medium_nn_v1
         32           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization    Idle (0)    Idle (0)    Idle (0)
    Kernel: void fft3d_c2r_16x16x16<float2, float, __half>(__half*, float2*, int3, int3, int3, int3, int3, float, float, bool, int, __half*, __half*)
       4000           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization    Idle (0)    Idle (0)    Idle (0)
    Kernel: void xmma_trt::gemm::kernel<xmma_trt::implicit_gemm::fprop_indexed::Kernel_traits<xmma_trt::Volta_hmma_fp16_traits, xmma_trt::Cta_tile<xmma_trt::Volta<int=0>, int=128, int=256, int=32, int=2, int=4, int=1, int=1>, xmma_trt::implicit_gemm::fprop_indexed::Gmem_tile_a_t<xmma_trt::Volta_hmma_fp16_traits, xmma_trt::Cta_tile<xmma_trt::Volta<int=0>, int=128, int=256, int=32, int=2, int=4, int=1, int=1>, xmma_trt::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=16, bool=0, xmma_trt::implicit_gemm::fprop_indexed::Gmem_tile_base_a<xmma_trt::Volta_hmma_fp16_traits, xmma_trt::Cta_tile<xmma_trt::Volta<int=0>, int=128, int=256, int=32, int=2, int=4, int=1, int=1>, xmma_trt::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=16, xmma_trt::Row, int=32, int=128>>, xmma_trt::implicit_gemm::fprop_indexed::Gmem_tile_c_t<xmma_trt::Volta_hmma_fp16_traits, xmma_trt::Cta_tile<xmma_trt::Volta<int=0>, int=128, int=256, int=32, int=2, int=4, int=1, int=1>, int=16, xmma_trt::Fragment_c<xmma_trt::Volta_hmma_fp16_traits, xmma_trt::Cta_tile<xmma_trt::Volta<int=0>, int=128, int=256, int=32, int=2, int=4, int=1, int=1>, bool=0>>, xmma_trt::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=1>>(xmma_trt::Volta_hmma_fp16_traitsParams)
         16           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization    High (9)    High (9)    High (9)
    Kernel: sm70_xmma_fprop_implicit_gemm_f16f16_f16f16_f16_nhwckrsc_nhwc_tilesize64x32x64_stage1_warpsize2x1x2_g1_tensor8x8x4_t3r3s3_kernel_trt
         16           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization     Mid (5)     Mid (5)     Mid (5)
    Kernel: sm70_xmma_fprop_implicit_gemm_f16f16_f16f16_f16_nhwckrsc_nhwc_tilesize128x128x64_stage1_warpsize2x2x1_g1_tensor8x8x4_t3r3s3_kernel_trt
         16           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization    High (9)    High (9)    High (9)
    Kernel: sm70_xmma_fprop_implicit_gemm_f16f16_f16f16_f16_nhwckrsc_nhwc_tilesize256x128x32_stage1_warpsize4x2x1_g1_tensor8x8x4_t3r3s3_kernel_trt
         16           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization    High (9)    High (9)    High (9)
    Kernel: sm70_xmma_fprop_implicit_gemm_f16f16_f16f16_f16_nhwckrsc_nhwc_tilesize128x256x32_stage1_warpsize2x4x1_g1_tensor8x8x4_t3r3s3_kernel_trt
         16           tensor_precision_fu_utilization   Tensor-Precision Function Unit Utilization    High (9)    High (9)    High (9)
...

Thanks.

Hi,

I see the same, but is it OK, that
“Kernel: volta_scudnn_128x64_3dconv_fprop_medium_nn_v1”
have Idle Tensor-Precision Function Unit Utilization? Does it mean that conv3d is not running on TensorCores or not? If not, so clarify please what does it mean?

Thanks.

Hi,

Sorry for the late update.

Could you also share the source of the conv2d and the corresponding onnx model with us?
We want to check it further and compare the models.

Thanks.

Hi,

No problem.
Code to create conv2D model:

import torch
from torch import nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        return x

image_dims = (1, 64, 64, 64)

dummy_input = torch.rand(image_dims, device="cuda")

model = Net().to("cuda")

torch.onnx.export(model, dummy_input, "test.onnx", verbose=True, opset_version=11)

Model file:
conv2D.onnx (144.6 KB)

Thanks.

Thanks.

We are checking this internally.
Will share more information with you later.

Hi,

Sorry for the late reply.
With TensorRT v8.0, conv3d does use Tensor Cores.

Please check the output shared above.
There are lots of xmma_* (Matrix Multiply-Accumulate) functions which means it uses the TensorCore API.
The volta_scudnn_128x64_3dconv_fprop_medium_nn_v1 is on the top of the xmma kernel.

Thanks.

1 Like

Hi,

Thanks for your answer, now I got it.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.