Description
when we use two 1x1 tensor from global avg/max pool and use 1x1 conv, tensorrt 8.0.1 raise a error, tensorrt 7.1.3.4 has no problem.
[TensorRT] ERROR: 2: [pointWiseV2Helpers.h::createTensorDesc::296] Error Code 2: Internal Error (Assertion tensor.extent.d[j] == 1 failed.)
this bug don’t have simple workaround, please fix it as quickly as possible.
Environment
TensorRT Version: 8.0.1.6
GPU Type: GTX 1080
Nvidia Driver Version: 460.80
CUDA Version: 11.2
CUDNN Version: 8.2.1
Operating System + Version: Ubuntu 20.04
Python Version (if applicable): 3.8
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):
Steps To Reproduce
Run following script in samples/python/network_api_pytorch_mnist
from PIL import Image
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import common
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
class ModelData(object):
INPUT_NAME = "data"
INPUT_SHAPE = (16, 1, 1)
OUTPUT_NAME = "prob"
OUTPUT_SIZE = 10
DTYPE = trt.float32
def pointwise_fusion_bug(network, weights):
# Configure the network layers based on the weights provided.
input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)
conv1_w = trt.Weights(weights['conv1.weight'])
conv1_b = trt.Weights()
conv1 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
conv2 = network.add_convolution(input=input_tensor, num_output_maps=32, kernel_shape=(1, 1), kernel=conv1_w, bias=conv1_b)
add1 = network.add_elementwise(conv1.get_output(0), conv2.get_output(0), trt.ElementWiseOperation.SUM)
network.mark_output(tensor=add1.get_output(0))
def build_engine(weights):
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network:
cfg = builder.create_builder_config()
cfg.max_workspace_size = common.GiB(1)
# builder.max_workspace_size = common.GiB(1)
# Populate the network using weights from the PyTorch model.
pointwise_fusion_bug(network, weights)
# Build and return an engine.
return builder.build_engine(network, cfg)
def main():
common.add_help(description="Runs an MNIST network using a PyTorch model")
# Train the PyTorch model
wconv = np.random.uniform(-5, 5, size=[32, 16, 1, 1]).astype(np.float32)
x = np.random.uniform(-5, 5, size=[1, 16, 1, 1]).astype(np.float32)
weights = {"conv1.weight": wconv}
# Do inference with TensorRT.
with build_engine(weights) as engine:
# Build an engine, allocate buffers and create a stream.
# For more information on buffer allocation, refer to the introductory samples.
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
with engine.create_execution_context() as context:
inputs[0].host[:] = x.reshape(-1)
# For more information on performing inference, refer to the introductory samples.
# The common.do_inference function will return a list of outputs - we only have one in this case.
[output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
if __name__ == '__main__':
print(trt.__version__)
main()