TensorRT's OnnxParser problem

EDIT Dec 3rd, 11:32 (GMT+1):

  • Fixed some mistakes in the description of the problem
  • The final command to be run to reproduce the error was launching the wrong script and actually worked. Added the right command.

Description

Hello,

I’ve spent the past few days trying to convert a PyTorch model into a TensorRT engine for fast inference, but I always get some strange errors related to the output of the model.

    # Initialize ONNX parser
    parser = trt.OnnxParser(network, TRT_LOGGER)
    
    # Parse ONNX model
    print('TENSORRT CONVERTER: Parsing ONNX model...')
    with open(self.onnx_path, 'rb') as model:
        parser.parse(model.read())
    print('TENSORRT CONVERTER: Completed parsing of ONNX file.')

    last_layer = network.get_layer(network.num_layers - 1)
    # Check if last layer recognizes it's output
    if not last_layer.get_output(0):
        print("TENSORRT CONVERTER: Manually marking output...")
        # If not, then mark the output using TensorRT API
        network.mark_output(last_layer.get_output(0))

    print("Network layers: {}".format(network.num_layers))
    print("Num inputs: {}".format(network.num_inputs))
    print("Num outputs: {}".format(network.num_outputs))
    print("Output shape: {}".format(last_layer.get_output(0).shape))
    print("Output: {}".format(network.get_output(0)))

My output is:

TENSORRT CONVERTER: Parsing ONNX model...
[TensorRT] WARNING: onnx2trt_utils.cpp:220: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
TENSORRT CONVERTER: Completed parsing of ONNX file.
Network layers: 586
Num inputs: 1
Num outputs: 0
Output shape: (1, 1, 60, 80)
Output: None
TENSORRT CONVERTER: Using FP16.
TENSORRT CONVERTER: Building conversion engine...
[TensorRT] ERROR: Network must have at least one output
[TensorRT] ERROR: Network validation failed.
TENSORRT CONVERTER: Completed creating engine!
TENSORRT CONVERTER: Serializing engine on models/bts_161_8m_640x480_pre_nyu_optimized_fp16.engine
Traceback (most recent call last):
  File "pytorch_to_tensortrt/convert_to_TensorRT.py", line 131, in <module>
    trt_c.convert_model()
  File "pytorch_to_tensortrt/convert_to_TensorRT.py", line 50, in convert_model
    f.write(engine.serialize())
AttributeError: 'NoneType' object has no attribute 'serialize'

Now, apart from the AttributeError, there are some strange things going on here:

  • The number of layers is lower than expected
  • The network has apparently 0 outputs, but the last layer does have one output of shape (1,1,60,80)
  • The output of the last layer has an unexpected size. The output of the model should be [1,1,480,640]

Honestly, it looks like the parser stopped at a certain layer, but never informed me of any errors.

The original model I’m trying to convert is this one: bts/pytorch at master · cleinc/bts · GitHub, but I’ve added some minor optimizations like reducing the number of outputs to 1 (I just don’t need the other outputs) and replacing the repeat_interleave torch function, which is not yet supported by ONNX, with a similar sequence of functions that eventually return the same result.

Then, I wrote some scripts to load a checkpoint for the model, convert it in ONNX format and saving the result on disk.

For the conversion to ONNX, I use:

torch.onnx.export(self.model, self.inputs, self.ONNX_FILE_PATH,
                        input_names=['input'], output_names=['output'],
                        export_params=True, opset_version=11)

A lower opset results in the following warning:

/home/volpepe/anaconda3/lib/python3.7/site-packages/torch/onnx/symbolic_helper.py:267: UserWarning: You are trying to export the model with onnx:Upsample for ONNX opset version 9. This operator might cause results to not match the expected results by PyTorch.
ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator. 
  "" + str(_export_onnx_opset_version) + ". "

and several errors after that, so I’d rather stick to opset 11, if possible.

I then simplified the ONNX model with this tool GitHub - daquexian/onnx-simplifier: Simplify your onnx model (I’m talking about this step for the sake of being accurate, but the same error happens even without the simplification).

I assume that the conversion to ONNX was successful for the following reasons:

  • After the export, I check the model with:
            onnx_model = onnx.load(self.ONNX_FILE_PATH)
            onnx.checker.check_model(onnx_model)
    
  • I have built a small test with onnxruntime that loads both the PyTorch model and the generated ONNX translation, runs inference on them with the same random input and checks if the output is similar. The output is indeed similar if not equal.
  • Netron (https://netron.app/) correctly shows the outline of the network when loading the .onnx model. It also shows the correct input and output shapes.

Therefore, I believe it must be a problem of TensorRT’s OnnxParser or some incompatibility error.

bts_trt.zip (407.8 KB)

Environment

TensorRT Version: 7.2.1.6
GPU Type: GeForce 1660 Ti Mobile
Nvidia Driver Version: 455.45.01
CUDA Version: 11.1
CUDNN Version: 8.0.5
Operating System + Version: Ubuntu 20.04
Python Version (if applicable): 3.7.7
TensorFlow Version (if applicable): None
PyTorch Version (if applicable): 1.7.0
Baremetal or Container (if container which image + tag): Baremetal

Relevant Files and Steps to Reproduce

I have attached a zip with all relevant scripts. The weights for the model are downloaded from the original repository (https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_densenet161.zip). Once downloaded, they should be put into the “models” folder. bts_saver.py can then be run in order to produce the .pth file. Then,

     python pytorch_to_tensortrt/convert_to_ONNX.py -m models/bts_161_8m_640x480_pre_nyu.pth -n 
     bts_161_8m_640x480_pre_nyu -o models

to create the .onnx file.

python -m onnxsim models/bts_161_8m_640x480_pre_nyu.onnx models/bts_161_8m_640x480_pre_nyu_optimized.onnx

is the additional step to simplify the model (GitHub - daquexian/onnx-simplifier: Simplify your onnx model).

Finally, the error is produced when running:

python pytorch_to_tensortrt/convert_to_TensorRT.py -m models/bts_161_8m_640x480_pre_nyu_optimized.onnx -o models -n bts_161_8m_640x480_pre_nyu_optimized -t fp32

Hi @volpepe,
Can you please share your onnx model as well with us?

Thanks!

Of course, sorry if I didn’t think about that sooner but it was late and I had limited connection speed.

Here’s a Google Drive link to the zip containing both the .onnx model generated by torch.onnx.export and the “simplified” version (has “_optimized” in its file name).

Hi @volpepe,
When i tried running your model on the latest TRT release, i am getting the below error


For which, you can refer the below link
https://github.com/onnx/onnx-tensorrt/issues/562#issuecomment-735436875

Thanks!

Hi again,

I think I fixed this problem by rewriting the conversion function like this:

def convert_model(self):
    #save the model
    print("ONNX CONVERTER: Saving the model...")
    torch.onnx.export(self.model, self.inputs, self.ONNX_FILE_PATH,
                    input_names=['input'], output_names=['output'],
                    export_params=True, opset_version=11,)
                    #keep_initializers_as_inputs=True)
    #check integrity
    onnx_model = onnx.load(self.ONNX_FILE_PATH)

    if self.model_name[:3] == 'bts':
        # Fix Clip layers

        original_graph = onnx_model.graph
        original_nodes = onnx_model.graph.node
        fixed_nodes = []
        initializers = onnx_model.graph.initializer
        initializers.append(onnx.helper.make_tensor(
            'max', onnx.TensorProto.FLOAT16, [1], [1]
        ))

        for i in range(len(original_nodes)):
            node = original_nodes[i]
            if node.op_type == "Clip":
                # Three inputs: input, min and max
                inputs = [node.input[0], node.input[1], 'max']
                outputs = node.output
                node = onnx.helper.make_node('Clip',
                            name=node.name, inputs=inputs, outputs=outputs) 
            fixed_nodes.append(node)

        final_graph = onnx.helper.make_graph(
            fixed_nodes, 'bts_optimized_graph', inputs=original_graph.input,
            outputs=original_graph.output, initializer=initializers
        )

        onnx_model = onnx.helper.make_model(final_graph, producer_name='volpepe')

    onnx.checker.check_model(onnx_model)
    print("ONNX CONVERTER: Model saved!")

    onnx.save(onnx_model, self.ONNX_FILE_PATH)

I basically fix 1 (chosen arbitrarily) as a maximum in all clip layers. I never used ONNX before and I don’t know if it is a correct solution, but I stopped having the problem I mentioned above.

That said, I have another problem now. I’ll leave the traceback here just in case, but I understand that it’s an ONNX problem more than a TensorRT or even NVIDIA-related one.

ONNX CONVERTER: Saving the model...
/home/volpepe/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed "
Traceback (most recent call last):
  File "pytorch_to_tensortrt/convert_to_ONNX.py", line 92, in <module>
    OnnxConverter(args.model_path, args.save_path, args.model_name).convert_model()
  File "pytorch_to_tensortrt/convert_to_ONNX.py", line 84, in convert_model
    onnx.checker.check_model(onnx_model)
  File "/home/volpepe/anaconda3/lib/python3.7/site-packages/onnx/checker.py", line 102, in check_model
    C.check_model(protobuf_string)
onnx.onnx_cpp2py_export.checker.ValidationError: Node (Unsqueeze_575) has input size 1 not in range [min=2, max=2].

==> Context: Bad node spec: input: "1743" output: "1744" name: "Unsqueeze_575" op_type: "Unsqueeze" attribute { name: "axes" ints: 1 type: INTS }

I’d be glad if you can help, otherwise I’ll consider this solved.

I finally fixed it and managed to produce a TensorRT FP16 engine.

The solution was downgrading ONNX to version 1.6.0 through pip.

I also had to change the initialization of the max tensor of the previous reply from:

initializers.append(onnx.helper.make_tensor(
           'max', onnx.TensorProto.FLOAT16, [1], [1]
       ))

to:

initializers.append(onnx.helper.make_tensor(
           'max', onnx.TensorProto.FLOAT, dims=0, vals=[1]
       ))

Again, 1 was chosen as an arbitrary max value so it’s probably not perfect, but three wrong layers in a 821-layers network is fine I guess.

Thanks for the support.