TensorRT INT8 Calibration Issue

Description

When calibrating for INT8 optimization I get an error: F tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:432] Check failed: t.TotalBytes() == device_tensor->TotalBytes() (372 vs. 332)

I am able to optimize for FP32 and FP16 with good results, but when I try INT8 I get the error.

The issue seems to be around the calibration dataset or how it is being loaded. I have included the code I run, minus the saved_model and calibration dataset due to IP. The standard efficientdet-d0 can be used to replicate the issue. The calibration dataset is just a bunch of the validation data saved as jpegs at the training resolution 512x512.

In the code below there is a function called input_fn_works and input_fn_doesnt_work. input_fn_works uses random data and doesnt throw the error, however the resulting optimized model doesnt perform at all. I get scores of less than 0.01 for all bboxes. input_fn_doesnt_work throws the error of check failed when it reaches the second image in the load sequence. I have tried various ways of loading the data and none seem to change the outcome. Loading the images in a different order just changes the print error from having (372 vs. 332) to some other two numbers.

Any ideas of what is going wrong?

Environment

TensorRT Version: 6.0.1
GPU Type: Titan V
Nvidia Driver Version: 455
CUDA Version: 10.1
CUDNN Version: 7.6.5
Operating System + Version: Ubuntu 18
Python Version (if applicable): 3.6
TensorFlow Version (if applicable): Tensorflow 2
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

Relevant Files

import glob
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from PIL import Image

def config_gpu_memory(gpu_mem_cap):
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if not gpus:
        return
    print('Found the following GPUs:')
    for gpu in gpus:
        print('    ', gpu)
    for gpu in gpus:
        try:
            if not gpu_mem_cap:
                tf.config.experimental.set_memory_growth(gpu, True)
            else:
                tf.config.experimental.set_virtual_device_configuration(
                        gpu,
                        [tf.config.experimental.VirtualDeviceConfiguration(
                                memory_limit=gpu_mem_cap)])
        except RuntimeError as e:
            print('Can not set GPU memory config', e)
            
def get_trt_conversion_params(max_workspace_size_bytes,
                                                            precision_mode,
                                                            minimum_segment_size,
                                                            max_batch_size):
    conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
    conversion_params = conversion_params._replace(
            max_workspace_size_bytes=max_workspace_size_bytes)
    conversion_params = conversion_params._replace(precision_mode=precision_mode)
    conversion_params = conversion_params._replace(
            minimum_segment_size=minimum_segment_size)
    conversion_params = conversion_params._replace(
            use_calibration=precision_mode == 'INT8')
    conversion_params = conversion_params._replace(
            max_batch_size=max_batch_size)
    return conversion_params

def get_func_from_saved_model(saved_model_dir):
    saved_model_loaded = tf.saved_model.load(
            saved_model_dir, tags=[tag_constants.SERVING])
    graph_func = saved_model_loaded.signatures[
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    return graph_func

#It seems to process the first image then throws an error with the second
def input_fn_doesnt_work():
    input_size = (512,512)
    
    filenames = glob.glob('/path/to/saved/images/*.jpg')
    for i in range(500):
        batched_input = tf.io.read_file(filenames[i])
        batched_input = tf.image.decode_jpeg(batched_input, channels=3)
        batched_input = tf.image.resize(batched_input, size=input_size)
        batched_input = tf.cast(batched_input, tf.uint8)
        batched_input = np.expand_dims(batched_input, axis=0)
        batched_input = tf.constant(batched_input)
        yield ((batched_input),)     

#works and doesnt crash, but output results are not good. AP after this is 0. Most bboxes have a score of < 0.1
def input_fn_works():
    input_size = (512,512)
    
    for i in range(500):
        batched_input = np.random.random((1, input_size[0], input_size[1], 3)).astype(np.uint8)
        batched_input = tf.constant(batched_input)
        yield ((batched_input),)   
        
config_gpu_memory(0)
conversion_params = get_trt_conversion_params((1<<30),'INT8',2,1)

converter = trt.TrtGraphConverterV2(
        input_saved_model_dir='/path/to/effdet-d0/saved_model/',
        conversion_params=conversion_params,
)

converter.convert(calibration_input_fn=input_fn_doesnt_work)
converter.build(input_fn=input_fn_doesnt_work)
converter.save(output_saved_model_dir='/path/to/save/trt_int8/')

Steps To Reproduce

Change the paths and run via a jupyter notebook or python script.

Hi @CodeMonkeyOrSomething,

Please refer following for INT8 calibration in python.

https://github.com/NVIDIA/TensorRT/tree/master/samples/python/int8_caffe_mnist

Thank you.

@spolisetty, the link you provided says it is for TensorRT 7, I’m using 6. In any case, Im using the tensorRT available in tensorflow and the examples are significantly different. I am following the tensorRT guide provided here: Accelerating Inference In TF-TRT User Guide :: NVIDIA Deep Learning Frameworks Documentation . Most of the code is a direct copy and its failing, any idea why it may not be working?

More specifically this is the code link: https://github.com/tensorflow/tensorrt/blob/master/tftrt/examples/object_detection/object_detection.py

Hi @CodeMonkeyOrSomething ,

It looks good in general. The conversion code (using input_fn_doesnt_work ) does works for me using the tf.keras.applications.EfficientNetB0 model. A few things to note:

  1. The calibration input has to have the same shape and dtype as the original tf model. (In my case, I have changed the cast to float32 and the image size to [224, 224] to get EfficientNetB0 converted).
  2. Calibration has to have realistic input data. What you describe regarding accuracy with input_fn_works is expected.
  3. Unrelated to the problem: the max_batch_size argument is ignored in TF2, the max batch size is given by the input data seen during the first inference (performed by converter.build in this case).

Thank you.

Hi @spolisetty ,

I checked the casting and it looks correct as uint8 for my model. I used the input_fn_works as a sanity check and since it works I think that means the data is formatted correctly. I also removed the max_batch_size since its ignored. I am still getting the error.

Could you perhaps try the EfficientDet D0 512x512 model here: models/tf2_detection_zoo.md at master · tensorflow/models · GitHub

To run with this you would download the zip, extract, and then change the path to the saved_model folder inside the extracted folder.

I am starting to wonder if there is maybe some issue in my environment setup.

Thanks,
B

Hi @CodeMonkeyOrSomething,

We could reproduce this issue. We will work on this issue.
There is a workaround for this. If we set minimum_segment_size=20, then calibration works.
Note on performance: even with this minimum_segment_size=20 we have 28 engines created for the model, and a large number of small engines is usually not so efficient. You can try using even larger segment sizes, like 50 or 100 to see if that improves the performance.

Thank you.

Hi @spolisetty,

Good to know that it wasn’t just me. I appreciate your help figuring this out. In the tensorRT documentation here:Accelerating Inference in TensorFlow with TensorRT User Guide - NVIDIA Docs it says: “We have observed that the default value 3 gives the best performance for most models.” Which is why I used a small value.

Thanks!

1 Like

Hello @spolisetty,

I have run into what appears to be the same issue as well. While we wait for a fix, can you elaborate a bit more on what precisely triggers the issue, if it is known?

Thank you!