Onnx -> tensorrt fp32 conversion performance degradation different outputs

Description

Hi! I am trying to export a trained segmentation model in pytorch to tensorrt. A am converting first the pytorch model to ONNX and from ONNX to tensorrt. All this is done in python.

The problem is that the exported tensorrt model is producing slightly (but important) different results than its pytorch or ONNX counterparts. Because of this, the accuracy is dropping by more than 10% in my dataset. Note that I am using float32 so I wouldn’t expect such a performance degradation.

You’ll find the environment to run everything as well as a the python code with the model in pytorch (and its weights) and the conversion process. At the end of the script the output of pytorch, onnx and tensorrt is compared.

Environment

TensorRT Version: 8.2.1
GPU Type: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia Driver Version: 515.48.07
CUDA Version: 11.5
CUDNN Version: 8
Operating System + Version: 20.04.3 LTS (Focal Fossa)
Python Version (if applicable): 3.8.10
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.8.0
Baremetal or Container (if container which image + tag): Docker container nvcr.io/nvidia/tensorrt:21.12-py3

Relevant Files

You can find the code below and the dockerfile used.
test_trt.py (15.7 KB)

Steps To Reproduce

I am using the following Dockerfile

FROM nvcr.io/nvidia/tensorrt:21.12-py3
RUN pip install https://download.pytorch.org/whl/cu111/torch-1.8.1%2Bcu111-cp38-cp38-linux_x86_64.whl onnx onnxruntime

And the python code

from typing import Tuple

import numpy as np
import onnxruntime
import pycuda.autoinit as cudainit
import pycuda.driver as cuda
import tensorrt as trt
import torch

import os
import subprocess

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBNReLU(nn.Module):
    def __init__(
        self,
        in_chan,
        out_chan,
        ks=3,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
        bias=False,
    ):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_chan,
            out_chan,
            kernel_size=ks,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_chan)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        feat = self.conv(x)
        feat = self.bn(feat)
        feat = self.relu(feat)
        return feat


class DetailBranch(nn.Module):
    def __init__(self):
        super(DetailBranch, self).__init__()
        self.S1 = nn.Sequential(
            ConvBNReLU(3, 64, 3, stride=2),
            ConvBNReLU(64, 64, 3, stride=1),
        )
        self.S2 = nn.Sequential(
            ConvBNReLU(64, 64, 3, stride=2),
            ConvBNReLU(64, 64, 3, stride=1),
            ConvBNReLU(64, 64, 3, stride=1),
        )
        self.S3 = nn.Sequential(
            ConvBNReLU(64, 128, 3, stride=2),
            ConvBNReLU(128, 128, 3, stride=1),
            ConvBNReLU(128, 128, 3, stride=1),
        )

    def forward(self, x):
        feat = self.S1(x)
        feat = self.S2(feat)
        feat = self.S3(feat)
        return feat


class StemBlock(nn.Module):
    def __init__(self):
        super(StemBlock, self).__init__()
        self.conv = ConvBNReLU(3, 16, 3, stride=2)
        self.left = nn.Sequential(
            ConvBNReLU(16, 8, 1, stride=1, padding=0),
            ConvBNReLU(8, 16, 3, stride=2),
        )
        self.right = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
        self.fuse = ConvBNReLU(32, 16, 3, stride=1)

    def forward(self, x):
        feat = self.conv(x)
        feat_left = self.left(feat)
        feat_right = self.right(feat)
        feat = torch.cat([feat_left, feat_right], dim=1)
        feat = self.fuse(feat)
        return feat


class CEBlock(nn.Module):
    def __init__(self):
        super(CEBlock, self).__init__()
        self.bn = nn.BatchNorm2d(128)
        self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
        # TODO: in paper here is naive conv2d, no bn-relu
        self.conv_last = ConvBNReLU(128, 128, 3, stride=1)

    def forward(self, x):
        feat = torch.mean(x, dim=(2, 3), keepdim=True)
        feat = self.bn(feat)
        feat = self.conv_gap(feat)
        feat = feat + x
        feat = self.conv_last(feat)
        return feat


class GELayerS1(nn.Module):
    def __init__(self, in_chan, out_chan, exp_ratio=6):
        super(GELayerS1, self).__init__()
        mid_chan = in_chan * exp_ratio
        self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
        self.dwconv = nn.Sequential(
            nn.Conv2d(
                in_chan,
                mid_chan,
                kernel_size=3,
                stride=1,
                padding=1,
                groups=in_chan,
                bias=False,
            ),
            nn.BatchNorm2d(mid_chan),
            nn.ReLU(inplace=True),  # not shown in paper
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                mid_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False
            ),
            nn.BatchNorm2d(out_chan),
        )
        self.conv2[1].last_bn = True
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.dwconv(feat)
        feat = self.conv2(feat)
        feat = feat + x
        feat = self.relu(feat)
        return feat


class GELayerS2(nn.Module):
    def __init__(self, in_chan, out_chan, exp_ratio=6):
        super(GELayerS2, self).__init__()
        mid_chan = in_chan * exp_ratio
        self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
        self.dwconv1 = nn.Sequential(
            nn.Conv2d(
                in_chan,
                mid_chan,
                kernel_size=3,
                stride=2,
                padding=1,
                groups=in_chan,
                bias=False,
            ),
            nn.BatchNorm2d(mid_chan),
        )
        self.dwconv2 = nn.Sequential(
            nn.Conv2d(
                mid_chan,
                mid_chan,
                kernel_size=3,
                stride=1,
                padding=1,
                groups=mid_chan,
                bias=False,
            ),
            nn.BatchNorm2d(mid_chan),
            nn.ReLU(inplace=True),  # not shown in paper
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                mid_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False
            ),
            nn.BatchNorm2d(out_chan),
        )
        self.conv2[1].last_bn = True
        self.shortcut = nn.Sequential(
            nn.Conv2d(
                in_chan,
                in_chan,
                kernel_size=3,
                stride=2,
                padding=1,
                groups=in_chan,
                bias=False,
            ),
            nn.BatchNorm2d(in_chan),
            nn.Conv2d(
                in_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False
            ),
            nn.BatchNorm2d(out_chan),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.dwconv1(feat)
        feat = self.dwconv2(feat)
        feat = self.conv2(feat)
        shortcut = self.shortcut(x)
        feat = feat + shortcut
        feat = self.relu(feat)
        return feat


class SegmentBranch(nn.Module):
    def __init__(self):
        super(SegmentBranch, self).__init__()
        self.S1S2 = StemBlock()
        self.S3 = nn.Sequential(
            GELayerS2(16, 32),
            GELayerS1(32, 32),
        )
        self.S4 = nn.Sequential(
            GELayerS2(32, 64),
            GELayerS1(64, 64),
        )
        self.S5_4 = nn.Sequential(
            GELayerS2(64, 128),
            GELayerS1(128, 128),
            GELayerS1(128, 128),
            GELayerS1(128, 128),
        )
        self.S5_5 = CEBlock()

    def forward(self, x):
        feat2 = self.S1S2(x)
        feat3 = self.S3(feat2)
        feat4 = self.S4(feat3)
        feat5_4 = self.S5_4(feat4)
        feat5_5 = self.S5_5(feat5_4)
        return feat2, feat3, feat4, feat5_4, feat5_5


class BGALayer(nn.Module):
    def __init__(self):
        super(BGALayer, self).__init__()
        self.left1 = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False
            ),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=False),
        )
        self.left2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False),
        )
        self.right1 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
        )
        self.right2 = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False
            ),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=False),
        )
        ##TODO: does this really has no relu?
        self.conv = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),  # not shown in paper
        )

    def forward(self, x_d, x_s):
        dsize = x_d.size()[2:]
        left1 = self.left1(x_d)
        left2 = self.left2(x_d)
        right1 = self.right1(x_s)
        right2 = self.right2(x_s)
        right1 = F.interpolate(right1, size=dsize, mode="bilinear", align_corners=True)
        left = left1 * torch.sigmoid(right1)
        right = left2 * torch.sigmoid(right2)
        right = F.interpolate(right, size=dsize, mode="bilinear", align_corners=True)
        out = self.conv(left + right)
        return out


class SegmentHead(nn.Module):
    def __init__(self, in_chan, mid_chan, num_classes):
        super(SegmentHead, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
        self.drop = nn.Dropout(0.1)
        self.conv_out = nn.Conv2d(
            mid_chan, num_classes, kernel_size=1, stride=1, padding=0, bias=True
        )

    def forward(self, x, size=None):
        feat = self.conv(x)
        feat = self.drop(feat)
        feat = self.conv_out(feat)
        if size is not None:
            feat = F.interpolate(feat, size=size, mode="bilinear", align_corners=True)
        return feat


class MyModel(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.detail = DetailBranch()
        self.segment = SegmentBranch()
        self.bga = BGALayer()

        # Create architecture with cityscapes classes
        self.define_heads_layers(num_classes)

        # check if we have the model already locally
        if not os.path.exists("my_model.pth"):
            # download pretrained model
            subprocess.call(
                "wget https://storage.googleapis.com/public-dl-models/my_model.pth",
                shell=True,
            )
        # Load all layers with pretrained weights
        self.load_state_dict(
            torch.load("my_model.pth", map_location="cpu"), strict=False
        )

    def define_heads_layers(self, num_classes):
        self.head = SegmentHead(128, 1024, num_classes)

    def forward(self, x):
        size = x.size()[2:]
        feat_d = self.detail(x)
        feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
        feat_head = self.bga(feat_d, feat_s)

        logits = self.head(feat_head, size)
        return logits


def onnx_to_trt(onnx_path: str, trt_path: str):
    logger = trt.Logger(trt.Logger.INFO)
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(logger) as builder, builder.create_network(
        EXPLICIT_BATCH
    ) as network, builder.create_builder_config() as config, trt.OnnxParser(
        network, logger
    ) as parser:
        # Parse model file
        with open(onnx_path, "rb") as fr:
            if not parser.parse(fr.read()):
                errors = []
                for error in range(parser.num_errors):
                    errors.append(parser.get_error(error))
                raise RuntimeError(f"Error parsing ONNX model: {errors}")

        # build settings
        builder.max_batch_size = 1
        config.max_workspace_size = 1 << 32  # 4GB

        plan = builder.build_serialized_network(network, config)
        if plan is None:
            raise RuntimeError("Error building TensorRT engine")

        # Serialize the plan
        with open(trt_path, "wb") as fw:
            fw.write(plan)


class HostDeviceMem(object):
    def __init__(
        self, host_mem, device_mem, binding_shape: Tuple[int, int, int, int], name: str
    ):
        self.host = host_mem
        self.device = device_mem
        self.binding_shape = binding_shape
        self.name = name

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()


def trt_dtype_to_np(dtype: trt.DataType) -> np.dtype:
    if dtype == trt.DataType.FLOAT:
        return np.float32
    elif dtype == trt.DataType.INT32:
        return np.int32
    elif dtype == trt.DataType.INT8:
        return np.int8
    elif dtype == trt.DataType.HALF:
        return np.float16
    else:
        raise NotImplementedError(f"{dtype} conversion to numpy not implemented")


def allocate_buffers(engine, max_batch_size):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()

    for binding in engine:
        binding_shape = engine.get_binding_shape(binding)
        dtype = engine.get_binding_dtype(binding)
        index = engine.get_binding_index(binding)
        name = engine.get_binding_name(index)

        np_dtype = trt_dtype_to_np(dtype)
        size = trt.volume(binding_shape) * max_batch_size
        host_mem = cuda.pagelocked_empty(size, np_dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)

        bindings.append(int(device_mem))

        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem, binding_shape, name))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem, binding_shape, name))

    return inputs, outputs, bindings, stream


def deserialize_engine(engine_path: str) -> trt.ICudaEngine:
    logger = trt.Logger(trt.Logger.INFO)
    trt.init_libnvinfer_plugins(None, "")
    with open(engine_path, "rb") as f, trt.Runtime(logger) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    if engine is None:
        raise RuntimeError("Error deserializing TensorRT engine")
    return engine


if __name__ == "__main__":

    model = MyModel(num_classes=19)

    num_images = 8
    onnx_path = "my_model.onnx"
    trt_engine_path = "my_model.trt"

    # set seeds
    torch.manual_seed(42)
    np.random.seed(42)
    images = (np.random.randint(0, 255, size=(num_images, 3, 360, 640)) / 255.0).astype(
        np.float32
    )

    torch_images = torch.from_numpy(images)
    model.eval()
    torch_logits = model(torch_images)
    torch_logits_np = torch_logits.detach().numpy()

    # Export to ONNX
    input_names = ["input_0"]
    output_names = ["output_0"]
    torch.onnx.export(
        model,
        tuple([torch_images[0][None, ...]]),  # batch size = 1
        onnx_path,
        do_constant_folding=True,
        verbose=False,
        input_names=input_names,
        output_names=output_names,
        opset_version=11,
    )

    # ONNX inference
    onnx_logits = []
    ort_session = onnxruntime.InferenceSession(onnx_path)
    for image in images:
        ort_inputs = {ort_session.get_inputs()[0].name: image[None, ...]}
        outputs = ort_session.run(None, ort_inputs)
        onnx_logits.append(outputs[0][0])
    onnx_logits = np.array(onnx_logits)

    # Convert ONNX to TensorRT
    onnx_to_trt(onnx_path, trt_engine_path)

    # TensorRT inference
    engine = deserialize_engine(trt_engine_path)
    context = engine.create_execution_context()
    inputs, outputs, bindings, stream = allocate_buffers(engine, 1)

    trt_logits = []
    for image in images:
        np.copyto(inputs[0].host, image.ravel())
        for inp in inputs:
            cuda.memcpy_htod_async(inp.device, inp.host, stream)
        context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
        for out in outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, stream)
        stream.synchronize()
        result = {out.name: out.host.reshape(out.binding_shape) for out in outputs}
        trt_logits.append(result["output_0"][0])
    trt_logits = np.array(trt_logits)

    # Compare results
    print("tensorrt - torch diff: ", np.max(np.abs(torch_logits_np - trt_logits)))
    print("tensorrt - onnx diff: ", np.max(np.abs(trt_logits - onnx_logits)))
    print("onnx - torch diff: ", np.max(np.abs(onnx_logits - torch_logits_np)))

    check_decimals = 3
    np.testing.assert_almost_equal(onnx_logits, torch_logits_np, decimal=check_decimals)
    np.testing.assert_almost_equal(trt_logits, torch_logits_np, decimal=check_decimals)
    np.testing.assert_almost_equal(trt_logits, onnx_logits, decimal=check_decimals)

Just run the python code inside the dockercontainer and you’ll see that the assertions will fail between tensorrt outputs and pytorch outputs

Hi,

We recommend you to please try on the latest TensorRT version 8.5.1 and let us know if you still face this issue.

Thank you.

Where I have deployed tensorrt I cannot change the version very easily. However, just tried it with version 8.5 and the same behavior is happening…

I managed to solve the problem. The problem was in the inference code which was heavily inspired by the sample code in TensorRT/samples/python at main · NVIDIA/TensorRT · GitHub. More specific from TensorRT/common.py at main · NVIDIA/TensorRT · GitHub and the abstraction in a class for easy usage like here TensorRT/infer.py at main · NVIDIA/TensorRT · GitHub
In all that python code (and mine shared in this thread), we first allocate the buffers for inference and then run a loop calling the predict API. The only way it runs well is if we allocate the output buffers each time we want to call model. The code would look like this:

    for image in images:
        np.copyto(inputs[0].host, image.ravel())
        for inp in inputs:
            cuda.memcpy_htod_async(inp.device, inp.host, stream)
        context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
        for out in outputs:
            # ALLOCATE A NEW FRESH OUTPUT BUFFER
            out.host = np.zeros(out.binding_shape, np.float32)
            cuda.memcpy_dtoh_async(out.host, out.device, stream)
        stream.synchronize()
        result = {out.name: out.host.reshape(out.binding_shape) for out in outputs}
        trt_logits.append(result["output_0"][0])
    trt_logits = np.array(trt_logits)

which is weird because no the sample code does this. Any clue what might be happening? I want to reproduce this with a sample code from your repo and create an issue if that’s the case

Hi,

Sorry for the delayed response. We recommend you to please reach out to the Issues · NVIDIA/TensorRT · GitHub to get better help on this.

Thank you.