TensorRT with ONNX model and RGB opencv data

Description

According to following snipped:

Custom trained SSD inception model in tensorRT c++ version - #14 by AastaLLL

It should be possible to convert opencv mat image files such that the tensorRT engine can work with it. However, my onnx models produce not the same output if I use them in tensorRT vs usage with onnxruntime: (only shape is relevant, colors differ is due to different plot methods)
onnxruntime:
onnx_runtime
tensorRT:
tensorRT

Environment

TensorRT Version: 7.1.3-1
GPU Type: Jetson Nano
Runtime: nvcr.io/nvidia/l4t-base:r32.2

Steps To Reproduce

Here is the code i used to process the input images for tensorRT:

// Prepare input according to:
// - https://forums.developer.nvidia.com/t/custom-trained-ssd-inception-model-in-tensorrt-c-version/143048/14
float* image_1 = static_cast<float*>(buffers.getHostBuffer("image_1:0"));
float* image_2 = static_cast<float*>(buffers.getHostBuffer("image_2:0"));
cv::Vec3b bgr;
unsigned i, j, k, volImg, volChl;
volImg = inputH * inputW;
volChl = inputH * inputW;
for (i = 0; i < batchSize; i++)
{
    for (j = 0; j < inputH; j++)
    {
        for(k = 0; k < inputW; k++)
        {
            bgr = prevImage.at<cv::Vec3b>(j, k);
            image_1[i * volImg + 0 * volChl + j * inputW + k] = float(bgr[2]);
            image_1[i * volImg + 1 * volChl + j * inputW + k] = float(bgr[1]);
            image_1[i * volImg + 2 * volChl + j * inputW + k] = float(bgr[0]);

            bgr = currImage.at<cv::Vec3b>(j, k);
            image_2[i * volImg + 0 * volChl + j * inputW + k] = float(bgr[2]);
            image_2[i * volImg + 1 * volChl + j * inputW + k] = float(bgr[1]);
            image_2[i * volImg + 2 * volChl + j * inputW + k] = float(bgr[0]);
        }
    }
}

And here is the python code to produce the correct output with the same onnx model:

import numpy as np
import onnx
import onnxruntime
import cv2

input_1 = 'target/data/1.jpg'
input_2 = 'target/data/2.jpg'
size = 128
session = onnxruntime.InferenceSession('m1.onnx', None)
input_name_1 = session.get_inputs()[0].name
input_name_2 = session.get_inputs()[1].name

output_name = session.get_outputs()[0].name
print(input_name_1)
print(input_name_2)
print(output_name)

prev = cv2.imread(input_2)
curr = cv2.imread(input_1)


curr = cv2.resize(curr, dsize=(size, size), interpolation=cv2.INTER_AREA)
prev = cv2.resize(prev, dsize=(size, size), interpolation=cv2.INTER_AREA)

curr.resize((1, size, size, 3))
prev.resize((1, size, size, 3))
curr = np.array(curr).astype('float32')
prev = np.array(prev).astype('float32')

result = session.run([output_name], {input_name_1:curr, input_name_2:prev})
mask = result[0].reshape(size, size)

Any information about how to correct put data into the tensorRT engine buffers would be appreciated!

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

Hi @raphael_zingg,

We request you to share inference script for generated Tenosrrt engine and onnx model. Also please share us issue reproducible sample images.

We would like to reproduce the issue from our end for better assistance.

Thank you.

Hi @spolisetty and @NVES
I have prepared reproducible steps and attached all files and models here:

I use test_onny.py to compare the keras and the onnx runtime results:

onnx: [[('n01491361', 'tiger_shark', 0.79704237), ('n01494475', 'hammerhead', 0.047171276), ('n01484850', 'great_white_shark', 0.031171514), ('n02640242', 'sturgeon', 0.002908883), ('n02071294', 'killer_whale', 0.0022062678)]]
keras: [[('n01491361', 'tiger_shark', 0.79704267), ('n01494475', 'hammerhead', 0.047171205), ('n01484850', 'great_white_shark', 0.031171512), ('n02640242', 'sturgeon', 0.0029088883), ('n02071294', 'killer_whale', 0.002206256)]]

I convert the model on my jetson nano with:
sesd@nano:/usr/src/tensorrt/bin$ sudo ./trtexec --onnx=/home/sesd/target/model.onnx --explicitBatch --verbose --workspace=1024 --exportTimes=trace.json >> /home/sesd/target/trt_out.txt --saveEngine=/home/sesd/target/model.trt

It runns without any errors, see trt_out.txt. TensorRT version is: 7.1.3-1+cuda10.2
Running the trt engine in python with shark_image_net.py produces the expected output:

trt: [[('n01491361', 'tiger_shark', 0.79704326), ('n01494475', 'hammerhead', 0.047170877), ('n01484850', 'great_white_shark', 0.031171363), ('n02640242', 'sturgeon', 0.0029088575), ('n02071294', 'killer_whale', 0.0022062561)]]

However, the final result of the c++ application, shark_image_net.cpp is off:

[04/19/2021-21:32:35] [I] idx:2 prob:0.0740943
[04/19/2021-21:32:35] [I] idx:3 prob:0.763998
[04/19/2021-21:32:35] [I] idx:4 prob:0.0382723
[04/19/2021-21:32:35] [I] idx:6 prob:0.0198122

all idx to labels can be found here:

 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
 3: 'tiger shark, Galeocerdo cuvieri',
 4: 'hammerhead, hammerhead shark',
 6: 'stingray',

Top 1 prediction is correct, tiger shark, but the output prob is slightly off and some other top 5 are just wrong. Is this just expected behaviour?

Again, I think there is something wrong with how I process the input. If I use another model, which is based on this example model but is more sensitify to output precicion and has two inputs, the output differs even more (see my initial post!). Another reason could be just something wrong with trt efficientnet arch, see: Output from ONNX inference and trt inference are different · Issue #1194 · NVIDIA/TensorRT · GitHub

2 Likes

Hi @raphael_zingg,

We could reproduce the issue as mentioned. As you mentioned this is same issue
https://github.com/NVIDIA/TensorRT/issues/1194

This is known bug related to Resize layer. It will be fixed in future releases.

Thank you.

1 Like

@spolisetty Thanks for the confirmation of the issue! Any idea when the future release is coming?

Hi @raphael_zingg,

We will update you once we get any update regarding that, meanwhile request you to please check for updates on Nvidia or TRT forum.

Thanks