Build engine error when use pointnet-like structure and TensorRT 8.0.1.6

Description

when we use two 1x1 tensor from global avg/max pool and use 1x1 conv, tensorrt 8.0.1 raise a error, tensorrt 7.1.3.4 has no problem.
[TensorRT] ERROR: 2: [pointWiseV2Helpers.h::createTensorDesc::296] Error Code 2: Internal Error (Assertion tensor.extent.d[j] == 1 failed.)
image

this bug don’t have simple workaround, please fix it as quickly as possible.

Environment

TensorRT Version: 8.0.1.6
GPU Type: GTX 1080
Nvidia Driver Version: 460.80
CUDA Version: 11.2
CUDNN Version: 8.2.1
Operating System + Version: Ubuntu 20.04
Python Version (if applicable): 3.8
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

Steps To Reproduce

Run following script in samples/python/network_api_pytorch_mnist

from PIL import Image
import numpy as np

import pycuda.driver as cuda
import pycuda.autoinit

import tensorrt as trt

import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import common

# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class ModelData(object):
    INPUT_NAME = "data"
    INPUT_SHAPE = (16, 1, 1)
    OUTPUT_NAME = "prob"
    OUTPUT_SIZE = 10
    DTYPE = trt.float32

def pointwise_fusion_bug(network, weights):
    # Configure the network layers based on the weights provided.
    input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)
    conv1_w = trt.Weights(weights['conv1.weight'])
    conv1_b = trt.Weights()
    conv1 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
    conv2 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
    add1 = network.add_elementwise(conv1.get_output(0), conv2.get_output(0), trt.ElementWiseOperation.SUM)
    network.mark_output(tensor=add1.get_output(0))

def build_engine(weights):
    # For more information on TRT basics, refer to the introductory samples.
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network:
        cfg = builder.create_builder_config()
        cfg.max_workspace_size = common.GiB(1)

        # builder.max_workspace_size = common.GiB(1)
        # Populate the network using weights from the PyTorch model.
        pointwise_fusion_bug(network, weights)
        # Build and return an engine.
        return builder.build_engine(network, cfg)

def main():
    common.add_help(description="Runs an MNIST network using a PyTorch model")
    # Train the PyTorch model
    wconv = np.random.uniform(-5, 5, size=[32, 16, 1, 1]).astype(np.float32)
    x = np.random.uniform(-5, 5, size=[1, 16, 1, 1]).astype(np.float32)
    weights = {"conv1.weight": wconv}
    # Do inference with TensorRT.
    with build_engine(weights) as engine:
        # Build an engine, allocate buffers and create a stream.
        # For more information on buffer allocation, refer to the introductory samples.
        inputs, outputs, bindings, stream = common.allocate_buffers(engine)
        with engine.create_execution_context() as context:
            inputs[0].host[:] = x.reshape(-1)
            # For more information on performing inference, refer to the introductory samples.
            # The common.do_inference function will return a list of outputs - we only have one in this case.
            [output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

if __name__ == '__main__':
    print(trt.__version__)
    main()

Hi @scrin,

Please share us complete error logs and issue repro ONNX model to try from our end for better assistance.

Thank you.

to get error log, just run that script, it’s a minimal reproduce script. that script pass in trt 7134 but fail in trt 8016. this bug only exists with NxCx1x1 tensors.
we use tensorrt network api to construct network, onnx model isn’t available, I think attached script is enough for debug.

Hi,
Please refer to the below link for Sample guide.

Refer to the installation steps from the link if in case you are missing on anything

However suggested approach is to use TRT NGC containers to avoid any system dependency related issues.

In order to run python sample, make sure TRT python packages are installed while using NGC container.
/opt/tensorrt/python/python_setup.sh

In case, if you are trying to run custom model, please share your model and script with us, so that we can assist you better.
Thanks!

the model is declared in that script:

def pointwise_fusion_bug(network, weights):
    # Configure the network layers based on the weights provided.
    input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)
    conv1_w = trt.Weights(weights['conv1.weight'])
    conv1_b = trt.Weights()
    conv1 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
    conv2 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
    add1 = network.add_elementwise(conv1.get_output(0), conv2.get_output(0), trt.ElementWiseOperation.SUM)
    network.mark_output(tensor=add1.get_output(0))

what are you talking about???

Hi @scrin,

We could reproduce this error. Let us get back on this again.

@scrin,

Could you please try using EXPLICIT_BATCH f lag for builder.create_network(). And let us know if you still face this issue.
For your reference, Developer Guide :: NVIDIA Deep Learning TensorRT Documentation

Thank you.

1 Like

I have test our model with explicit batch mode but still have bug. the input of our model have dynamic dim in batch axis. Here is another minimal reproduce script:


import numpy as np

import pycuda.driver as cuda
import pycuda.autoinit

import tensorrt as trt

import sys, os
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class ModelData(object):
    INPUT_NAME = "data"
    INPUT_SHAPE = (-1, 16, 1, 1)
    OUTPUT_NAME = "prob"
    OUTPUT_SIZE = 10
    DTYPE = trt.float32

def pointwise_fusion_bug(network, weights):
    # Configure the network layers based on the weights provided.
    input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)
    conv1_w = trt.Weights(weights['conv1.weight'])
    conv1_b = trt.Weights()
    conv1 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
    conv2 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
    add1 = network.add_elementwise(conv1.get_output(0), conv2.get_output(0), trt.ElementWiseOperation.SUM)
    network.mark_output(tensor=add1.get_output(0))

def build_engine(weights):
    # For more information on TRT basics, refer to the introductory samples.
    mode = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(mode) as network:
        cfg = builder.create_builder_config()
        cfg.max_workspace_size = 1 << 30
        profile = builder.create_optimization_profile()
        profile.set_shape("data", [1, 16, 1, 1], [2, 16, 1, 1], [4, 16, 1, 1])
        cfg.add_optimization_profile(profile)
        # builder.max_workspace_size = common.GiB(1)
        # Populate the network using weights from the PyTorch model.
        pointwise_fusion_bug(network, weights)
        # Build and return an engine.
        return builder.build_engine(network, cfg)

def main():
    # Train the PyTorch model
    wconv = np.random.uniform(-5, 5, size=[32, 16, 1, 1]).astype(np.float32)
    weights = {"conv1.weight": wconv}
    # Do inference with TensorRT.
    with build_engine(weights) as engine:
        pass


if __name__ == '__main__':
    print(trt.__version__)
    main()

pass in trt 7134 but fail in trt 8016.

1 Like

Thank you for the confirmation. Please allow us some time to work on this.

Hi,

This issue has been fixed. Please try latest TRT version.

Thank you.

Hi,

is the latest TRT version 8.2.0?

Thank you.

Yes @xiazheng1996

Thank you!

I encounter the same error in latest TensorRT8.0(8.0.3.4). Why not fix the issue in TensorRT8.0