Description
Onnx model int8 quantization fails post calibration with error: “[TRT] [E] 1: Unexpected exception _Map_base::at while assigning tensor scales” Here are last few lines of logs.
[03/07/2024-21:05:53] [TRT] [I] Post Processing Calibration data in 573.155 seconds.
[03/07/2024-21:05:53] [TRT] [V] Overriding tensor scales: /Concat_output_0 [Quantization(scale: {0.00787594,}, zero-point: {0,})] using /Concat_output_0 [Quantization(scale: {0.00787594,}, zero-point: {0,})]
[03/07/2024-21:05:53] [TRT] [V] Overriding tensor scales: /Concat_1_output_0 [Quantization(scale: {0.0361088,}, zero-point: {0,})] using /Concat_1_output_0 [Quantization(scale: {0.0361088,}, zero-point: {0,})]
[03/07/2024-21:05:53] [TRT] [V] Assigning tensor scales: /input_blocks.7/input_blocks.7.1/blocks.0/Concat_4_output_0 using /input_blocks.7/input_blocks.7.1/blocks.0/Concat_4_output_0 [
[03/07/2024-21:05:54] [TRT] [E] 1: Unexpected exception _Map_base::at
I can’t figure out what can be wrong as this error does not provide much information even with Verbose mode.
Export to float32 or float16 works properly, only with int8 I get this error.
I have a custom private model so I can’t share it here. But here is my conversion code:
import argparse
import os
import pickle
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.driver import Context
import tensorrt as trt
import numpy as np
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument("-o", "--onnx", type=str, required=True, help="ONNX model file path.")
parser.add_argument("-e", "--engine", type=str, required=True, help="TensorRT engine file path.")
parser.add_argument("-w", "--workspace", type=int, default=4, help="The workspace size in GB.")
parser.add_argument("--fp16", action="store_true", help="Enable FP16 mode.")
parser.add_argument("--int8", action="store_true", help="Enable INT8 mode.")
parser.add_argument("--calibration_data", type=str, help="Path to the calibration data .pkl file for INT8 quantization.")
parser.add_argument("--max_calibration_data", type=int, help="Maximum number of calibration data samples to use for INT8 quantization.")
return parser.parse_args()
class CalibrationDataLoader(trt.IInt8EntropyCalibrator2):
def __init__(self, calibration_data_path, input_layers, batch_size=1, max_examples=None):
super(CalibrationDataLoader, self).__init__()
self.batch_size = batch_size
self.current_index = 0
with open(calibration_data_path, 'rb') as f:
all_calibration_data = pickle.load(f)
self.calibration_data = all_calibration_data[:max_examples] if max_examples is not None else all_calibration_data
self.input_layers = input_layers
# Sizes for 'x' and 'lq' inputs
x_lq_num_elements = np.prod([1, 3, 64, 64]) # Product of the shape's dimensions
x_lq_element_size = np.float32().itemsize # Size in bytes of float32
x_lq_size = x_lq_num_elements * x_lq_element_size
# Size for 'timesteps' input
timesteps_num_elements = 1 # Single value
timesteps_element_size = np.int32().itemsize # Size in bytes of int32
timesteps_size = timesteps_num_elements * timesteps_element_size
# Allocate device memory
self.device_input_x = cuda.mem_alloc(int(x_lq_size))
self.device_input_lq = cuda.mem_alloc(int(x_lq_size))
self.device_input_timesteps = cuda.mem_alloc(int(timesteps_size))
def get_batch_size(self):
return self.batch_size
def get_batch(self, names):
if self.current_index < len(self.calibration_data):
batch = self.calibration_data[self.current_index]
x, timesteps, lq = batch
# Flatten and prepare x and lq, which are float32 tensors.
x_data = x.numpy().ravel().astype(np.float32)
lq_data = lq.numpy().ravel().astype(np.float32)
timesteps_data = timesteps.numpy().ravel().astype(np.int32)
cuda.memcpy_htod(self.device_input_x, x_data)
cuda.memcpy_htod(self.device_input_lq, lq_data)
cuda.memcpy_htod(self.device_input_timesteps, timesteps_data)
self.current_index += 1
print(f"Processing batch {self.current_index}/{len(self.calibration_data)}")
return [int(self.device_input_x), int(self.device_input_timesteps), int(self.device_input_lq)]
else:
print("All calibration batches processed.")
return None
def read_calibration_cache(self):
return None
def write_calibration_cache(self, cache):
pass
def convert_onnx_model_to_tensorrt_engine(onnx_file_path, engine_file_path, workspace=2, fp16_mode=False, int8_mode=False, calibration_data_path="", max_examples=None):
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace * (2 ** 30))
profile = builder.create_optimization_profile()
profile.set_shape('x', min=(1, 3, 64, 64), opt=(1, 3, 64, 64), max=(1, 3, 64, 64))
profile.set_shape('lq', min=(1, 3, 64, 64), opt=(1, 3, 64, 64), max=(1, 3, 64, 64))
profile.set_shape('noise', min=(1, 3, 64, 64), opt=(1, 3, 64, 64), max=(1, 3, 64, 64))
config.add_optimization_profile(profile)
if fp16_mode:
config.set_flag(trt.BuilderFlag.FP16)
if int8_mode:
assert calibration_data_path, "Calibration data path must be provided for INT8 mode"
calibrator = CalibrationDataLoader(calibration_data_path, ['x', 'timesteps', 'lq'], max_examples=max_examples)
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = calibrator
with open(onnx_file_path, "rb") as model:
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print("Building an engine; this may take a while...")
serialized_network = builder.build_serialized_network(network, config)
if serialized_network is None:
print("Failed to build serialized network.")
return None
engine = runtime.deserialize_cuda_engine(serialized_network)
if engine is None:
print("Failed to deserialize CUDA engine.")
return None
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(serialized_network)
return engine
def main():
args = get_parser()
convert_onnx_model_to_tensorrt_engine(args.onnx, args.engine, workspace=args.workspace, fp16_mode=args.fp16, int8_mode=args.int8, calibration_data_path=args.calibration_data, max_examples=args.max_calibration_data)
if __name__ == '__main__':
main()
Environment
TensorRT Version: 8.6.1 (8.6.1.post1)
GPU Type: V100 (Tesla V100-SXM2-16GB)
Nvidia Driver Version: cuda_12.2.r12.2/compiler.33191640_0
CUDA Version: 12.2
CUDNN Version: 8.9.6
ONNX version: 1.15.0
ONNX model opset: 16
ONNXruntime: 1.17.1
Operating System + Version: Ubuntu 22.04.3 LTS
Python Version (if applicable): 3.10.12
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 2.2.0+cu121
Baremetal or Container (if container which image + tag): Google Colab
Steps To Reproduce
!python onnx_to_tensorrt_int8.py -o ./pre-trained/enhance.onnx -e ./pre-trained/enhance_v100_int8.trt -w 16 --int8 --calibration_data ./calibration_data.pkl --max_calibration_data 100