I trained in Tensorflow 2 a detection model with Quantization Aware Training (QAT) and then converted this model to ONNX. We convert the model from ONNX to TRT based on the following Nvidia samples code.
When converting this model to TRT, we get the following error:
[TensorRT] ERROR: ../builder/Network.cpp (1653) - Assertion Error in validateExplicitPrecision: 0 (layer.getNbInputs() == 2)
I see that
trtexec handles the dynamic range by setting the dynamic range for all layers in the onnx model, as can be seen here in setTensorScales.
I tried to mimic this logic in my python code for the conversion, but it’s not trivial to do set correctly the dynamic ranges for the relevant layers due to the layer fusion that TensorRT applies during conversion. Should I set the dynamic range layers for the QDQ layers? or for the input for the QDQ? or for the input of the inputs of the QDQ? My quantization ops were introduced after the pattern of Conv2d–>BatchNorm–>Activation.
I tried extracting the quantization scale of each QDQ layer, calculating the dynamic range and setting the dynamic range but the model is sensitive as to which layers are being set with the dynamic range. so it seems tricky do set the dynamic range for specific layers.
This dynamic range handling from the ONNX graph should be done automatically during conversion of a QAT model, and not by the customer so the onnx model won’t be altered in a bad way…
Please suggest a way for converting a QAT model from ONNX while utilizing the quantization scale for each layer so we can get its dynamic range correctly.
TensorRT Version : 7.1.2
CUDA Version : 11.0
Operating System + Version : Ubuntu 18.04
Python Version (if applicable) : 3.6
TensorFlow Version (if applicable) : The model was trained on tf 2.3, converted to onnx, and then converted to tensorRT engine.
I can attach a dummy onnx model that helps to reproduce this issue:
basic_model_qat.py (956 Bytes)
dummy_qat_model_range_given_explicit_dims_batch_1.onnx (28.1 KB) for reproducing the issue using the conversion code I attached in the post (just run it without relevance for calibration, so explicit precision of int8 will be set)