TensorRT Batch Inference: different results

Description

Hi,

I am utilizing YOLOV4 detection models for my project. I use AlexeyAB’s darknet fork for training custom YOLOv4 detection models. For TensorRT conversion, I use Tianxiaomo’s pytorch-YOLOv4 to parse darknet models to Pytorch and then later to ONNX using torch.onnx.export.

The issue is that when I use the TensorRT model for batch size 1 inference, there is no problem but for batch size > 1, depending on the TensorRT conversion method, the inference results from the model is different.

Method#1 : Conversion using trtexec

trtexec --onnx=onnx_models/yolov4tiny_2_3_416_416_static.onnx --maxBatch=2 --fp16  --saveEngine=trt_models/trtexec/yolov4tiny2-416.trt

For testing, I created NumPy batch image (batch size 2) from a single image. However, the inference results are not similar. There are 5 objects are detected in the former but 11 objects are detected in the latter. The detection outputs for the first frame and the second frame is not the same despite being the same image.

Method#2 : Conversion using TensorRT Python API

EXPLICIT_BATCH = []
if trt.__version__[0] >= '7':
    EXPLICIT_BATCH.append(
        1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
print(EXPLICIT_BATCH)

def build_engine(onnx_file_path, engine_file_path, verbose=False, batch_size=1):
    """Takes an ONNX file and creates a TensorRT engine."""
    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network,
                                                                                                                TRT_LOGGER) as parser:

        if trt.__version__[0] >= '8':
            config = builder.create_builder_config()
            config.max_workspace_size = 1 << 28
            builder.max_batch_size = batch_size
            config.flags = 1 << int(trt.BuilderFlag.FP16)
            # config.flags = strict_type_constraints << int(trt.BuilderFlag.STRICT_TYPES)
        else:
            builder.max_workspace_size = 1 << 28
            builder.max_batch_size = batch_size
            builder.fp16_mode = True
            # builder.strict_type_constraints = True

        # Parse model file
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            if not parser.parse(model.read()):
                print('ERROR: Failed to parse the ONNX file.')
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None
        if trt.__version__[0] >= '7':
            # The actual yolo*.onnx is generated with batch size 64.
            # Reshape input to batch size 1
            shape = list(network.get_input(0).shape)
            print(shape)
            shape[0] = batch_size
            network.get_input(0).shape = shape
        print('Completed parsing of ONNX file')

        print('Building an engine; this may take a while...')
        if trt.__version__[0] >= '8':
            engine = builder.build_engine(network, config)
        else:
            engine = builder.build_cuda_engine(network)
        print('Completed creating engine')
        try:
            with open(engine_file_path, 'wb') as f:
                f.write(engine.serialize())
            return engine
        except:
            traceback.print_exc()

TensorRT model, converted from python API produces different results from trtexec. Python API TensorRT model produces 11 detections for the first image in the batch image (batch size is 2) but empty detection results for the latter similar to this issue by @thomallain. The detection outputs here only concern the first frame. All the detector arrays for the second frame are equal to zeros.

Environment

TensorRT Version: 8.0.0.3
GPU Type: GTX 1070
Nvidia Driver Version: 465.31
CUDA Version: 8.0.0.3
CUDNN Version: 8.2.2
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3,6
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.5
Baremetal or Container (if container which image + tag): Baremetal

Relevant Files

File Structure
.
├── onnx_models
│ ├── yolov4_-1_3_416_416_dynamic.onnx
│ ├── yolov4_1_3_416_416_static.onnx
│ └── yolov4_2_3_416_416_static.onnx
├── TownCenter-008.jpg
├── trt_convert.py
├── trt_models
│ ├── pythonapi
│ │ └── yolov4tiny2-416.trt
│ └── trtexec
│ └── yolov4tiny2-416.trt
└── yolov4-tiny.py

Relevant scripts, onnx models, and converted TensorRT models can be download via this Google Drive: issue - Google Drive

Steps To Reproduce

  • Conversion via trtexec can be done with the aforementioned method.
  • Conversion with python api can be done with trt_convert.py by passing desired onnx model as parameter.
  • Inference can be done with yolov4-tiny.py by passing desired converted TensorRT model as parameter.
python yolov4-tiny.py --trt trt_models/trtexec/yolov4tiny2-416.trt --conf_threshold 0.3 --nms_threshold 0.4 --num_classes 1 --batch_size 2

Could you share with me some suggestions on how to fix this error so that batch inference runs as expected?

Thanks in advance!

Best Regards,
Htut

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

@NVES , Hi, Thanks for the reply.

I have already added a link to download all required files to reproduce this error via Google Drive. I’ve also attached the text file which contains verbose outputs while converting the model with trtexec method.

check_model.py also runs fine with the onnx that I’ve attached in the GoogleDrive link.
verbose.txt (1.3 MB)

@Htut,

Sorry for the delayed response. Are you using dynamic shape input ? as you are working with batches, Please make sure you’re marking dynamic shapes correctly. Looks like you’re also not mentioning input layer name and dynamic shapes. Please refer following,

https://github.com/NVIDIA/TensorRT/blob/master/samples/opensource/trtexec/README.md#example-4-running-an-onnx-model-with-full-dimensions-and-dynamic-shapes

Hi @spolisetty, I am sorry for very late response. FYI, I specifically used onnx model, generated with static shape input. After debugging, converting the static shape input batch onnx model with trtexec works as expected now which is the method#1. However, I am still having trouble with method#2. TensorRT conversion of batch onnx model using Python API is still having the same trouble as usual. In my test case, I converted static shape input (2, 3, 416, 416) batch onnx model to TensorRT. Conversion went fine as usual but when I try to do inference with converted batch inference TensorRT model, it produces correct results for first frame but does not produce anything for second frame in batch data. Upgrading to TensorRT version 8.0.1.6 didn’t help either. I do need TensorRT conversion of batch onnx model generated with static input shape in Python API because I want to do INT8 quantization later which Python API needs to be utilized.

New TensorRT 8 based conversion script in Python API for batch onnx model, generated with static input shape

#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

import os

import tensorrt as trt


def set_net_batch(network, batch_size):
    """Set network input batch size explicitly.

    The ONNX file might have been generated with a different batch size.
    """
    if trt.__version__[0] >= '7':
        shape = list(network.get_input(0).shape)
        shape[0] = batch_size
        network.get_input(0).shape = shape

    return network


def build_engine(onnx_model, trt_engine, input_dims, do_int8, dla_core, batch_size=1, verbose=False):
    """Build TensorRT engine from ONNX model

    onnx_model    : str
                    Path of Onnx model file
    trt_engine    : str
                    Name of output TensorRT engine file
    batch_size    : int
                    Explicit batch size to be assigned to transformed engine file
    input_dims    : list
                    list containing input dimensions to onnx model, in (Height, Width)
    do_int8       : Boolean
                    Whether to quantize as Int8 model or not
                    If False, will quantize the model as FP16 engine while Engine file generation
    dla_core      : int
                    Deep Learning Acceleration by Nvidia
                    Number of CPU cores, to be offloaded for some operations in engine for better GPU utilization
    verbose       : Boolean
                    If enabled, a higher verbosity level will be set on the TensorRT logger.
    """

    print("Loading the ONNX file...")

    if not os.path.isfile(onnx_model):
        raise FileNotFoundError("ERROR: file {} not found!  You might want to run yolo_to_onnx.py first to generate it.".format(onnx_model))
    else:
        with open(onnx_model, "rb") as f:
            onnx_data = f.read()

    # network paramters
    net_c = 3
    net_h = int(input_dims[0])
    net_w = int(input_dims[1])

    print("TensorRT Version : {}".format(trt.__version__))

    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
    EXPLICIT_BATCH = [1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]
    print("Explicit Batch : {}".format(EXPLICIT_BATCH))
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network,
                                                                                                                TRT_LOGGER) as parser:

        if do_int8 and not builder.platform_has_fast_int8:
            raise RuntimeError('INT8 not supported on this platform')
        if not parser.parse(onnx_data):
            print('ERROR: Failed to parse the ONNX file.')
            for error in range(parser.num_errors):
                print(parser.get_error(error))

        network = set_net_batch(network, batch_size)

        print('Naming the input tensort as "input".')
        network.get_input(0).name = 'input'

        print('Building the TensorRT engine.  This would take a while...')
        builder.max_batch_size = batch_size
        config = builder.create_builder_config()
        config.max_workspace_size = 1 << 30
        print("Workspace Size : {}".format(1 << 30))
        config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
        config.set_flag(trt.BuilderFlag.FP16)
        profile = builder.create_optimization_profile()
        profile.set_shape(
            'input',  # input tensor name
            (batch_size, net_c, net_h, net_w),  # min shape
            (batch_size, net_c, net_h, net_w),  # opt shape
            (batch_size, net_c, net_h, net_w))  # max shape
        config.add_optimization_profile(profile)

        if do_int8:
            pass

        if dla_core >= 0:
            config.default_device_type = trt.DeviceType.DLA
            config.DLA_core = dla_core
            config.set_flag(trt.BuilderFlag.STRICT_TYPES)
            print('Using DLA core %d.' % dla_core)

        # engine = builder.build_engine(network, config)
        engine = builder.build_serialized_network(network, config)

    if engine is not None:
        print("Completed building TensorRT engine.")

    with open(trt_engine, 'wb') as engine_file:
        engine_file.write(engine)

    print("Serialized the TensorRT engine to file: {}".format(trt_engine))


if __name__ == "__main__":
    build_engine(onnx_model="yolov4_2_3_416_416_static.onnx",
                 trt_engine="yolov4tiny2-416.trt",
                 input_dims=[416, 416],
                 do_int8=False,
                 dla_core=-1,
                 batch_size=2,
                 verbose=False)

Regards,
Htut