Description
Converting a trained model from Tensorflow 2.8 → ONNX → TensorRT. Part of the model computes the minimum between a desired number of results and the actual number of results available and then passes this value as the K parameter to tf.math.top_k
(that is, K is computed at runtime).
When converting this model to ONNX there are no issues, and the ONNX model checker also returns no errors.
When loading this model with TensorRT (using trtexec) I get the following error
[I] [TRT] ----------------------------------------------------------------
[I] [TRT] Input filename: networks/centernet/inference_model/model.onnx
[I] [TRT] ONNX IR version: 0.0.7
[I] [TRT] Opset version: 13
[I] [TRT] Producer name: tf2onnx
[I] [TRT] Producer version: 1.9.3
[I] [TRT] Domain:
[I] [TRT] Model version: 0
[I] [TRT] Doc string:
[I] [TRT] ----------------------------------------------------------------
[E] [TRT] parsers/onnx/ModelImporter.cpp:782: input: "StatefulPartitionedCall/Reshape:0"
input: "Unsqueeze__1342:0"
output: "scores"
output: "StatefulPartitionedCall/TopKV2:1"
name: "StatefulPartitionedCall/TopKV2"
op_type: "TopK"
attribute {
name: "sorted"
i: 1
type: INT
}
[E] [TRT] parsers/onnx/ModelImporter.cpp:783: --- End node ---
[E] [TRT] parsers/onnx/ModelImporter.cpp:785: ERROR: parsers/onnx/builtin_op_importers.cpp:4519 In function importTopK:
[8] Assertion failed: (inputs.at(1).is_weights()) && "This version of TensorRT only supports input K as an initializer."
My questions are as follows:
- Are there any plans for TensorRT supporting a dynamic K value (one that is computed at runtime) in the near future?
- Is it an error for K to be larger than the number of elements along the specified axis? Will TopK return as many elements as it can, up to K elements?
Environment
TensorRT Version: 8.2.4
GPU Type: NVidia GeForce RTX 3060 Mobile
Nvidia Driver Version: 510.60.02
CUDA Version: 11.6
CUDNN Version: 8.3.3
Operating System + Version: Arch Linux (kernel 5.17.4)
Python Version (if applicable): N/A
TensorFlow Version (if applicable): 2.8
PyTorch Version (if applicable): N/A
Baremetal or Container (if container which image + tag): N/A