Conv3D does not use Tensor Cores

Description

It appears to me that TensorRT does not make use of kernels that use Tensor Cores for Conv3D. I tried running an ONNX model with a single Conv3D as well as constructing the network definition with TensorRT.
It is therefore slower than in pytorch.
I believe that I followed all recommendations for 3D convolutions, all entities are multiple of 8.
The Conv2D equivalent chooses a Tensor Core-enabled kernel.
I tested this on an RTX 2080 Ti and on a T4.

Environment

TensorRT Version: 7.1.3.4
GPU Type: RTX 2080 Ti / T4
Nvidia Driver Version: 440.33.01
CUDA Version: 10.2
CUDNN Version: 8.0.1
Operating System + Version: CentOS 7.7.1908
Python Version (if applicable): 3.6.8
Baremetal or Container (if container which image + tag): Baremetal

Relevant Files

See below.

Steps To Reproduce

Run the following script with python3.

import time
import os
import numpy as np
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)

if __name__ == "__main__":

    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
        EXPLICIT_BATCH
    ) as network:
        input_tensor = network.add_input(
            name="input_image", dtype=trt.float32, shape=(1, 64, 64, 64, 64)
        )

        # Add a convolution layer
        conv1_w = np.ones((64, 64, 1, 1, 1), dtype=np.float32)
        conv1_b = np.ones(64, dtype=np.float32)
        conv1 = network.add_convolution_nd(
            input=input_tensor,
            num_output_maps=64,
            kernel_shape=(1, 1, 1),
            kernel=conv1_w,
            bias=conv1_b,
        )

        conv1.get_output(0).name = "output_featuremap"
        network.mark_output(conv1.get_output(0))

        # build engine
        builder.max_workspace_size = 8500000000  # This determines the amount of memory available to the builder when building an optimized engine and should generally be set as high as possible.
        builder.fp16_mode = True
        # builder.strict_type_constraints = True
        builder.min_find_iterations = 10
        builder.average_find_iterations = 10
        with builder.build_cuda_engine(network) as engine:
            with open("to_mount/toy/toy.engine", "wb") as f:
                f.write(engine.serialize())

            # Inference
            # Determine dimensions and create page-locked memory buffers (i.e. won't be swapped to disk) to hold host inputs/outputs.
            h_bindings = []
            d_bindings = []
            import pycuda.autoinit
            import pycuda.driver as cuda
            import numpy as np

            for i in range(engine.num_bindings):
                h_bindings.append(
                    cuda.pagelocked_empty(
                        trt.volume(engine.get_binding_shape(i)), dtype=np.float16
                    )
                )
                # Allocate device memory for inputs and outputs.
                d_bindings.append(cuda.mem_alloc(h_bindings[-1].nbytes))
            # Create a stream in which to copy inputs/outputs and run inference.
            stream = cuda.Stream()

            with engine.create_execution_context() as context:

                print("Starting inference")

                # Transfer input data to the GPU.
                for i in range(engine.num_bindings):
                    cuda.memcpy_htod_async(d_bindings[i], h_bindings[i], stream)
                # Run inference.
                tstart = time.time()
                context.execute_async(
                    bindings=d_bindings, stream_handle=stream.handle,
                )
                # Synchronize the stream
                stream.synchronize()
                # Return the host output.
                print(f"Total Run Time: {time.time()-tstart}")
                # Transfer predictions back from the GPU.
                for i in range(engine.num_bindings):
                    cuda.memcpy_dtoh_async(h_bindings[i], d_bindings[i], stream)
                # Synchronize the stream
                stream.synchronize()

nv-nsight-cu yields:

==PROF== Connected to process 27360 (/usr/bin/python3.6)
Starting inference
==PROF== Profiling “tensor_elementwise_kernel” - 1: 0%…50%…100% - 13 passes
==PROF== Profiling “implicit_convolveNd_sgemm” - 2: 0%…50%…100% - 13 passes
==PROF== Profiling “op_generic_tensor_kernel” - 3: 0%…50%…100% - 13 passes
==PROF== Profiling “nchwTonchw” - 4: 0%…50%…100% - 13 passes
Total Run Time: 9.15361213684082
==PROF== Disconnected from process 27360

Hi @Klamue,
Can you please share your model so that i can help you better.
Thanks!

Hi @AakankshaS

the model is built in the script I provided, which is a really minimal example for a model using conv3d.

Thanks

Hi @Klamue,

If you can provide the engine file you are using in the script, I should be able to help you better.
Thanks!

Hi,

sure. Here is the engine that was serialized (please unzip first): toy.engine.zip (967 Bytes)

I can also give you the verbose log output (RTX 2080 Ti system):

[TensorRT] WARNING: Setting layouts of network and plugin input/output tensors to linear, as 3D operators are found and 3D non-linear IO formats are not supported, yet.
[TensorRT] VERBOSE: Applying generic optimizations to the graph for inference.
[TensorRT] VERBOSE: Original: 1 layers
[TensorRT] VERBOSE: After dead-layer removal: 1 layers
[TensorRT] VERBOSE: After Myelin optimization: 1 layers
[TensorRT] VERBOSE: After scale fusion: 1 layers
[TensorRT] VERBOSE: After vertical fusions: 1 layers
[TensorRT] VERBOSE: After final dead-layer removal: 1 layers
[TensorRT] VERBOSE: After tensor merging: 1 layers
[TensorRT] VERBOSE: After concat removal: 1 layers
[TensorRT] VERBOSE: Graph construction and optimization completed in 0.0017621 seconds.
[TensorRT] VERBOSE: Constructing optimization profile number 0 [1/1].
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.248995
[TensorRT] VERBOSE: Tactic: 0 time 0.272621
[TensorRT] VERBOSE: Fastest Tactic: 1002 Time: 0.248995
[TensorRT] VERBOSE: *************** Autotuning format combination: Float(1,64,4096,262144,16777216) -> Float(1,64,4096,262144,16777216) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (FusedConvActConvolution)
[TensorRT] VERBOSE: FusedConvActConvolution has no valid tactics for this config, skipping
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CaskConvolution)
[TensorRT] VERBOSE: CaskConvolution has no valid tactics for this config, skipping
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CudaConvolution)
[TensorRT] VERBOSE: Tactic: 0 time 1.2455
[TensorRT] VERBOSE: Tactic: 5 time 11.1889
[TensorRT] VERBOSE: Tactic: 57 time 1.05417
[TensorRT] VERBOSE: Fastest Tactic: 57 Time: 1.05417
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CudaDepthwiseConvolution)
[TensorRT] VERBOSE: CudaDepthwiseConvolution has no valid tactics for this config, skipping
[TensorRT] VERBOSE: >>>>>>>>>>>>>>> Chose Runner Type: CudaConvolution Tactic: 57
[TensorRT] VERBOSE: 
[TensorRT] VERBOSE: *************** Autotuning format combination: Half(1,64,4096,262144,16777216) -> Half(1,64,4096,262144,16777216) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (FusedConvActConvolution)
[TensorRT] VERBOSE: FusedConvActConvolution has no valid tactics for this config, skipping
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CaskConvolution)
[TensorRT] VERBOSE: CaskConvolution has no valid tactics for this config, skipping
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CudaConvolution)
[TensorRT] VERBOSE: Tactic: 0 time 0.559293
[TensorRT] VERBOSE: Tactic: 5 time 11.0986
[TensorRT] VERBOSE: Tactic: 57 time 0.561216
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.559293
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CudaDepthwiseConvolution)
[TensorRT] VERBOSE: CudaDepthwiseConvolution has no valid tactics for this config, skipping
[TensorRT] VERBOSE: >>>>>>>>>>>>>>> Chose Runner Type: CudaConvolution Tactic: 0
[TensorRT] VERBOSE: 
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.269802
[TensorRT] VERBOSE: Tactic: 0 time 0.196784
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.196784
[TensorRT] VERBOSE: Adding reformat layer: (Unnamed Layer* 0) [Convolution] reformatted input 0 (input_image) from Float(1,64,4096,262144,16777216) to Half(1,64,4096,262144,16777216)
[TensorRT] VERBOSE: Adding reformat layer: (Unnamed Layer* 0) [Convolution] output to be reformatted 0 (output_featuremap) from Float(1,64,4096,262144,16777216) to Half(1,64,4096,262144,16777216)
[TensorRT] VERBOSE: Formats and tactics selection completed in 3.89279 seconds.
[TensorRT] VERBOSE: After reformat layers: 3 layers
[TensorRT] VERBOSE: Block size 8500000256
[TensorRT] VERBOSE: Block size 33554432
[TensorRT] VERBOSE: Block size 33554432
[TensorRT] VERBOSE: Total Activation Memory: 8567109120
[TensorRT] INFO: Detected 1 inputs and 1 output network tensors.
[TensorRT] VERBOSE: Layer: (Unnamed Layer* 0) [Convolution] input reformatter 0 Weights: 0 HostPersistent: 0 DevicePersistent: 0
[TensorRT] VERBOSE: Layer: (Unnamed Layer* 0) [Convolution] Weights: 8192 HostPersistent: 8 DevicePersistent: 0
[TensorRT] VERBOSE: Layer: (Unnamed Layer* 0) [Convolution] output reformatter 0 Weights: 0 HostPersistent: 0 DevicePersistent: 0
[TensorRT] VERBOSE: Total Host Persistent Memory: 8
[TensorRT] VERBOSE: Total Device Persistent Memory: 0
[TensorRT] VERBOSE: Total Weight Memory: 8192
[TensorRT] VERBOSE: Builder timing cache: created 4 entries, 0 hit(s)
[TensorRT] VERBOSE: Engine generation completed in 4.67907 seconds.
[TensorRT] VERBOSE: Engine Layer Information:
[TensorRT] VERBOSE: Layer(Reformat): (Unnamed Layer* 0) [Convolution] input reformatter 0, Tactic: 1002, input_image[Float(64,64,64,64)] -> (Unnamed Layer* 0) [Convolution] reformatted input 0[Half(64,64,64,64)]
[TensorRT] VERBOSE: Layer(Convolution): (Unnamed Layer* 0) [Convolution], Tactic: 0, (Unnamed Layer* 0) [Convolution] reformatted input 0[Half(64,64,64,64)] -> (Unnamed Layer* 0) [Convolution] output to be reformatted 0[Half(64,64,64,64)]
[TensorRT] VERBOSE: Layer(Reformat): (Unnamed Layer* 0) [Convolution] output reformatter 0, Tactic: 0, (Unnamed Layer* 0) [Convolution] output to be reformatted 0[Half(64,64,64,64)] -> output_featuremap[Float(64,64,64,64)]
Starting inference
Total Run Time: 0.47873854637145996

Thanks!

Could you reproduce or solve this issue on your side? Any progress on it?

Hi @Klamue,
The team is checking into this.
We will update you.
Thanks!

I’m very interested in this too.

I’m very interested in it, too! I have exactly the same problem with my 3D convolution network: model_export.onnx - Google Drive (can’t attach because the file size limitation)