klinten
January 11, 2021, 7:59am
20
I found a working solution.
The documentation fails to mention that BatchedNMSPlugin is modeled directly after TensorFlow CombinedNonMaxSuppression:
as compared to
# batchedNMSPlugin
**Table Of Contents**
- [Description](#description)
* [Structure](#structure)
- [Parameters](#parameters)
- [Algorithms](#algorithms)
- [Additional resources](#additional-resources)
- [License](#license)
- [Changelog](#changelog)
- [Known issues](#known-issues)
## Description
The `batchedNMSPlugin` implements a non-maximum suppression (NMS) step over boxes for object detection networks.
Non-maximum suppression is typically the universal step in object detection inference. This plugin is used after you’ve processed the bounding box prediction and object classification to get the final bounding boxes for objects.
With this plugin, you can incorporate the non-maximum suppression step during TensorRT inference. During inference, the neural network generates a fixed number of bounding boxes with box coordinates, identified class and confidence levels. Not all bounding boxes, but the most representative ones, have to be drawn on the original image.
This file has been truncated. show original
So I modified my TF model to use CombinedNMS, then wrote a script using ONNX Graphsurgeon that convert nodes from CombinedNonMaxSuppression
to BatchedNMSDynamic_TRT
based on the the following mapping from the tf2tensorrt code:
return errors::InvalidArgument(
"Node ", node_name,
" with is neither Placeholder nor Arg, instead ", node_def.op());
}
DataType tf_dtype = node_def.attr().at(type_key).type();
if (tf_dtype == DT_RESOURCE) {
VLOG(2) << "Adding engine input resource " << node_name;
TF_RETURN_IF_ERROR(converter->AddInputResource(
node_name, ctx->input(slot_number).flat<ResourceHandle>()(0)));
} else {
nvinfer1::DataType trt_dtype;
nvinfer1::Dims trt_dims;
int batch_size = -1;
const auto shape = input_shapes.at(slot_number);
const auto status = ValidateTensorProperties(
node_def.op(), node_def.attr().at(type_key).type(), shape,
use_implicit_batch, /*validation_only=*/false, &trt_dtype,
&trt_dims, &batch_size);
if (!status.ok()) {
const string error_message =
StrCat("Validation failed for ", node_name, " and input slot ",