TensorRT conversion issues of ONNX model trained with Quantization Aware Training + custom quantization scale

Description

I trained in Tensorflow 2 a detection model with Quantization Aware Training (QAT). The quantization ops were added after the following pattern:
Conv2D → BatchNorm → Activation (as Nvidia guidelines say for QAT)
For a tensor x, I demanded the quantization scale to be custom for each quantized tensor in the following way:
x = tf.quantization.quantize_and_dequantize(x, input_min=-127, input_max=127, range_given=False)
The argument range_given=False means the input_min / input_max are ignored, so the min/max values will be taken from each quantized tensor.
tf2onnx recently added support for this mode, but it seems TensorRT doesn’t support such ONNX models: When trying to convert the model to TRT, I get the following error when parsing the onnx model:

[TensorRT] VERBOSE: ModelImporter.cpp:125: QuantLinearNode__20 [QuantizeLinear] inputs: [StatefulPartitionedCall/functional_3/tiny_yolov3/yolo_darknet/leaky_re_lu/LeakyRelu:0 -> (1, 16, 224, 1408)], [Max__18:0 -> ()], [zero_point__139 -> ()], 
ERROR: Failed to parse the ONNX file.
In node -1 (importQuantizeLinear): INVALID_NODE: Assertion failed: inputs.at(1).is_weights()

The structure of a quantized layer with range_given=True (using the same quant x_scale for all layers: 64/127 ~= 0.503)

The structure of a quantized layer with range_given=False (getting the quant x_scale from each tensor’s values):

I tried using onnx-graphsurgeon with fold_constants function on the onnx graph, but this didn’t help- I had the same error.
I would like to know if there’s a workaround for converting such an ONNX model to TensorRT without needing to update jetpack etc.

Environment

TensorRT Version : 7.1.2
CUDA Version : 11.0
Operating System + Version : Ubuntu 18.04
Python Version (if applicable) : 3.6
TensorFlow Version (if applicable) : The model was trained on tf 2.3, converted to onnx, and then converted to tensorRT engine.

I can’t share the relevant model for this.

Any help will be appreciated!

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

Hi @NVES ,
We already apply onnx.checker.check_model(model) on our model when converting to ONNX, so this isn’t the issue.
I created a basic “dummy” model that reproduces the problem in this code:
basic_model_qat.py (917 Bytes)
Converted this to onnx, this is the onnx model:
dummy_qat_model_explicit_dims_batch_1.onnx (28.4 KB)
I ran trtexec and got the same error. verbose log:
dummy_qat_trtexec_verbose_log.txt (16.1 KB)

Hi @weissrael,

Please try following.

  1. Do you have any particular reason to use x = tf.quantization.quantize_and_dequantize ? IIRC, TF QAT is able to insert FakeQuant nodes, i.e. you don’t need to manipulate the graph for most cases.
  2. The graph is different here, looks like the quantization parameter needs to be computed online if range_given=False , tf2onnx handles this which is out of the scope for TensorRT.
  3. onnx-graphsurgeon with fold_constants doesn’t work because the quantization parameters are not constants in the graph.
  4. TensorRT only supports quantization parameters that are constants currently. Will enhance in the future.

Thank you.

Hi @spolisetty ,
I’m replying to each point separately

  1. I’m following Nvidia guidelines for Quantization Aware Training- Due to the layer fusion, I insert the quantization op after the Activation function in the pattern Conv2D → BatchNorm → Activation.
    Your samples for TF pointed towards using tf.quantization.quantize_and_dequantize.
    a) SampleQAT of Nvidia links here:DeepLearningExamples/TensorFlow/Classification/ConvNets/resnet50v1.5 at master · NVIDIA/DeepLearningExamples · GitHub
    b) And the following Nvidia webinar also used tf.quantization.quantize_and_dequantize:
    http://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21664-toward-int8-inference-deploying-quantization-aware-trained-networks-using-tensorrt.pdf
    The reason I don’t use FakeQuant is because it seems FakeQuant assumes constant quantizations scale, the same as using x = tf.quantization.quantize_and_dequantize with argument range_given=True. I want to extract the quantization scale for each layer according to its values and not use constant quantization scale for all layers. Does TensorRT support custom quantization scale for each layer to begin with?

  2. tf2onnx indeed handles the quantization parameters for each layer as wanted. It seems to me that currently TensorRT does not support such setup of custom quantization scale for each layer, since it gave me the error I wrote in the post. So since all is well with the conversion to ONNX, it seems it is indeed in the scope for TensorRT.

  3. Can onnx-graphsurgeon help somehow in this scenario I described, before converting the ONNX model to TensorRT?

  4. Do you have an estimate as to when this enhance ment will be available?

Hi @weissrael,

TensorRT supports per-tensor quantization parameters, please check following API,
https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_tensor.html#a956f662b1d2ebe7ba3aba3391aedddf5

For converting from ONNX, the quantization parameters need to be constants, i.e. the input tensor of Quantize/DequantizeLinear needs to be weights in ONNX.

We cannot provide you the ETA of enhancement of QAT support at this moment. Stay tuned for updates.

Thank you.