How to work with explicit batches in python

Description

I’m trying to understand how to build engine in trt and run inference with explicit batch size.
I tried to build some simple network in pytorch and tensorrt (LeNet like) and wanted to compare the outputs.
But I stacked in understanding of doing the inference with trt.

Environment

TensorRT Version: 7.1.3
GPU Type: Nvidia GeForce RTX2080 Ti
Nvidia Driver Version: 470.86
CUDA Version: 10.2
CUDNN Version: 8.2
Operating System + Version: Ubuntu 18.04
**Python Version : 3.6

Steps To Reproduce

Here is my code:

import torch
import torch.nn as nn
from dataclasses import dataclass
from torchvision import datasets, transforms
import torch.nn.functional as F

import common
import numpy as np

import time

import tensorrt as trt

# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

global BATCH_SIZE
BATCH_SIZE = 2

class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self._body = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        
        self._head = nn.Sequential(
            nn.Linear(in_features=16 * 5 * 5, out_features=120), 
            nn.ReLU(inplace=True),
            nn.Linear(in_features=120, out_features=84), 
            nn.ReLU(inplace=True),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        x = self._body(x)
        x = x.view(x.size()[0], -1)
        x = self._head(x)
        return x

# initialize the model
lenet5_model = LeNet5()


class ModelData(object):
	INPUT_NAME = "data"
	MODEL_PATH = 'models/lenet5_mnist.pt'
	INPUT_SHAPE = (-1, 1, -1, -1)
	OUTPUT_NAME = "prob"
	DTYPE = trt.float32


class LeNet5TRT(object):
	def __init__(self, weights) -> None:
		super().__init__()
		self.weights = weights
		self.engine = self.build_engine()

	def populate_network(self):
		# Configure the network layers based on the self.weights provided.
		input_tensor = self.network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)

		# body
		_body_conv1_w = self.weights['_body.0.weight'].numpy()
		_body_conv1_b = self.weights['_body.0.bias'].numpy()
		_body_conv1 = self.network.add_convolution(input=input_tensor,
                                             num_output_maps=6,
                                             kernel_shape=(5, 5),
                                             kernel=_body_conv1_w ,
                                             bias=_body_conv1_b)
		_body_conv1.stride = (1, 1)
		_body_conv1.padding = (0, 0)

		_body_relu1 = self.network.add_activation(
        								    input=_body_conv1.get_output(0), 
        									type=trt.ActivationType.RELU)

		_body_maxpool1 = self.network.add_pooling(input=_body_relu1.get_output(0), 
										type=trt.PoolingType.MAX, 
										window_size=(2, 2))

		_body_conv2_w = self.weights['_body.3.weight'].numpy()
		_body_conv2_b = self.weights['_body.3.bias'].numpy()
		_body_conv2 = self.network.add_convolution(input=_body_maxpool1.get_output(0),
                                             num_output_maps=16,
                                             kernel_shape=(5, 5),
                                             kernel=_body_conv2_w,
                                             bias=_body_conv2_b)
		_body_conv2.stride = (1, 1)
		_body_conv2.padding = (0, 0)

		_body_relu2 = self.network.add_activation(
        								    input=_body_conv2.get_output(0), 
        									type=trt.ActivationType.RELU)
		_body_maxpool2 = self.network.add_pooling(input=_body_relu2.get_output(0), 
										type=trt.PoolingType.MAX, 
										window_size=(2, 2))

		# head
		_head_linear1_w = self.weights['_head.0.weight'].numpy()
		_head_linear1_b = self.weights['_head.0.bias'].numpy()
		_head_linear1 = self.network.add_fully_connected(
		    											 input=_body_maxpool2.get_output(0),
		    											 num_outputs=120,
		    											 kernel=_head_linear1_w,
		                                                 bias=_head_linear1_b)
		_head_relu1 = self.network.add_activation(
        								    	  input=_head_linear1 .get_output(0), 
        										  type=trt.ActivationType.RELU)

		_head_linear2_w = self.weights['_head.2.weight'].numpy()
		_head_linear2_b = self.weights['_head.2.bias'].numpy()
		_head_linear2 = self.network.add_fully_connected(
		    											 input=_head_relu1.get_output(0),
		    											 num_outputs=84,
		    											 kernel=_head_linear2_w,
		                                                 bias=_head_linear2_b)

		_head_relu2 = self.network.add_activation(
        								    	  input=_head_linear2 .get_output(0), 
        										  type=trt.ActivationType.RELU)

		_head_linear3_w = self.weights['_head.4.weight'].numpy()
		_head_linear3_b = self.weights['_head.4.bias'].numpy()
		_head_linear3 = self.network.add_fully_connected(
		    											 input=_head_relu2.get_output(0),
		    											 num_outputs=10,
		    											 kernel=_head_linear3_w,
		                                                 bias=_head_linear3_b)
		
		_head_linear3.get_output(0).name = "prob"
		self.network.mark_output(tensor=_head_linear3 .get_output(0))

	def GiB(self, val):
		return val * 1 << 30


	def build_engine(self):
		# For more information on TRT basics, refer to the introductory samples.
		# with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network:
		EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
		with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network:
			self.network = network
			# builder.max_workspace_size = 32
			# builder.max_batch_size = 1 # always 1 for explicit batch
			# builder.fp16_mode = True
			# builder.create_optimization_profile().set_shape(ModelData.INPUT_NAME, (1, 32, 32), (3, 32, 32), (3, 64, 64)) 
			config = builder.create_builder_config()
			config.max_workspace_size = self.GiB(1)
			config.set_flag(trt.BuilderFlag.FP16)
			profile = builder.create_optimization_profile();
			profile.set_shape(ModelData.INPUT_NAME, (1, 1, 32, 32), (BATCH_SIZE, 1, 32, 32), (3, 1, 32, 32)) 
			config.add_optimization_profile(profile)
			# Populate the network using self.weights from the PyTorch model.
			self.populate_network()
			# Build and return an engine.
			# return builder.build_cuda_engine(self.network)
			return builder.build_engine(self.network, config)



def load_random_test_case(pagelocked_buffer):
	# Select an image at random to be the test case.
	img = np.random.rand(BATCH_SIZE,1,32,32).astype(np.float32)
	# Copy to the pagelocked input buffer
	np.copyto(pagelocked_buffer, img.ravel())
	return img

def main():
	common.add_help(description="Yeah!")
	# Get the PyTorch weights
	lenet5_model = LeNet5()
	lenet5_model.eval()
	lenet5_model.load_state_dict(torch.load(ModelData.MODEL_PATH))
	weights = lenet5_model.state_dict()

	# Do inference with TensorRT.
	with LeNet5TRT(weights).engine as engine:
		# Build an engine, allocate buffers and create a stream.
		# For more information on buffer allocation, refer to the introductory samples.
		with open('models/lenet5_mnist.trt', "wb") as f:
			f.write(engine.serialize())

		with open('models/lenet5_mnist.trt', "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
			engine = runtime.deserialize_cuda_engine(f.read())
			inputs, outputs, bindings, stream = common.allocate_buffers(engine)
			with engine.create_execution_context() as context:
				t = 0
				for _ in range(1):
					img = load_random_test_case(pagelocked_buffer=inputs[0].host)
					# For more information on performing inference, refer to the introductory samples.
					# The common.do_inference function will return a list of outputs 
					a = time.time()
					context.set_binding_shape(BATCH_SIZE, (1, 32, 32))
					context.active_optimization_profile = 0
					pred_trt = common.do_inference(context, bindings=bindings, inputs=inputs, 
						                           outputs=outputs, stream=stream, batch_size=BATCH_SIZE)
					t += time.time() - a

		with torch.no_grad():	
			pred_torch = lenet5_model.cuda()(torch.from_numpy(img).cuda())
			print('baseline: ', pred_torch.cpu().numpy())
		print(np.asarray(pred_trt, dtype=np.float32).shape)
		print('output:   ', np.asarray(pred_trt, dtype=np.float32).reshape((BATCH_SIZE, 10)))
		print('diff:    ', torch.max(torch.abs(pred_torch.cpu() - 
								     torch.as_tensor(np.asarray(pred_trt, dtype=np.float32).reshape((BATCH_SIZE, 10))))))
	print('Time: ', t)

# if __name__ == '__main__':
main()

The module common from import is the common.py taken from TensorRT/common.py at 96e23978cd6e4a8fe869696d3d8ec2b47120629b · NVIDIA/TensorRT · GitHub
The error is

$ python3 test.py 
Traceback (most recent call last):
  File "test.py", line 216, in <module>
    main()
  File "test.py", line 192, in main
    inputs, outputs, bindings, stream = common.allocate_buffers(engine)
  File "/yolact_edge/common.py", line 123, in allocate_buffers
    host_mem = cuda.pagelocked_empty(size, dtype)
pycuda._driver.MemoryError: cuMemHostAlloc failed: out of memory

It seem unlikely to spend all 12Gb of memory for such small model. I seem that I’m not fully understand how to deal with explicit batch size. Any tips or suggestions?

Hi,

Looks like there is some problem in allocating buffer.
Please refer the following similar issue, which may help you.

Also, please refer the following sample.

Thank you.

Thanks, for sharing. I tried to modify my code by suggestions from links. So my code is:

import torch
import torch.nn as nn
from dataclasses import dataclass
from torchvision import datasets, transforms
import torch.nn.functional as F

import common
import numpy as np

import time

import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

global BATCH_SIZE
BATCH_SIZE = 1

class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self._body = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        
        self._head = nn.Sequential(
            nn.Linear(in_features=16 * 5 * 5, out_features=120), 
            nn.ReLU(inplace=True),
            nn.Linear(in_features=120, out_features=84), 
            nn.ReLU(inplace=True),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        x = self._body(x)
        x = x.view(x.size()[0], -1)
        x = self._head(x)
        return x

# initialize the model
lenet5_model = LeNet5()


class ModelData(object):
	INPUT_NAME = "data"
	MODEL_PATH = 'models/lenet5_mnist.pt'
	INPUT_SHAPE = (-1, 1, 32, 32)
	OUTPUT_NAME = "prob"
	DTYPE = trt.float32


class LeNet5TRT(object):
	def __init__(self, weights) -> None:
		super().__init__()
		self.weights = weights
		self.engine = self.build_engine()

	def populate_network(self):
		# Configure the network layers based on the self.weights provided.
		input_tensor = self.network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)

		# body
		_body_conv1_w = self.weights['_body.0.weight'].numpy()
		_body_conv1_b = self.weights['_body.0.bias'].numpy()
		_body_conv1 = self.network.add_convolution(input=input_tensor,
                                             num_output_maps=6,
                                             kernel_shape=(5, 5),
                                             kernel=_body_conv1_w ,
                                             bias=_body_conv1_b)
		_body_conv1.stride = (1, 1)
		_body_conv1.padding = (0, 0)

		_body_relu1 = self.network.add_activation(
        								    input=_body_conv1.get_output(0), 
        									type=trt.ActivationType.RELU)

		_body_maxpool1 = self.network.add_pooling(input=_body_relu1.get_output(0), 
										type=trt.PoolingType.MAX, 
										window_size=(2, 2))

		_body_conv2_w = self.weights['_body.3.weight'].numpy()
		_body_conv2_b = self.weights['_body.3.bias'].numpy()
		_body_conv2 = self.network.add_convolution(input=_body_maxpool1.get_output(0),
                                             num_output_maps=16,
                                             kernel_shape=(5, 5),
                                             kernel=_body_conv2_w,
                                             bias=_body_conv2_b)
		_body_conv2.stride = (1, 1)
		_body_conv2.padding = (0, 0)

		_body_relu2 = self.network.add_activation(
        								    input=_body_conv2.get_output(0), 
        									type=trt.ActivationType.RELU)
		_body_maxpool2 = self.network.add_pooling(input=_body_relu2.get_output(0), 
										type=trt.PoolingType.MAX, 
										window_size=(2, 2))

		# head
		_head_linear1_w = self.weights['_head.0.weight'].numpy()
		_head_linear1_b = self.weights['_head.0.bias'].numpy()
		_head_linear1 = self.network.add_fully_connected(
		    											 input=_body_maxpool2.get_output(0),
		    											 num_outputs=120,
		    											 kernel=_head_linear1_w,
		                                                 bias=_head_linear1_b)
		_head_relu1 = self.network.add_activation(
        								    	  input=_head_linear1 .get_output(0), 
        										  type=trt.ActivationType.RELU)

		_head_linear2_w = self.weights['_head.2.weight'].numpy()
		_head_linear2_b = self.weights['_head.2.bias'].numpy()
		_head_linear2 = self.network.add_fully_connected(
		    											 input=_head_relu1.get_output(0),
		    											 num_outputs=84,
		    											 kernel=_head_linear2_w,
		                                                 bias=_head_linear2_b)

		_head_relu2 = self.network.add_activation(
        								    	  input=_head_linear2 .get_output(0), 
        										  type=trt.ActivationType.RELU)

		_head_linear3_w = self.weights['_head.4.weight'].numpy()
		_head_linear3_b = self.weights['_head.4.bias'].numpy()
		_head_linear3 = self.network.add_fully_connected(
		    											 input=_head_relu2.get_output(0),
		    											 num_outputs=10,
		    											 kernel=_head_linear3_w,
		                                                 bias=_head_linear3_b)
		
		_head_linear3.get_output(0).name = "prob"
		self.network.mark_output(tensor=_head_linear3 .get_output(0))

	def GiB(self, val):
		return val * 1 << 30


	def build_engine(self):
		# For more information on TRT basics, refer to the introductory samples.
		# with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network:
		EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
		with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network:
			self.network = network
			builder.max_batch_size = BATCH_SIZE # 
			config = builder.create_builder_config()
			config.max_workspace_size = 10 * (2 ** 30) # 8 gb # 1 << 30 # self.GiB(10)
			config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
			config.set_flag(trt.BuilderFlag.FP16)
			profile = builder.create_optimization_profile();
			profile.set_shape(ModelData.INPUT_NAME, 
				(BATCH_SIZE, 1, 32, 32), 
				(BATCH_SIZE, 1, 32, 32), 
				(BATCH_SIZE, 1, 32, 32)) 
			config.add_optimization_profile(profile)
			# Populate the network using self.weights from the PyTorch model.
			self.populate_network()
			# Build and return an engine.
			# return builder.build_cuda_engine(self.network)
			return builder.build_engine(self.network, config)



def load_random_test_case(pagelocked_buffer):
	# Select an image at random to be the test case.
	img = np.random.rand(BATCH_SIZE,1,32,32).astype(np.float32)
	# Copy to the pagelocked input buffer
	np.copyto(pagelocked_buffer, img.ravel())
	return img

def main():
	common.add_help(description="Yeah!")
	# Get the PyTorch weights
	lenet5_model = LeNet5()
	lenet5_model.eval()
	lenet5_model.load_state_dict(torch.load(ModelData.MODEL_PATH))
	weights = lenet5_model.state_dict()

	# Do inference with TensorRT.
	with LeNet5TRT(weights).engine as engine:
		# Build an engine, allocate buffers and create a stream.
		# For more information on buffer allocation, refer to the introductory samples.
		with open('models/lenet5_mnist.trt', "wb") as f:
			f.write(engine.serialize())

		with open('models/lenet5_mnist.trt', "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
			engine = runtime.deserialize_cuda_engine(f.read())
			inputs, outputs, bindings, stream = common.allocate_buffers(engine, batch_size=-1)
			print(f'Len inputs {len(inputs)}')
			with engine.create_execution_context() as context:
				t = 0
				for _ in range(1):
					img = load_random_test_case(pagelocked_buffer=inputs[0].host)
					# For more information on performing inference, refer to the introductory samples.
					# The common.do_inference function will return a list of outputs 
					a = time.time()
					# context.set_binding_shape(BATCH_SIZE, (1, 32, 32))
					context.set_binding_shape(0, (BATCH_SIZE, 1, 32, 32))

					context.active_optimization_profile = 0
					pred_trt = common.do_inference_v2(context, bindings=bindings, inputs=inputs, 
						                              outputs=outputs, stream=stream, batch_size=BATCH_SIZE)
					t += time.time() - a

		with torch.no_grad():	
			pred_torch = lenet5_model.cuda()(torch.from_numpy(img).cuda())
			print('baseline: ', pred_torch.cpu().numpy())
		print(np.asarray(pred_trt, dtype=np.float32).shape)
		print('output:   ', np.asarray(pred_trt, dtype=np.float32).reshape((BATCH_SIZE, 10)))
		print('diff:    ', torch.max(torch.abs(pred_torch.cpu() - 
								     torch.as_tensor(np.asarray(pred_trt, dtype=np.float32).reshape((BATCH_SIZE, 10))))))
	print('Time: ', t)

# if __name__ == '__main__':
main()

I modified function allocate_buffers ans set_batch size to -1

def allocate_buffers(engine, batch_size=1):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        # pdb.set_trace()
        size = trt.volume(engine.get_binding_shape(binding)) * batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
            print(f"input: shape:{engine.get_binding_shape(binding)} dtype:{engine.get_binding_dtype(binding)}")
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
            print(f"output: shape:{engine.get_binding_shape(binding)} dtype:{engine.get_binding_dtype(binding)}")
    return inputs, outputs, bindings, stream

And I get the same error:

python test.py
Traceback (most recent call last):
  File "test.py", line 223, in <module>
    main()
  File "test.py", line 196, in main
    inputs, outputs, bindings, stream = common.allocate_buffers(engine, batch_size=-1)
  File "/yolact_edge/common.py", line 146, in allocate_buffers
    host_mem = cuda.pagelocked_empty(size, dtype)
pycuda._driver.MemoryError: cuMemHostAlloc failed: out of memory

Hi,

size = trt.volume(engine.get_binding_shape(binding)) * batch_size

if you set batch_size to -1, size will become a negative number.
The correct approach is this:

  • engine.get_binding_shape(binding) will return the dimension (including -1 for dynamic dims) of the binding. For example, it may return [-1, 3, 224, 224] .
  • Then, you can calculate the volume by doing multiplications: maxBS * 3 * 224 * 224

cuMemHostAlloc failed: out of memory

This error means you are passing in negative size in cuda.mem_alloc.

Thank you.

Thanks, for explaining! Actually I’ve changed batch size from -1 to 1 and it works, but now it becomes more clear why the error occurred.

1 Like