Myelin problem converting TensorflowV2.5.0 Object_detection_api to TensorRT on jetson nano

Hi, I am working on deploying MobileNetv2 on Jetson Nano and am having trouble converting these networks from TensorflowV2.5.0 Object_detection_api to TensorRT. I tried to convert the .onnx model to TensorRT engine, but I get this Myelin error:

[06/21/2021-10:11:26] [V] [TRT] Formats and tactics selection completed in 157.841 seconds.
[06/21/2021-10:11:26] [V] [TRT] After reformat layers: 135 layers
[06/21/2021-10:11:26] [V] [TRT] Block size 16777216
[06/21/2021-10:11:26] [V] [TRT] Block size 8640000
[06/21/2021-10:11:26] [V] [TRT] Block size 3240448
[06/21/2021-10:11:26] [V] [TRT] Block size 1440256
[06/21/2021-10:11:26] [V] [TRT] Block size 540160
[06/21/2021-10:11:26] [V] [TRT] Block size 39424
[06/21/2021-10:11:26] [V] [TRT] Block size 7680
[06/21/2021-10:11:26] [V] [TRT] Block size 7680
[06/21/2021-10:11:26] [V] [TRT] Block size 2048
[06/21/2021-10:11:26] [V] [TRT] Block size 1536
[06/21/2021-10:11:26] [V] [TRT] Block size 1024
[06/21/2021-10:11:26] [V] [TRT] Block size 512
[06/21/2021-10:11:26] [V] [TRT] Block size 512
[06/21/2021-10:11:26] [V] [TRT] Block size 512
[06/21/2021-10:11:26] [V] [TRT] Total Activation Memory: 30699008
[06/21/2021-10:11:26] [I] [TRT] Detected 1 inputs and 4 output network tensors.
[06/21/2021-10:11:26] [E] [TRT] ../builder/myelin/codeGenerator.cpp (112) - Myelin Error in addNodeToMyelinGraph: 0 (map/while/TensorArrayV2Read/TensorListGetItem operation not supported within a loop body.)
[06/21/2021-10:11:26] [V] [TRT] Builder timing cache: created 72 entries, 22 hit(s)
[06/21/2021-10:11:26] [E] [TRT] ../builder/myelin/codeGenerator.cpp (112) - Myelin Error in addNodeToMyelinGraph: 0 (map/while/TensorArrayV2Read/TensorListGetItem operation not supported within a loop body.)
[06/21/2021-10:11:26] [E] Engine creation failed
[06/21/2021-10:11:26] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec # trtexec --explicitBatch --onnx=/home/varya/weights/tf22/model_float_nms.onnx --verbose

Environment

TensorRT Version : 7.2
NVIDIA GPU : Jetson Nano L4T 32.4.3; nvidia-jetpack: 4.4.1
Python Version : 3.6
Tensorflow Version : 2.3.1+nv20.12
onnx : 1.9.0
tensorflow-onnx : 1.9.0
onnx-graphsurgeon : 0.3.7

Relevant Files

onnx models:

Steps To Reproduce

to convert I do the following:

  1. convert the save_model to .onnx file
    tf2onnx.convert --saved-model /home/tf2/saved_model --output /home/tf2/model.onnx --opset 11
  2. change data type from uint8 to float32
import onnx
import numpy as np
import onnx_graphsurgeon as gs

graph = gs.import_onnx(onnx.load("/home/tf2/model.onnx"))
for inp in graph.inputs:
    inp.dtype = np.float32
onnx.save(gs.export_onnx(graph), "/home/tf2/model_float.onnx")
  1. replace NMS using the BatchedNMSDynamic_TRT plugin
import onnx
import numpy as np
import onnx_graphsurgeon as gs

print ("Running BatchedNMSDynamic_TRT plugin the ONNX model.. ")

input_model_path = "/home/tf2/model_float.onnx"
output_model_path = "/home/tf2/model_nms.onnx"

def replace_nms(graph, boxes, scores, nms_out):
    scores.outputs.clear()
    nms_out.inputs.clear()
    nms_attrs = {'shareLocation': True,
                'backgroundLabelId': -1,
                'numClasses': 4,
                'topK': 20,
                'keepTopK': 20,
                'scoreThreshold': 0.3,
                'iouThreshold': 0.6,
                'isNormalized': True,
                'clipBoxes': True}
    
    nms_num_detections = gs.Variable(name="nms_num_detections", dtype=np.int32, shape=(1, 1))
    nms_boxes =  gs.Variable(name="nms_boxes", dtype=np.float32, shape=(1, 20, 4))
    nms_scores = gs.Variable(name="nms_scores", dtype=np.float32, shape=(1, 20))
    nms_classes = gs.Variable(name="nms_classes", dtype=np.float32, shape=(1, 20))

    boxes_reshaped = gs.Variable(name="boxes_reshaped", dtype=np.float32, shape=(1, -1, 1, 4))
    boxes_shape = gs.Constant(name="boxes_shape", values=np.array([1, -1, 1, 4], dtype=np.int32))
    
    reshape = gs.Node(op="Reshape", inputs=[boxes, boxes_shape], outputs=[boxes_reshaped])
    graph.nodes.append(reshape)

    node = gs.Node(op="BatchedNMSDynamic_TRT", attrs=nms_attrs, 
                inputs=[reshape.outputs[0], scores], outputs=[nms_num_detections, nms_boxes, nms_scores, nms_classes])
    graph.nodes.append(node)
    graph.outputs = [nms_num_detections, nms_boxes, nms_scores, nms_classes]
    return graph

graph = gs.import_onnx(onnx.load(input_model_path))
graph.inputs[0].shape=[1,300,300,3]
tmap = graph.tensors()

boxes, scores, nms_out = tmap["Unsqueeze__568:0"], tmap["Unsqueeze__671:0"], tmap["NonMaxSuppression__673:0"]
boxes.outputs.clear()
graph = replace_nms(graph, boxes, scores, nms_out)

scores, nms_out = tmap["Unsqueeze__637:0"], tmap["NonMaxSuppression__639:0"]
graph = replace_nms(graph, boxes, scores, nms_out)

scores, nms_out = tmap["Unsqueeze__603:0"], tmap["NonMaxSuppression__605:0"]
graph = replace_nms(graph, boxes, scores, nms_out)

scores, nms_out = tmap["Unsqueeze__569:0"], tmap["NonMaxSuppression__571:0"]
graph = replace_nms(graph, boxes, scores, nms_out)

graph.cleanup()
graph.toposort()
onnx.save_model(gs.export_onnx(graph), output_model_path)
  1. convert to TRT
    trtexec --explicitBatch --onnx=/home/tf2/model_nms.onnx --verbose

Hi,

Thanks for reporting this.

We are trying to reproduce it internally.
Will get back to you later.

Hi,

Thanks for your patience.
We can reproduce both issues (model_float.onnx &model_nms.onnx) in our environment.

The myelin error might be caused by the NMS plugin.
We need more time to check this deeper.

Will share more information with you later.
Thanks.

Hi,

We test your model on both Tensor v7.1 and v8.0.
Unfortunately, the error still remains.

To check it further, could you share which TensorFlow model do you use?
Since there is much variance for the NMS layer, this will help us to fix the issue.

Thanks.

Hi, thanks for the answer
For training, I am using the SSD MobileNet v2 320x320 from object detection model zoo models/tf2_detection_zoo.md at master · tensorflow/models · GitHub
Added my saved model to yandex drive

Hi,

Thanks for sharing the file.
We are checking this internally. Will share more information with you later.

Thanks.

Hi,

Sorry for keeping you waiting.

We have rooted cause the error when converting a TFv2 SSD MobilenetV2 into TensorRT.
The error is caused by a non-supported TopK layer from below:

Currently, we only support TopK with a pre-defined (fixed) K value in TensorRT v8.0.
However, the model decides the K value by a runtime tensor output which is dynamic.
So it cannot be converted into TensorRT.

Thanks.