Onnx to TensorRT, with NVIDIA plugins (gridAnchor)

Hi, i’m trying to replicate the TensorRT implementations for the models shown here https://github.com/AastaNV/TRT_object_detection
but using the tf2onnx instead of using the uff parser.

Is there some guidance for how to do this? I am mostly using the code discussed here as guide.

My specific error is from the grid anchor plugin,

[06/27/2021-17:10:52] [F] [TRT] Assertion failed: numExpectedLayers == numLayers
gridAnchorPlugin.cpp:454

I’m using the same gridAnchor parameters as the uff examples. Here’s the code. The input onnx graph was created from from a SSD_Inception_V2 frozen_graph.pb file using tf2onnx. This pb file can be converted can be converted to uff and run with tensorRT as shown in here GitHub - AastaNV/TRT_object_detection: Python sample for referencing object detection model with TensorRT

import onnx_graphsurgeon as gs
import argparse
import onnx
import numpy as np


####
#### Model parameters
####
num_classes = 9
init_inputs = ["image_tensor:0"]
init_outputs = ["detection_boxes:0",
                "detection_scores:0", 
                "detection_classes:0", 
                "num_detections:0",]

def generate_prior_var_shapes(batch_size):
    prior_var_shapes = [[batch_size,2,3*19*19*4,1],
                        [batch_size,2,6*10*10*4,1],
                        [batch_size,2,6*5*5*4,1],
                        [batch_size,2,6*3*3*4,1],
                        [batch_size,2,6*2*2*4,1],
                        [batch_size,2,6*1*1*4,1],
    ]
    return prior_var_shapes
prior_box_attrs = {
    'numLayers':6,
    'minSize':0.2,
    'maxSize':0.95,
    'aspectRatios':[1.0, 2.0, 0.5, 3.0, 0.33],
    'variance':[0.1,0.1,0.2,0.2],
    'featureMapShapes':[19, 10, 5, 3, 2, 1],

}
NMS_attrs = {
                'shareLocation':1,
                'varianceEncodedInTarget':0,
                'backgroundLabelId':0,
                'confidenceThreshold':1e-8,
                'nmsThreshold':0.6,
                'topK':100,
                'keepTopK':100,
                'numClasses':9,
                'inputOrder':[0, 2, 1],
                'confSigmoid':1,
                'isNormalized':1
                }
def create_and_add_plugin_node(graph,input_shape=None):
    
    if input_shape is None:
        batch_size = graph.inputs[0].shape[0]
        input_h = graph.inputs[0].shape[1]
        input_w = graph.inputs[0].shape[2]
        channels = graph.inputs[0].shape[3]
    else:
        batch_size,input_h,input_w,channels = input_shape

    nodes = graph.nodes

    tensors = graph.tensors()

    prior_var_shapes = generate_prior_var_shapes(batch_size)

    grid_anchor_outs = []
    total_boxes = 0
    for idx,data  in enumerate(zip(prior_box_attrs['featureMapShapes'],prior_var_shapes)):
        nprior,var_shape = data
        total_boxes+=var_shape[2]
        grid_anchor_outs.append(gs.Variable(name="GridAnchor_{}_out".format(idx)).to_variable(dtype=np.int32, 
                                                                             shape=var_shape))
                                                        
    box_node = gs.Node(name= 'GridAnchor',
                op="GridAnchor_TRT",
                attrs=prior_box_attrs,
                outputs=grid_anchor_outs
                )
    concat_priorbox_out = gs.Variable(name="concat_priorbox_out").to_variable(dtype=np.float32,
                                                                             shape=[batch_size,2,total_boxes,1])
    concat_priorbox = gs.Node(name="concat_priorbox", op="ConcatV2",
                             attrs={'axis':2},
                             inputs=grid_anchor_outs,
                             outputs=[concat_priorbox_out])
    
    graph.nodes.append(concat_priorbox)
    graph.nodes.append(box_node)

### 
### Below has more to do with NMS, not directly related to gridAnchor error
### 
    original_concat_box_loc = tensors["concat:0"]
    original_concat_box_conf = tensors["concat_1:0"]

    keepTopK = NMS_attrs['keepTopK']
    NMS_out_0 = gs.Variable(name="NMS_out_0").to_variable(dtype=np.float32, 
                                                        shape=[batch_size,1,keepTopK,7])
    NMS_out_1 = gs.Variable(name="NMS_out_1").to_variable(dtype=np.int32, 
                                                    shape=[batch_size,1,1,1])
    nms_node = gs.Node(name="NMS",
        op="NMS_TRT",
        attrs=NMS_attrs,
        inputs=[original_concat_box_loc, concat_priorbox_out,original_concat_box_conf],
        outputs=[NMS_out_0,NMS_out_1])
    
    graph.nodes.append(nms_node)
    graph.outputs = [NMS_out_0]
    return graph.cleanup().toposort()


def main():
    parser = argparse.ArgumentParser(description="Add NMS")
    parser.add_argument("-onnx_path", "--onnx_path", 
                        help="Path to the ONNX model generated by export_model.py", 
                        )
    parser.add_argument("--input_shape", "--input_shape", default=None,type=int,nargs='+',
                    help="input shape as list (in NHWC order) e.g. '-1 300 300 3'", 
                    )
    args, _ = parser.parse_known_args()

    graph = gs.import_onnx(onnx.load(args.onnx_path))
    
    graph = create_and_add_plugin_node(graph,input_shape=args.input_shape)
    
    onnx.save(gs.export_onnx(graph), args.onnx_path + ".nms.onnx")


if __name__ == '__main__':
    main()

Environment

TensorRT Version: 8.0
GPU Type: 1070
Nvidia Driver Version: 460.73.01
CUDA Version: 11.2
CUDNN Version: 8
Operating System + Version: Ubuntu 18.04.5 LTS
Python Version (if applicable): 3.7
TensorFlow Version (if applicable): 2.5

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

models are here (one source onnx before running the above script and one after(.nms.onnx)).

The source onnx loads fine with onnx.checker.check_model, but fails TRTexec. It’s the Postprocessor layer that we want to replace with the NMS plugin. (Note: I don’t think the specific error is important here, other parts of the Postprocessor have errors too)

[06/27/2021-19:17:23] [V] [TRT] Registering layer: Postprocessor/Reshape for ONNX node: Postprocessor/Reshape
[06/27/2021-19:17:23] [E] Error[4]: Postprocessor/Reshape: volume mismatch. Input dimensions [2147483647,1917,4] have volume 16466904605196 and output dimensions [2147481731,4] have volume 8589926924.
[06/27/2021-19:17:23] [E] [TRT] ModelImporter.cpp:738: While parsing node number 299 [Reshape -> "Postprocessor/Reshape:0"]:
[06/27/2021-19:17:23] [E] [TRT] ModelImporter.cpp:739: --- Begin node ---
[06/27/2021-19:17:23] [E] [TRT] ModelImporter.cpp:740: input: "Postprocessor/Tile:0"
input: "const_fold_opt__7205"
output: "Postprocessor/Reshape:0"
name: "Postprocessor/Reshape"
op_type: "Reshape"

[06/27/2021-19:17:23] [E] [TRT] ModelImporter.cpp:741: --- End node ---
[06/27/2021-19:17:23] [E] [TRT] ModelImporter.cpp:744: ERROR: ModelImporter.cpp:197 In function parseGraph:
[6] Invalid Node - Postprocessor/Reshape
Postprocessor/Reshape: volume mismatch. Input dimensions [2147483647,1917,4] have volume 16466904605196 and output dimensions [2147481731,4] have volume 8589926924.

The NMS plugged model (nms.onnx) fails the onnx check_model, this was expected due to the addition of the TensorRT plugins

    C.check_model(protobuf_string)
onnx.onnx_cpp2py_export.checker.ValidationError: No Op registered for GridAnchor_TRT with domain_version of 13

==> Context: Bad node spec for node. Name: GridAnchor OpType: GridAnchor_TRT

It also fails TRTexec gives the “Assertion failed: numExpectedLayers == numLayers gridAnchorPlugin.cpp:454” error described above

@philminhnguyen,

This looks like some problem with plugin. The gridAnchorPlugin is open-sourced.
Please find out what the plugin expects here. For example, whether static_cast<int>(fMapShapes.size()) >> (isFMapRect ? 1 : 0) equals to numLayers

Thank you.

How would you recommend I go about debugging this? gdb+trtexec doesn’t give any information on line numbers/local variables

Hi @philminhnguyen,

If you build TRT OSS repo, you will get the libnvinfer_plugin.so . And then just set LD_LIBRARY_PATH to load this libnvinfer_plugin.so instead of the one shipped with TRT package.

Thank you.

Ah, that makes sense.
I’m currently working using one of the TensorRT images, and the LD_LIBRARY_PATH is

/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64

Seems like libnvinfer_plugin.so is located at

/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so

So how is libnvinfer_plugin.so getting detected? it’s working fine with the tensorrt python api? is /usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so the correct file to replace?

Upon debugging, it seems that all list lengths are getting multiplied by 4. I’m not quite sure why this is happening.

For example, the ‘featureMapShapes’ attribute of my GridAnchor_TRT plugin (as given in the python code in my opening post) is:
[19, 10, 5, 3, 2, 1]

However, within the cpp code, the length (“fields[i].length”) of the list is 24, thus the resulting fMapShapes is:
{19, 10, 5, 3, 2, 1, 81, 0, 60757408, 22006, 60757576, 22006, 8, 0, 1769103734, 1701015137, 0, 0, 60757488, 22006, 60757504, 22006,
60757504, 22006}
So the first 6 values match what is expected, and the following 18 numbers are gibberish.
Similarly the plugin, which is type PluginFieldCollection, is a list/sequence of fields for the plugin, is also 4x as long. As a result, all the lists will be multiplied by 4 again. Thus the length of fMapShapes at the line when it errors is of length 96 (6 x 4 x 4). This is what gdb says:

(gdb) p fMapShapes
$15 = std::vector of length 96, capacity 96 = {19, 10, 5, 3, 2, 1, 81, 0, 60757408, 22006, 60757576, 22006, 8, 0, 1769103734, 1701015137, 0, 0, 60757488, 22006, 60757504, 22006,
60757504, 22006, 19, 10, 5, 3, 2, 1, 81, 0, 60757408, 22006, 60757576, 22006, 8, 0, 1769103734, 1701015137, 0, 0, 60757488, 22006, 60757504, 22006, 60757504, 22006, 19, 10, 5,
3, 2, 1, 81, 0, 60757408, 22006, 60757576, 22006, 8, 0, 1769103734, 1701015137, 0, 0, 60757488, 22006, 60757504, 22006, 60757504, 22006, 19, 10, 5, 3, 2, 1, 81, 0, 60757408,
22006, 60757576, 22006, 8, 0, 1769103734, 1701015137, 0, 0, 60757488, 22006, 60757504, 22006, 60757504, 22006}

@philminhnguyen,

Thank you for sharing. Please allow us some time to work on this.

Hi, I also tried to replicate the TensorRT import for tensorflow-object detection models using tf2onnx and the TensorRT onnx parser and ran into the same issue.

Adding the line mPluginAttributes.clear(); at the start of the Consructor of GridAnchorBasePluginCreator in gridAnchorPlugin.cpp solved the issue for me.

However, after that, further issues occurred, where the shapes of the output tensors of the plugin node could not be inferred, which caused a Segfault in the parser.
In the end, it was necessary to implement IPluginV2DynamicExt - Versions of the GridAnchor and FlattenConcat plugins, to be able to import them with the onnx parser (which requires the explicitBatch flag to be set for the network).

I don’t know if this particular issue has been solved, but I now use modifications of this script to convert my models

Hi,

The model has GridAnchor_TRT plugin node with no inputs, which is invalid. Also bug in the plugin implementation. TensorRT/plugin/gridAnchorPlugin at master · NVIDIA/TensorRT · GitHub

Thank you.