Calibration and int8 inference on Onnx model

Description

I use TensorRT to accelerate the inception v1 in onnx format, and get top1-accuracy 67.5% in fp32 format/67.5% in fp16 format, while get 0.1% in int8 after calibration.

The image preprocessing of the model is in bgr format, with mean subtraction [103.939, 116.779, 123.680]. Since tensorrt is not opensourced, I’ve no idea what’s going on inside the calibration tools. The images fed into the calibration tools should be the same format with the ones for inference, right?

Was there anything wrong when I was using the calibration or inference? Or this type of unnormalized image format not friendly as input?

I attached my script , onnx weight and calibration cache below.

https://drive.google.com/drive/folders/1niT1dvsUdHyfKoWWRwztcKL5V4Lp1UYd?usp=sharing

Could you help to inspect it ? thanks.

Environment

Actually the environment is the Flashed environment by Jetpack 4.4 on Jetson AGX Xavier.

TensorRT Version 7: 7.1 (Flashed by Jetpack4.4)
GPU Type: GPU of Jetson Xavier
Nvidia Driver Version:
CUDA Version: 10.2
CUDNN Version:
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.6
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.2 aarch64, downloaded from this forum
Baremetal or Container (if container which image + tag):

Hi @yeahfd,
Your query has been noted. Please allow us some time to check on this.
Thanks!

Hi @yeahfd,
Looking at the calibration cache.

TRT-7103-EntropyCalibration
data_0: 5c810a14   ---> 2.90571e+17 

seems to be wrong.
Can you please try IInt8EntropyCalibrator2 instead.
Thanks!

Actually I tried both the two versions of calibrators.

Here is the IInt8EntropyCalibrator2 result

data_0: 5c810a14
conv1/7x7_s2_1: 5db1c61c
conv1/7x7_s2_2: 5dbee29b
pool1/3x3_s2_1: 5dbee29b
pool1/norm1_1: 5d9d2427
conv2/3x3_reduce_1: 5cf8a62a
conv2/3x3_reduce_2: 5d110825
conv2/3x3_1: 5d452ca4
conv2/3x3_2: 5cf63814
conv2/norm2_1: 5d15bde1
pool2/3x3_s2_1: 5d15bde1
inception_3a/1x1_1: 5c175b59
inception_3a/1x1_2: 5bcabc1d
inception_3a/3x3_reduce_1: 5bc0e6af
inception_3a/3x3_reduce_2: 5bc0e6af
inception_3a/3x3_1: 5c037974
inception_3a/3x3_2: 5bcabc1d
inception_3a/5x5_reduce_1: 5bda8670
inception_3a/5x5_reduce_2: 5bda8670
inception_3a/5x5_1: 5bac1839
inception_3a/5x5_2: 5bcabc1d
inception_3a/pool_1: 5d15bde1
inception_3a/pool_proj_1: 5c13b2b4
inception_3a/pool_proj_2: 5bcabc1d
inception_3a/output_1: 5bcabc1d
inception_3b/1x1_1: 5bb19575
inception_3b/1x1_2: 5bb0e2d4
inception_3b/3x3_reduce_1: 5b940d1d
inception_3b/3x3_reduce_2: 5b94d34a
inception_3b/3x3_1: 5bfb13a4
inception_3b/3x3_2: 5bb0e2d4
inception_3b/5x5_reduce_1: 5bc2cdda
inception_3b/5x5_reduce_2: 5bc8017d
inception_3b/5x5_1: 5c00c944
inception_3b/5x5_2: 5bb0e2d4
inception_3b/pool_1: 5bcabc1d
inception_3b/pool_proj_1: 5c1d55b7
inception_3b/pool_proj_2: 5bb0e2d4
inception_3b/output_1: 5bb0e2d4
pool3/3x3_s2_1: 5bb0e2d4
inception_4a/1x1_1: 5be0dcf0
inception_4a/1x1_2: 5ba7eb6f
inception_4a/3x3_reduce_1: 5bc9bd32
inception_4a/3x3_reduce_2: 5bc76092
inception_4a/3x3_1: 5c2a6110
inception_4a/3x3_2: 5ba7eb6f
inception_4a/5x5_reduce_1: 5bca9d5a
inception_4a/5x5_reduce_2: 5bcbdddd
inception_4a/5x5_1: 5c0a8457
inception_4a/5x5_2: 5ba7eb6f
inception_4a/pool_1: 5bb0e2d4
inception_4a/pool_proj_1: 5c60767a
inception_4a/pool_proj_2: 5ba7eb6f
inception_4a/output_1: 5ba7eb6f
inception_4b/1x1_1: 5b3f666d
inception_4b/1x1_2: 5b472a31
inception_4b/3x3_reduce_1: 5b0f8d0f
inception_4b/3x3_reduce_2: 5b1ee009
inception_4b/3x3_1: 5b2b7c60
inception_4b/3x3_2: 5b472a31
inception_4b/5x5_reduce_1: 5b3f98dc
inception_4b/5x5_reduce_2: 5b4d370e
inception_4b/5x5_1: 5b8f33d6
inception_4b/5x5_2: 5b472a31
inception_4b/pool_1: 5ba7eb6f
inception_4b/pool_proj_1: 5bb0975c
inception_4b/pool_proj_2: 5b472a31
inception_4b/output_1: 5b472a31
inception_4c/1x1_1: 5b65285b
inception_4c/1x1_2: 5b3765ca
inception_4c/3x3_reduce_1: 5b126f16
inception_4c/3x3_reduce_2: 5b126f16
inception_4c/3x3_1: 5b4170f6
inception_4c/3x3_2: 5b3765ca
inception_4c/5x5_reduce_1: 5b283f35
inception_4c/5x5_reduce_2: 5aed9c76
inception_4c/5x5_1: 5ae3f2f0
inception_4c/5x5_2: 5b3765ca
inception_4c/pool_1: 5b472a31
inception_4c/pool_proj_1: 5b974eee
inception_4c/pool_proj_2: 5b3765ca
inception_4c/output_1: 5b3765ca
inception_4d/1x1_1: 5b3a49fe
inception_4d/1x1_2: 5b047fa2
inception_4d/3x3_reduce_1: 5b2a319f
inception_4d/3x3_reduce_2: 5b01e841
inception_4d/3x3_1: 5b256afe
inception_4d/3x3_2: 5b047fa2
inception_4d/5x5_reduce_1: 5ae99132
inception_4d/5x5_reduce_2: 5aba5146
inception_4d/5x5_1: 5ae2f90a
inception_4d/5x5_2: 5b047fa2
inception_4d/pool_1: 5b3765ca
inception_4d/pool_proj_1: 5b766c30
inception_4d/pool_proj_2: 5b047fa2
inception_4d/output_1: 5b047fa2
inception_4e/1x1_1: 5a95d64b
inception_4e/1x1_2: 5a8f0ef3
inception_4e/3x3_reduce_1: 5a2f0475
inception_4e/3x3_reduce_2: 5a3269be
inception_4e/3x3_1: 5a96c39f
inception_4e/3x3_2: 5a8f0ef3
inception_4e/5x5_reduce_1: 5a8b6f43
inception_4e/5x5_reduce_2: 5aa36478
inception_4e/5x5_1: 5b0103a0
inception_4e/5x5_2: 5a8f0ef3
inception_4e/pool_1: 5b047fa2
inception_4e/pool_proj_1: 5b0f6a32
inception_4e/pool_proj_2: 5a8f0ef3
inception_4e/output_1: 5a8f0ef3
pool4/3x3_s2_1: 5a8f0ef3
inception_5a/1x1_1: 5a7e9c35
inception_5a/1x1_2: 5a46681e
inception_5a/3x3_reduce_1: 5a46b6cf
inception_5a/3x3_reduce_2: 5a49f5eb
inception_5a/3x3_1: 5a4048cd
inception_5a/3x3_2: 5a46681e
inception_5a/5x5_reduce_1: 5a4cbc76
inception_5a/5x5_reduce_2: 5a7a87fa
inception_5a/5x5_1: 5a45d620
inception_5a/5x5_2: 5a46681e
inception_5a/pool_1: 5a8f0ef3
inception_5a/pool_proj_1: 5b11b4e7
inception_5a/pool_proj_2: 5a46681e
inception_5a/output_1: 5a46681e
inception_5b/1x1_1: 59dcad3b
inception_5b/1x1_2: 59768900
inception_5b/3x3_reduce_1: 5a00ebaa
inception_5b/3x3_reduce_2: 5998eefa
inception_5b/3x3_1: 5957e92e
inception_5b/3x3_2: 59768900
inception_5b/5x5_reduce_1: 59d7afc3
inception_5b/5x5_reduce_2: 59655347
inception_5b/5x5_1: 59296929
inception_5b/5x5_2: 59768900
inception_5b/pool_1: 5a46681e
inception_5b/pool_proj_1: 5a16b618
inception_5b/pool_proj_2: 59768900
inception_5b/output_1: 59768900
pool5/7x7_s1_1: 59768900
pool5/7x7_s1_2: 585216d3
OC2_DUMMY_0: 585216d3
(Unnamed Layer* 142) [Constant]_output: 3a87a23a
OC2_DUMMY_2: 3a87a23a
(Unnamed Layer* 144) [Matrix Multiply]_output: 58905ec7
(Unnamed Layer* 145) [Constant]_output: 3c6b9c60
(Unnamed Layer* 146) [Shuffle]_output: 3c6b9c60
loss3/classifier_1: 58905ec7
(Unnamed Layer* 148) [Shuffle]_output: 58905ec7
(Unnamed Layer* 149) [Softmax]_output: 3c010a14
prob_1: 3c010a14

Hi @yeahfd,

Could you please manually check whether the batch data that is getting generated is correct in “load_calibrate_data”?
Also please refer to below sample for your reference

Thanks

Hi @SunilJB @yeahfd,
Has this issue been solved?

the batch data in calibration should be ok.
Actually the situation is like

  • when I use the model in torchvision, whose preprocessing steps is to divide 255, subtract mean and divide std, if I do the same steps in IInt8EntropyCalibrator2, the final int8 model does not have good accuracy, while if I do not do any preprocessing, the result is good.
  • when I tried the model attached in this topic(actually I downloaded it from Cadence’s onnx model zoo), whose preprocessing step is only to subtract mean in original float32-onnx environment. When I quantize it in tensorRT, I almost tried all kinds of preprocessing policies in the calibration stage, and seems get no good results on accuracy.

Hi @yeahfd,
Have you tried calibration cache creation on detector models??

no.
actually the type of task does not affect the accuracy, but the type of preprocessing in the original model.
Since tensorRT is a black box, I just wonder if there any extra steps done inside the calibration tool.

Hi @yeahfd
So are you saying if we change the classifier model to detector model it should work??

nope… I just got confusion on the preprocessing stage. I think the type of the task does not have relationship with the problem.

The quantization-related problem if you get in the classifier, when you switch to the detector, it will exist as well since the backbones in the detector are still these commonly used networks.

Hi @yeahfd, @AakankshaS @SunilJB
are you supplying labels for your calibration files??

labels are not necessary for calibration stage. Only images are fed into model to get the blob tensor of each convolutional layer.

Oh ok

@yeahfd
I am seeing the same calibration behaviour with TRT 7.1.3. A calibrator that follows normalisation used during training in pytorch does not make any sense and just sets all tensor ranges to [-inf,inf]. While if you just divide img array by 255 the calibrated model then performs well. Did you manage to get any explanation?

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

@NVES
Thanks for your reply. Onnx model https://drive.google.com/file/d/1JVUiIBysRZjAA0a9lK9vKjHWG88d-qls/view?usp=sharing, onnx.checker.check_model return no errors.

The model was trained in pytorch with [0,1] input normalised with mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225].

My env:

'JETSON_JETPACK': '4.4.1'
'JETSON_L4T_REVISION': '4.4'
'JETSON_TENSORRT': '7.1.3.0',
'JETSON_CUDNN': '8.0.0.180'
'JETSON_CUDA': '10.2.89'

Calibration script:

import os
import glob
import yaml
import numpy as np

from PIL import Image

IMG_SUFFIX = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def get_img_arr(img_path, img_size, mean=None, std=None):

    if any([mean is not None, std is not None]):
        assert all([mean is not None, std is not None]), 'Provide both mean and std or neither'
 
    img_pil = Image.open(img_path).resize((img_size[1],img_size[0]), resample=Image.BILINEAR) # swap H,W to W,H for PIL
    img_arr = (np.array(img_pil)/255).astype(np.float32).transpose(2,0,1)
    if mean is not None and std is not None:
        img_arr = (img_arr - mean)/std
    img_arr = np.ascontiguousarray(np.expand_dims(img_arr,0))
    # img_arr = np.expand_dims(img_arr,0)

    return img_arr

class Dataset:

    def __init__(self, folder, img_size, batch_size, norm):
     
        images = []
        for suffix in IMG_SUFFIX:
            images.extend(glob.glob('{}/**/*{}'.format(folder, suffix)))
        self.images = images
        norm_cfg = os.path.join(folder, 'norm.yaml')
        self.norm = os.path.exists(norm_cfg) and norm
        self.mean = None
        self.std = None
        if self.norm:
            with open(norm_cfg) as f:
                cfg = yaml.safe_load(f)         
            self.mean = np.array(cfg['mean']).reshape(-1,1,1)
            self.std = np.array(cfg['std']).reshape(-1,1,1)
        self.img_max_nbytes = max([get_img_arr(img, img_size, mean=self.mean, std=self.std).nbytes for img in self.images])
        self.img_size = img_size

    def __iter__(self):

         return iter([get_img_arr(img, self.img_size, mean=self.mean, std=self.std) for img in self.images])
import os
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

from dataset import Dataset


class Calibrator(trt.IInt8EntropyCalibrator2):

    def __init__(self, data_folder, img_size, cache_file, cache_exist=False, batch_size=1, norm=False):

        trt.IInt8EntropyCalibrator2.__init__(self)
        self.cache_file = cache_file
        self.cache_exist = cache_exist
        self.batch_size = batch_size
        dataset = Dataset(data_folder, img_size, batch_size, norm=norm)
        self.batches = iter(dataset)
        self.device_input = cuda.mem_alloc(dataset.img_max_nbytes * self.batch_size)

    def get_batch_size(self):
        return self.batch_size
    
    def get_batch(self, names):
        try:
            # Assume self.batches is a generator that provides batch data.
            data = next(self.batches) # for calibration we take just img array
            print('IMG', data.shape, data.min(), data.max())
            # Assume that self.device_input is a device buffer allocated by the constructor.
            cuda.memcpy_htod(self.device_input, data)
            return [int(self.device_input)]
        except StopIteration:
            # When we're out of batches, we return either [] or None.
            # This signals to TensorRT that there is no calibration data remaining.
            return None
    
    def read_calibration_cache(self):
        # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
        if self.cache_exist:
            print('READ FROM EXISTING CALIBRATION CACHE')
        # if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()
        else:
            print('NO EXISTING CALIBRATION CACHE. PERFORM CALIBRATION')

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

Ok, now two cases.

Case 1. Calibration without normalisation. Cache - classifier_int8_noqat_unnorm.cache (2.6 KB). Log - log_classifier_int8_noqat_unnorm.txt (1.8 MB).
As you can see from the log during calibration the engine sees the input in the range [0,1], which is not normalised. If we look further down the log then scales and activations ranges make sense. But what is interesting is the input range which is [-1.00393,1.00393] while the engine sees inputs in the range [0,1]. Is is doing some normalisation by default?

Case 2. Calibration with normalisation. Using mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]. Cache - classifier_int8_noqat_norm.cache (2.6 KB). Log - log_classifier_int8_noqat_norm.txt (1.8 MB).
We can see from the log that during calibration the engine sees the normalised input in the range [-value, +value]. But then the actual scales and ranges does not make any sense with most of them set to [-inf, inf]. What is strange is the input range [-2.04692e+38,2.04692e+38].

I have faced a similar problem on other detection task and it seems to me having the input normalised b/n 0-1 is the source of the problem in the first place in case you plan to use int8 inference. The int8 is an Integer type with 8 bytes; however, by normalising the input data from 0-255 into 0-1 it is actually being converted into a floating point decimals and in the first convolution itself the result will be in fractions and doing int8 conversion will result in a lot of information loss. Thus, I just went and tried to have the input data in its original range (0-255) and trained the model (torch), torch to engine conversion, and the results after int8 calibration looks reasonable to me as compared to it counter part float32 version.