Description
I trained a SSD Resnet50_v1_coco using the Tensorflow Object Detection API, then I succesfully converted the SavedModel to onnx using tf2onnx.convert, but when trying to convert it to TensorRT it failed as layer “Non_Max_Supression” is apparently not supported in TensorRT. What is the most appropiate way to solve this?
I 've traied modifying thee onnx model using onnx-graphsurgeon by replacing the unsopported layer with “BatchedNMS_TRT” with no succes (the code is posted below).
Environment
TensorRT Version: 8
GPU Type: Nvidia rtx 2060
Nvidia Driver Version: 460.73.01
CUDA Version: 11.2
CUDNN Version: 8
Operating System + Version: Ubuntu 20.04
Python Version (if applicable): 3.8.5
TensorFlow Version (if applicable): 2.4.1
Steps To Reproduce
@gs.Graph.register()
def trt_batched_nms(self, boxes_input, scores_input, nms_output,
share_location, num_classes):
boxes_input.outputs.clear()
scores_input.outputs.clear()
nms_output.inputs.clear()
attrs = {
"shareLocation": share_location,
"numClasses": num_classes,
"backgroundLabelId": 0,
"topK": 116740,
"keepTopK": 100,
"scoreThreshold": 0.3,
"iouThreshold": 0.6,
"isNormalized": True,
"clipBoxes": True
}
return self.layer(op="BatchedNMS_TRT", attrs=attrs,
inputs=[boxes_input, scores_input],
outputs=[nms_output])
graph.inputs[0].shape=[1,320,320,3]
print(graph.inputs[0])
print(graph.inputs[0].shape)
for inp in graph.inputs:
inp.dtype = np.int
input = graph.inputs[0]
tmap = graph.tensors()
graph.trt_batched_nms(tmap[“StatefulPartitionedCall/Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/unstack__830:0”],
tmap["Unsqueeze__833:0"],
tmap["NonMaxSuppression__837:0"],
share_location=True,
num_classes=4)
graph.cleanup()
graph.toposort()
onnx.checker.check_model(gs.export_onnx(graph))
onnx.save_model(gs.export_onnx(graph), output_model_path)
When perfoming onnx.checker.check_model an error pops up, as apparently unsqueeze op has an invalid input shape = [1] when the min limit and max limit both are equal to 2. Any recommendations?
Thanks in advance.