Tensorrt lpdnet output

Hi,

I am trying to use lpdnet in python with a tensorrt model. The following code successfully runs inference but I cannot interpret the output.

import os
import time

import cv2
#import matplotlib.pyplot as plt
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
import pdb


class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()


def load_engine(trt_runtime, engine_path):
    with open(engine_path, "rb") as f:
        engine_data = f.read()
    engine = trt_runtime.deserialize_cuda_engine(engine_data)
    return engine

# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine, batch_size=1):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        # pdb.set_trace()
        size = trt.volume(engine.get_binding_shape(binding)) * batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
            print(f"input: shape:{engine.get_binding_shape(binding)} dtype:{engine.get_binding_dtype(binding)}")
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
            print(f"output: shape:{engine.get_binding_shape(binding)} dtype:{engine.get_binding_dtype(binding)}")
    return inputs, outputs, bindings, stream



def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async(
        batch_size=batch_size, bindings=bindings, stream_handle=stream.handle
    )
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

# TensorRT logger singleton
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt_engine_path = "lpd.trt"

trt_runtime = trt.Runtime(TRT_LOGGER)
# pdb.set_trace()
trt_engine = load_engine(trt_runtime, trt_engine_path)
# Execution context is needed for inference
context = trt_engine.create_execution_context()
# This allocates memory for network inputs/outputs on both CPU and GPU
inputs, outputs, bindings, stream = allocate_buffers(trt_engine)

# pdb.set_trace()
image = cv2.imread("car.jpg")
image = cv2.resize(image, (640, 480))

np.copyto(inputs[0].host, image.ravel())

outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
test1 = np.reshape(outputs[0], (4, 30, 40))
test2 = np.reshape(outputs[1], (1, 30, 40))
# cv2.imshow("1", test1)
cv2.imshow("2", test2)

cv2.waitKey()
print(outputs)

The output defined by the engine is as follows:
output: shape:(4, 30, 40) dtype:DataType.FLOAT
output: shape:(1, 30, 40) dtype:DataType.FLOAT

From the documentation

The usa pruned models are intended for easy deployment to the edge using DeepStream SDK or TensorRT. They accept 640x480x3 dimension input tensors and outputs 40x30x12 bbox coordinate tensor and 40x30x3 class confidence tensor.

The tensorrt engine was made with the following command:

tlt-converter -k nvidia_tlt -d 3,480,640 -p image_input,1x3x480x640,4x3x480x640,16x3x480x640 usa_pruned.etlt -t fp16 -e lpr_engine.trt

The problem is that the confidence and bbox arrays have different shapes than expected. How do I go about drawing the bounding boxes on the image after inference?

Simon

Hi,

LPDNet is a DetectNet_v2 detector with ResNet18 as feature extractor.
The output is in the DetectNet_v2 format.

It is GridBox output, which divides input image into a 30x40 grid.
And you can find the bounding-box parameter (xc, yc, w, h) and its corresponding confidence of each grid in the output.

Below is a DetectNet_v2 output parser for your reference:

Thanks.

Thank you for the help.

Is it expected for the bounding box array to have decimals?
Here is a the flattened output of the model:
outputs[0]

array([0.4873047 , 0.60253906, 1.0849609 , ..., 0.6376953 , 0.6113281 ,
       0.41455078], dtype=float32)

The confidence array is also strange. I input the following image:

image

and the max confidence is > max(outputs[1]) 0.00026947918 which seems very low for this image

1 Like

Hi,

Based on the configure file below:
https://github.com/NVIDIA-AI-IOT/deepstream_lpr_app/blob/master/deepstream-lpr-app/lpd_ccpd_config.txt

LPDNet requires scale=0.0039 and NCHW input.
Could you update the input with the following and try it again?

image = cv2.imread("car.jpg")
image = cv2.resize(image, (640, 480))
image = image / 255
image = image.transpose((2, 0, 1))
...

Thanks.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.