Description
Hi! I am trying to export a trained segmentation model in pytorch to tensorrt. A am converting first the pytorch model to ONNX and from ONNX to tensorrt. All this is done in python.
The problem is that the exported tensorrt model is producing slightly (but important) different results than its pytorch or ONNX counterparts. Because of this, the accuracy is dropping by more than 10% in my dataset. Note that I am using float32 so I wouldn’t expect such a performance degradation.
You’ll find the environment to run everything as well as a the python code with the model in pytorch (and its weights) and the conversion process. At the end of the script the output of pytorch, onnx and tensorrt is compared.
Environment
TensorRT Version: 8.2.1
GPU Type: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia Driver Version: 515.48.07
CUDA Version: 11.5
CUDNN Version: 8
Operating System + Version: 20.04.3 LTS (Focal Fossa)
Python Version (if applicable): 3.8.10
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.8.0
Baremetal or Container (if container which image + tag): Docker container nvcr.io/nvidia/tensorrt:21.12-py3
Relevant Files
You can find the code below and the dockerfile used.
test_trt.py (15.7 KB)
Steps To Reproduce
I am using the following Dockerfile
FROM nvcr.io/nvidia/tensorrt:21.12-py3
RUN pip install https://download.pytorch.org/whl/cu111/torch-1.8.1%2Bcu111-cp38-cp38-linux_x86_64.whl onnx onnxruntime
And the python code
from typing import Tuple
import numpy as np
import onnxruntime
import pycuda.autoinit as cudainit
import pycuda.driver as cuda
import tensorrt as trt
import torch
import os
import subprocess
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBNReLU(nn.Module):
def __init__(
self,
in_chan,
out_chan,
ks=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=False,
):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
feat = self.conv(x)
feat = self.bn(feat)
feat = self.relu(feat)
return feat
class DetailBranch(nn.Module):
def __init__(self):
super(DetailBranch, self).__init__()
self.S1 = nn.Sequential(
ConvBNReLU(3, 64, 3, stride=2),
ConvBNReLU(64, 64, 3, stride=1),
)
self.S2 = nn.Sequential(
ConvBNReLU(64, 64, 3, stride=2),
ConvBNReLU(64, 64, 3, stride=1),
ConvBNReLU(64, 64, 3, stride=1),
)
self.S3 = nn.Sequential(
ConvBNReLU(64, 128, 3, stride=2),
ConvBNReLU(128, 128, 3, stride=1),
ConvBNReLU(128, 128, 3, stride=1),
)
def forward(self, x):
feat = self.S1(x)
feat = self.S2(feat)
feat = self.S3(feat)
return feat
class StemBlock(nn.Module):
def __init__(self):
super(StemBlock, self).__init__()
self.conv = ConvBNReLU(3, 16, 3, stride=2)
self.left = nn.Sequential(
ConvBNReLU(16, 8, 1, stride=1, padding=0),
ConvBNReLU(8, 16, 3, stride=2),
)
self.right = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
self.fuse = ConvBNReLU(32, 16, 3, stride=1)
def forward(self, x):
feat = self.conv(x)
feat_left = self.left(feat)
feat_right = self.right(feat)
feat = torch.cat([feat_left, feat_right], dim=1)
feat = self.fuse(feat)
return feat
class CEBlock(nn.Module):
def __init__(self):
super(CEBlock, self).__init__()
self.bn = nn.BatchNorm2d(128)
self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
# TODO: in paper here is naive conv2d, no bn-relu
self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
def forward(self, x):
feat = torch.mean(x, dim=(2, 3), keepdim=True)
feat = self.bn(feat)
feat = self.conv_gap(feat)
feat = feat + x
feat = self.conv_last(feat)
return feat
class GELayerS1(nn.Module):
def __init__(self, in_chan, out_chan, exp_ratio=6):
super(GELayerS1, self).__init__()
mid_chan = in_chan * exp_ratio
self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
self.dwconv = nn.Sequential(
nn.Conv2d(
in_chan,
mid_chan,
kernel_size=3,
stride=1,
padding=1,
groups=in_chan,
bias=False,
),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True), # not shown in paper
)
self.conv2 = nn.Sequential(
nn.Conv2d(
mid_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(out_chan),
)
self.conv2[1].last_bn = True
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
feat = self.conv1(x)
feat = self.dwconv(feat)
feat = self.conv2(feat)
feat = feat + x
feat = self.relu(feat)
return feat
class GELayerS2(nn.Module):
def __init__(self, in_chan, out_chan, exp_ratio=6):
super(GELayerS2, self).__init__()
mid_chan = in_chan * exp_ratio
self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
self.dwconv1 = nn.Sequential(
nn.Conv2d(
in_chan,
mid_chan,
kernel_size=3,
stride=2,
padding=1,
groups=in_chan,
bias=False,
),
nn.BatchNorm2d(mid_chan),
)
self.dwconv2 = nn.Sequential(
nn.Conv2d(
mid_chan,
mid_chan,
kernel_size=3,
stride=1,
padding=1,
groups=mid_chan,
bias=False,
),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True), # not shown in paper
)
self.conv2 = nn.Sequential(
nn.Conv2d(
mid_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(out_chan),
)
self.conv2[1].last_bn = True
self.shortcut = nn.Sequential(
nn.Conv2d(
in_chan,
in_chan,
kernel_size=3,
stride=2,
padding=1,
groups=in_chan,
bias=False,
),
nn.BatchNorm2d(in_chan),
nn.Conv2d(
in_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(out_chan),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
feat = self.conv1(x)
feat = self.dwconv1(feat)
feat = self.dwconv2(feat)
feat = self.conv2(feat)
shortcut = self.shortcut(x)
feat = feat + shortcut
feat = self.relu(feat)
return feat
class SegmentBranch(nn.Module):
def __init__(self):
super(SegmentBranch, self).__init__()
self.S1S2 = StemBlock()
self.S3 = nn.Sequential(
GELayerS2(16, 32),
GELayerS1(32, 32),
)
self.S4 = nn.Sequential(
GELayerS2(32, 64),
GELayerS1(64, 64),
)
self.S5_4 = nn.Sequential(
GELayerS2(64, 128),
GELayerS1(128, 128),
GELayerS1(128, 128),
GELayerS1(128, 128),
)
self.S5_5 = CEBlock()
def forward(self, x):
feat2 = self.S1S2(x)
feat3 = self.S3(feat2)
feat4 = self.S4(feat3)
feat5_4 = self.S5_4(feat4)
feat5_5 = self.S5_5(feat5_4)
return feat2, feat3, feat4, feat5_4, feat5_5
class BGALayer(nn.Module):
def __init__(self):
super(BGALayer, self).__init__()
self.left1 = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False
),
nn.BatchNorm2d(128),
nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=False),
)
self.left2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False),
)
self.right1 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
)
self.right2 = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False
),
nn.BatchNorm2d(128),
nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=False),
)
##TODO: does this really has no relu?
self.conv = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True), # not shown in paper
)
def forward(self, x_d, x_s):
dsize = x_d.size()[2:]
left1 = self.left1(x_d)
left2 = self.left2(x_d)
right1 = self.right1(x_s)
right2 = self.right2(x_s)
right1 = F.interpolate(right1, size=dsize, mode="bilinear", align_corners=True)
left = left1 * torch.sigmoid(right1)
right = left2 * torch.sigmoid(right2)
right = F.interpolate(right, size=dsize, mode="bilinear", align_corners=True)
out = self.conv(left + right)
return out
class SegmentHead(nn.Module):
def __init__(self, in_chan, mid_chan, num_classes):
super(SegmentHead, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
self.drop = nn.Dropout(0.1)
self.conv_out = nn.Conv2d(
mid_chan, num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
def forward(self, x, size=None):
feat = self.conv(x)
feat = self.drop(feat)
feat = self.conv_out(feat)
if size is not None:
feat = F.interpolate(feat, size=size, mode="bilinear", align_corners=True)
return feat
class MyModel(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.detail = DetailBranch()
self.segment = SegmentBranch()
self.bga = BGALayer()
# Create architecture with cityscapes classes
self.define_heads_layers(num_classes)
# check if we have the model already locally
if not os.path.exists("my_model.pth"):
# download pretrained model
subprocess.call(
"wget https://storage.googleapis.com/public-dl-models/my_model.pth",
shell=True,
)
# Load all layers with pretrained weights
self.load_state_dict(
torch.load("my_model.pth", map_location="cpu"), strict=False
)
def define_heads_layers(self, num_classes):
self.head = SegmentHead(128, 1024, num_classes)
def forward(self, x):
size = x.size()[2:]
feat_d = self.detail(x)
feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
feat_head = self.bga(feat_d, feat_s)
logits = self.head(feat_head, size)
return logits
def onnx_to_trt(onnx_path: str, trt_path: str):
logger = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(logger) as builder, builder.create_network(
EXPLICIT_BATCH
) as network, builder.create_builder_config() as config, trt.OnnxParser(
network, logger
) as parser:
# Parse model file
with open(onnx_path, "rb") as fr:
if not parser.parse(fr.read()):
errors = []
for error in range(parser.num_errors):
errors.append(parser.get_error(error))
raise RuntimeError(f"Error parsing ONNX model: {errors}")
# build settings
builder.max_batch_size = 1
config.max_workspace_size = 1 << 32 # 4GB
plan = builder.build_serialized_network(network, config)
if plan is None:
raise RuntimeError("Error building TensorRT engine")
# Serialize the plan
with open(trt_path, "wb") as fw:
fw.write(plan)
class HostDeviceMem(object):
def __init__(
self, host_mem, device_mem, binding_shape: Tuple[int, int, int, int], name: str
):
self.host = host_mem
self.device = device_mem
self.binding_shape = binding_shape
self.name = name
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
def trt_dtype_to_np(dtype: trt.DataType) -> np.dtype:
if dtype == trt.DataType.FLOAT:
return np.float32
elif dtype == trt.DataType.INT32:
return np.int32
elif dtype == trt.DataType.INT8:
return np.int8
elif dtype == trt.DataType.HALF:
return np.float16
else:
raise NotImplementedError(f"{dtype} conversion to numpy not implemented")
def allocate_buffers(engine, max_batch_size):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
binding_shape = engine.get_binding_shape(binding)
dtype = engine.get_binding_dtype(binding)
index = engine.get_binding_index(binding)
name = engine.get_binding_name(index)
np_dtype = trt_dtype_to_np(dtype)
size = trt.volume(binding_shape) * max_batch_size
host_mem = cuda.pagelocked_empty(size, np_dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(device_mem))
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem, binding_shape, name))
else:
outputs.append(HostDeviceMem(host_mem, device_mem, binding_shape, name))
return inputs, outputs, bindings, stream
def deserialize_engine(engine_path: str) -> trt.ICudaEngine:
logger = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(None, "")
with open(engine_path, "rb") as f, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
if engine is None:
raise RuntimeError("Error deserializing TensorRT engine")
return engine
if __name__ == "__main__":
model = MyModel(num_classes=19)
num_images = 8
onnx_path = "my_model.onnx"
trt_engine_path = "my_model.trt"
# set seeds
torch.manual_seed(42)
np.random.seed(42)
images = (np.random.randint(0, 255, size=(num_images, 3, 360, 640)) / 255.0).astype(
np.float32
)
torch_images = torch.from_numpy(images)
model.eval()
torch_logits = model(torch_images)
torch_logits_np = torch_logits.detach().numpy()
# Export to ONNX
input_names = ["input_0"]
output_names = ["output_0"]
torch.onnx.export(
model,
tuple([torch_images[0][None, ...]]), # batch size = 1
onnx_path,
do_constant_folding=True,
verbose=False,
input_names=input_names,
output_names=output_names,
opset_version=11,
)
# ONNX inference
onnx_logits = []
ort_session = onnxruntime.InferenceSession(onnx_path)
for image in images:
ort_inputs = {ort_session.get_inputs()[0].name: image[None, ...]}
outputs = ort_session.run(None, ort_inputs)
onnx_logits.append(outputs[0][0])
onnx_logits = np.array(onnx_logits)
# Convert ONNX to TensorRT
onnx_to_trt(onnx_path, trt_engine_path)
# TensorRT inference
engine = deserialize_engine(trt_engine_path)
context = engine.create_execution_context()
inputs, outputs, bindings, stream = allocate_buffers(engine, 1)
trt_logits = []
for image in images:
np.copyto(inputs[0].host, image.ravel())
for inp in inputs:
cuda.memcpy_htod_async(inp.device, inp.host, stream)
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
for out in outputs:
cuda.memcpy_dtoh_async(out.host, out.device, stream)
stream.synchronize()
result = {out.name: out.host.reshape(out.binding_shape) for out in outputs}
trt_logits.append(result["output_0"][0])
trt_logits = np.array(trt_logits)
# Compare results
print("tensorrt - torch diff: ", np.max(np.abs(torch_logits_np - trt_logits)))
print("tensorrt - onnx diff: ", np.max(np.abs(trt_logits - onnx_logits)))
print("onnx - torch diff: ", np.max(np.abs(onnx_logits - torch_logits_np)))
check_decimals = 3
np.testing.assert_almost_equal(onnx_logits, torch_logits_np, decimal=check_decimals)
np.testing.assert_almost_equal(trt_logits, torch_logits_np, decimal=check_decimals)
np.testing.assert_almost_equal(trt_logits, onnx_logits, decimal=check_decimals)
Just run the python code inside the dockercontainer and you’ll see that the assertions will fail between tensorrt outputs and pytorch outputs