int8 calibration,meet error get_batch() takes 2 positional arguments but 3 were given

python:3.5
TensorRT-5.1.2.2
ubuntu 16.04
cuda 10.0
cudnn 7.5
gpu:2080ti

I tried to use int8 mode, and implemented a custom calibrator

# -*- coding: utf-8 -*-
"""
Created on Tue May 21 16:57:28 2019

@author: gs
"""

import tensorrt as trt
import os

import pycuda.driver as cuda
import pycuda.autoinit
import cv2
import numpy as np

class customEntropyCalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self,batch_data_dir,cache_file,batch_size,input_size):
        trt.IInt8EntropyCalibrator2.__init__(self)
        assert os.path.exists(batch_data_dir)
        if not isinstance(input_size,(tuple,list)):
            self.input_size = [input_size,input_size]
        else:
            self.input_size = input_size
        self.calib_image_paths = [os.path.join(batch_data_dir,f) for f in os.listdir(batch_data_dir)]
        self.cache_file = cache_file
        self.batch_size = batch_size
        self.shape = [self.batch_size,3] + self.input_size
        self.device_input = cuda.mem_alloc(trt.volume(self.shape) * trt.float32.itemsize)
        self.indices = np.arange(len(self.calib_image_paths))
        np.random.shuffle(self.indices)
        
        def load_batches():
            for i in range(0,len(self.calib_image_paths) - self.batch_size+1, self.batch_size):
                indexs = self.indices[i:i+self.batch_size]
                yield self.read_batch_file(self.calib_image_paths[indexs])
                
        self.batches = load_batches()
        
    def read_batch_file(self,filenames):
        images = list()
        for filename in filenames:
            assert os.path.exists(filename)
            image = cv2.imread(filename)
            assert image.data
            images.append(image)
        return cv2.dnn.blobFromImages(images,scalefactor = 1./255,size=self.input_size,swapRB=True,crop=False)
    
    def get_batch_size(self):
        return self.batch_size
    
    def get_batch(self,names):
        try:
            data = next(self.batches)
            cuda.memcpy_htod(self.device_input,data)
            return [int(self.device_input)]
        except StopIteration:
            return None
        
    def read_calibration_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file,'rb') as f:
                return f.read()
            
    def write_calibration_cache(self,cache):
        with open(self.cache_file,'wb') as f:
            f.write(cache)

my build engine function :

def build_engine(onnx_file_path,max_batch_size,engine_file_path,mode,calib):
    assert mode in ["fp32","fp16","int8"]
    assert os.path.exists(onnx_file_path)
    with trt.Builder(TRT_LOGGER) as builder,builder.create_network() as network,trt.OnnxParser(network,TRT_LOGGER) as parser:
        builder.max_workspace_size = 1 <<30
        builder.max_batch_size = max_batch_size
        if mode == "fp16":
            builder.fp16_mode = True
        elif mode == "int8":
            builder.int8_mode = True
            assert calib is not None
            assert max_batch_size == calib.get_batch_size()
            builder.int8_calibrator = calib
        with open(onnx_file_path,'rb') as model:
            parser.parse(model.read())
        engine = builder.build_cuda_engine(network)
        assert engine is not None
        with open(engine_file_path,'wb') as f:
            f.write(engine.serialize())
        print("trt {} engine saved {}".format(mode,engine_file_path))

when i use these to create a engine in int8 mode,meet a error

DEPRECATED: This variant of get_batch is deprecated. Please use the single argument variant described in the documentation instead.
Traceback (most recent call last):
  File "build_engine.py", line 24, in <module>
    build_engine(onnx_file_path,max_batch_size,trt_engine_path,mode,calib)
  File "/home/trt_yolo_int8/utils.py", line 40, in build_engine
    engine = builder.build_cuda_engine(network)
TypeError: get_batch() takes 2 positional arguments but 3 were given

is there anyone can tell me how to fix that? thanks.

I try to trans onnx model to trt, in fp32 mode, it works

I change get_batch() to old version trt’s get_batch(), which is like this

def get_batch(self,bindings,names):
        try:
            data = next(self.batches)
            cuda.memcpy_htod(self.device_input,data)
            return [int(self.device_input)]
        except StopIteration:
            return None

the difference between old and new is one more arg bindings. I checked function defination of trt.IInt8EntropyCalibrator2.get_batch() in trt-5.1.2.2

>>> help(trt.IInt8EntropyCalibrator2.get_batch)
Help on instancemethod in module tensorrt.tensorrt:

get_batch(...)
    get_batch(self: tensorrt.tensorrt.IInt8EntropyCalibrator2, names: List[str]) -> object
    
    
    Get a batch of input for calibration. The batch size of the input must match the batch size returned by :func:`get_batch_size` .
    
    A possible implementation may look like this:
    ::
    
        def get_batch(names):
            try:
                # Assume self.batches is a generator that provides batch data.
                data = next(self.batches)
                # Assume that self.device_input is a device buffer allocated by the constructor.
                cuda.memcpy_htod(self.device_input, data)
                return [int(self.device_input)]
            except StopIteration:
                # When we're out of batches, we return either [] or None.
                # This signals to TensorRT that there is no calibration data remaining.
                return None
    
    :arg names: The names of the network inputs for each object in the bindings array.
    
    :returns: A :class:`list` of device memory pointers set to the memory containing each network input data, or an empty :class:`list` if there are no more batches for
 calibration. You can allocate these device buffers with pycuda, for example, and then cast them to :class:`int` to retrieve the pointer.

the new version get_batch() has only one rather then two arg, but when i use the new interface, there will be the above error, Is this a bug or something? if i use changed get_batch() to build int8 engine,will it work?

I build int8 engine successfully,and it works.

I have the same issue and basically same system (os etc). How did you solve the problem with get_batch()?

Hi,
I am also working on a calibrator for INT8. I am a bit confused about the ‘names’ argument in the get_batch() function. Is that used/needed at all?

Also it looks some NVIDIA official implementations are not correct:

  def get_batch(self, bindings, names):
    batch = self.stream.next_batch()
    if not batch.size:   
      return None
      
    cuda.memcpy_htod(self.d_input, batch)
    for i in self.input_layers[0]:
      assert names[0] != i

    bindings[0] = int(self.d_input)
    return bindings

I would expect:

for i in self.input_layers:
  assert names[0] == i

Indeed the C++ implementation here:

https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleINT8#calibrator-interface

is:

bool getBatch(void* bindings[], const char* names[], int nbBindings)
override
{
	if (!mStream.next())
		return false;

	CHECK(cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice));
	assert(!strcmp(names[0], INPUT_BLOB_NAME));
	bindings[0] = mDeviceInput;
	return true;
}

Thanks