Doing inference in python with YOLO V4 in TensorRT - postporsessing

Exist a example of inference with YOLO v4 in python?

I found this: YOLO v4 inference with TensorRT after training with TLT 3.0 - #4 by johan_b

But is missing the postporsessing, do you know where i can find it?

Thanks!

is this one? GitHub - NVIDIA-AI-IOT/deepstream_tao_apps: Sample apps to demonstrate how to deploy models trained with TAO on DeepStream

Please refer to deepstream_tao_apps/post_processor at master · NVIDIA-AI-IOT/deepstream_tao_apps · GitHub or https://forums.developer.nvidia.com/t/inferring-yolo-v3-trt-model-in-python/

Ok, with your recomendations i found a workinkg example of inferring with yolo v4. But i still have some issues:

The model im using is custom YOLO v4 trained with our own dataset. With the example of TLT (tlt_vc_samples_v1.1.0/yolo_v4/yolo_v4.ipynb). It is trained for Person, Car and Two_wheels.
Model: trt2-yolo.engine - Google Drive

For exporting the model we use:

!tlt yolo_v4 export -m $USER_EXPERIMENT_DIR/experiment_dir_unpruned/weights/yolov4_resnet18_epoch_080.tlt \
                    -k $KEY \
                    -o $USER_EXPERIMENT_DIR/export_unpruned/yolov4_resnet18_epoch_080_fp16.etlt \
                    -e $SPECS_DIR/yolo_v4_train_resnet18_kitti.txt \
                    --batch_size 1\
                    --data_type fp16

Convert to engine:

!tlt tlt-converter -k $KEY \
                   -p Input,1x3x544x960,8x3x544x960,16x3x544x960 \
                   -m 1 \
                   -d 3x544x960 \
                   -o BatchedNMS \
                   -e $USER_EXPERIMENT_DIR/export_fp16/trt3.engine \
                   -t fp32 \
                   -i nchw \
                   $USER_EXPERIMENT_DIR/export_unpruned/yolov4_resnet18_epoch_080_fp16.etlt

This validation is working fine:

!tlt yolo_v4 inference -m $USER_EXPERIMENT_DIR/export_fp16/trt3.engine \
                       -e $SPECS_DIR/yolo_v4_train_resnet18_kitti.txt \
                       -i $DATA_DOWNLOAD_DIR/testing/images \
                       -o $USER_EXPERIMENT_DIR/yolo_infer_images \
                       -k $KEY \
                       -t 0.5

This is the code that im using:

'''Class for infer yolo trt engine'''
import time
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
from PIL import Image
import cv2


class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()


class TRTLoader:

    def __init__(self, trt_engine_path, model_w, model_h, num_classes, threshold, nms_threshold,
                 box_norm, stride=16):

        self.model_w = model_w
        self.model_h = model_h
        self.stride = stride
        self.box_norm = box_norm
        self.min_confidence = threshold
        self.NUM_CLASSES = num_classes

        self.grid_h = int(model_h / stride)
        self.grid_w = int(model_w / stride)
        self.grid_size = self.grid_h * self.grid_w
        self.trt_engine = self.load_engine(trt_engine_path)
        self.context = self.trt_engine.create_execution_context()
        inputs, outputs, bindings, stream = self.allocate_buffers(self.trt_engine)
        self.inputs = inputs
        self.outputs = outputs
        self.bindings = bindings
        self.stream = stream
        self.nms_threshold = nms_threshold

    def load_engine(self, engine_path):
        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
        trt.init_libnvinfer_plugins(None, '')
        trt_runtime = trt.Runtime(TRT_LOGGER)

        with open(engine_path, "rb") as f:
            engine_data = f.read()

        engine = trt_runtime.deserialize_cuda_engine(engine_data)
        return engine

    def allocate_buffers(self, engine, batch_size=1):
        """Allocates host and device buffer for TRT engine inference.
        This function is similair to the one in common.py, but
        converts network outputs (which are np.float32) appropriately
        before writing them to Python buffer. This is needed, since
        TensorRT plugins doesn't support output type description, and
        in our particular case, we use NMS plugin as network output.
        Args:
            engine (trt.ICudaEngine): TensorRT engine
        Returns:
            inputs [HostDeviceMem]: engine input memory
            outputs [HostDeviceMem]: engine output memory
            bindings [int]: buffer to device bindings
            stream (cuda.Stream): cuda stream for engine inference synchronization
        """
        inputs = []
        outputs = []
        bindings = []
        stream = cuda.Stream()

        # Current NMS implementation in TRT only supports DataType.FLOAT but
        # it may change in the future, which could brake this sample here
        # when using lower precision [e.g. NMS output would not be np.float32
        # anymore, even though this is assumed in binding_to_type]
        binding_to_type = {
            'Input': np.float32,
            'BatchedNMS': np.int32,
            'BatchedNMS_1': np.float32,
            'BatchedNMS_2': np.float32,
            'BatchedNMS_3': np.float32
        }

        for binding in engine:
            size = trt.volume(engine.get_binding_shape(binding)) * batch_size
            dtype = binding_to_type[str(binding)]
            # Allocate host and device buffers
            #size = abs(size)
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            # Append the device buffer to device bindings.
            bindings.append(int(device_mem))
            # Append to the appropriate list.
            if engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem))

        return inputs, outputs, bindings, stream

    def process_image(self, arr):

        # image = Image.fromarray(np.uint8(arr))

        # image_resized = image.resize(size=(self.model_w, self.model_h), resample=Image.BILINEAR)
        image_resized = cv2.resize(arr, (self.model_w, self.model_h))
        img_np = image_resized.astype(np.float32)
        # HWC -> CHW
        img_np = img_np.transpose((2, 0, 1))
        # print(img_np)
        # Normalize to [0.0, 1.0] interval (expected by model)
        # img_np = (1.0 / 255.0) * img_np
        img_np = img_np.ravel()
        return img_np

    # def process_image(self,arr):

    #     image = Image.fromarray(np.uint8(arr))
    #     # image_resized = image.thumbnail((self.model_w, self.model_h), Image.ANTIALIAS)
    #     image_resized = image.resize(size=(self.model_w, self.model_h),resample = Image.NEAREST)
    #     img_np = np.array(image_resized,dtype=np.float32)
    #     print(img_np.shape)
    #     # HWC -> CHW
    #     img_np = img_np.transpose((2, 0, 1))
    #     # Normalize to [0.0, 1.0] interval (expected by model)
    #     # img_np =  img_np/255.0
    #     img_np = (1.0 / 255.0) * img_np
    #     print(img_np)

    #     img_np = img_np.ravel()
    #     return img_np

    # This function is generalized for multiple inputs/outputs.
    # inputs and outputs are expected to be lists of HostDeviceMem objects.
    def do_inference(self, context, bindings, inputs, outputs, stream, batch_size=1):
        # Transfer input data to the GPU.
        [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
        # Run inference.
        context.execute_async(
            batch_size=batch_size, bindings=bindings, stream_handle=stream.handle
        )
        # Transfer predictions back from the GPU.
        [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
        # Synchronize the stream
        stream.synchronize()
        # Return only the host outputs.
        return [out.host for out in outputs]

    def _nms_boxes(self, boxes, box_confidences):
        """Apply the Non-Maximum Suppression (NMS) algorithm on the bounding boxes with their
        confidence scores and return an array with the indexes of the bounding boxes we want to
        keep (and display later).
        Keyword arguments:
        boxes -- a NumPy array containing N bounding-box coordinates that survived filtering,
        with shape (N,4); 4 for x,y,height,width coordinates of the boxes
        box_confidences -- a Numpy array containing the corresponding confidences with shape N
        """
        x_coord = boxes[:, 0]
        y_coord = boxes[:, 1]
        width = boxes[:, 2]
        height = boxes[:, 3]

        areas = width * height
        ordered = box_confidences.argsort()[::-1]

        keep = list()
        while ordered.size > 0:
            # Index of the current element:
            i = ordered[0]
            keep.append(i)
            xx1 = np.maximum(x_coord[i], x_coord[ordered[1:]])
            yy1 = np.maximum(y_coord[i], y_coord[ordered[1:]])
            xx2 = np.minimum(x_coord[i] + width[i], x_coord[ordered[1:]] + width[ordered[1:]])
            yy2 = np.minimum(y_coord[i] + height[i], y_coord[ordered[1:]] + height[ordered[1:]])

            width1 = np.maximum(0.0, xx2 - xx1 + 1)
            height1 = np.maximum(0.0, yy2 - yy1 + 1)
            intersection = width1 * height1
            union = (areas[i] + areas[ordered[1:]] - intersection)

            # Compute the Intersection over Union (IoU) score:
            iou = intersection / union

            # The goal of the NMS algorithm is to reduce the number of adjacent bounding-box
            # candidates to a minimum. In this step, we keep only those elements whose overlap
            # with the current bounding box is lower than the threshold:
            indexes = np.where(iou <= self.nms_threshold)[0]
            ordered = ordered[indexes + 1]

        keep = np.array(keep)
        return keep

    def predict(self, image):
        """Infers model on batch of same sized images resized to fit the model.
        Args:
            image_paths (str): paths to images, that will be packed into batch
                and fed into model
        """

        img = self.process_image(image)

        self.img_shape = image.shape
        print('image_shape ', image.shape)

        inputs = self.inputs
        outputs = self.outputs
        bindings = self.bindings
        stream = self.stream

        np.copyto(inputs[0].host, img)

        # When infering on single image, we measure inference
        # time to output it to the user

        # Fetch output from the model
        detection_out = self.do_inference(
            self.context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream
        )

        # Output inference time

        """print(
            "TensorRT inference time: {} ms".format(
                int(round((time.time() - inference_start_time) * 1000))
            )
        )"""

        # And return results
        return detection_out

    def postprocess(self, outputs, wh_format=True):
        """
        Postprocesses the inference output
        Args:
            outputs (list of float): inference output
            min_confidence (float): min confidence to accept detection
            analysis_classes (list of int): indices of the classes to consider

        Returns: list of list tuple: each element is a two list tuple (x, y) representing the corners of a bb
        """

        p_keep_count = outputs[0]
        p_bboxes = outputs[1]
        p_scores = outputs[2]
        p_classes = outputs[3]
        analysis_classes = list(range(self.NUM_CLASSES))
        threshold = self.min_confidence
        p_bboxes = np.array_split(p_bboxes, len(p_bboxes) / 4)
        bbs = []
        class_ids = []
        scores = []
        print(p_bboxes)
        x_scale = self.img_shape[1] / self.model_w
        y_scale = self.img_shape[0] / self.model_h

        for i in range(p_keep_count[0]):
            assert (p_classes[i] < len(analysis_classes))
            if p_scores[i] > threshold:
                x1 = int(np.round(p_bboxes[i][0] * x_scale))
                y1 = int(np.round(p_bboxes[i][1] * y_scale))
                x2 = int(np.round(p_bboxes[i][2] * x_scale))
                y2 = int(np.round(p_bboxes[i][3] * y_scale))

                bbs.append([x1, y1, x2, y2])
                class_ids.append(p_classes[i])
                scores.append(p_scores[i])
        # print(class_ids)
        bbs = np.asarray(bbs)
        print(scores)
        class_ids = np.asarray(class_ids)
        scores = np.asarray(scores)

        nms_boxes, nms_categories, nscores = list(), list(), list()
        for category in set(class_ids):
            idxs = np.where(class_ids == category)
            box = bbs[idxs]
            category = class_ids[idxs]
            confidence = scores[idxs]

            keep = self._nms_boxes(box, confidence)
            print('keep', keep)
            nms_boxes.append(box[keep])
            nms_categories.append(category[keep])
            nscores.append(confidence[keep])
        if len(nms_boxes) == 0:
            return [], [], []

        return nms_boxes, nms_categories, nscores


if __name__ == '__main__':
    engine = TRTLoader(trt_engine_path='trt2-yolo.engine', model_w=960, model_h=544,
                       num_classes=3, threshold=0.3, nms_threshold=0.5, box_norm=True,
                       stride=16)

    img = Image.open('person3.jpg')
    img = np.array(img)  # im2arr.shape: height x width x channel

    data = engine.predict(img)

We have a couple of errors:

1

File "test4.py", line 98, in allocate_buffers
    host_mem = cuda.pagelocked_empty(size, dtype)
pycuda._driver.MemoryError: cuMemHostAlloc failed: out of memory

Doing some debugin i found this:
the value of engine.get_binding_shape(binding) is (-1, 3, 544, 960)
so size is -1566720

So to fix this i did size = abs(size), but i dont know if this is correct.
But “fix” the issue.

After tring this, the problem is this:
[TensorRT] ERROR: Parameter check failed at: engine.cpp::resolveSlots::1318, condition: allInputDimensionsSpecified(routine)

So i don’t know what i’m doing wrong

Firstly, can you run official inference way successfully against your trt engine?
See the command in YOLOv4 — Transfer Learning Toolkit 3.0 documentation

I fix the problem, when i convert the model it seems that using diferent Inputs docent work so i change to: - p Input,1x3x544x960,1x3x544x960,1x3x544x960

!tlt tlt-converter -k $KEY \
                   -p Input,1x3x544x960,1x3x544x960,1x3x544x960 \
                   -m 1 \
                   -d 3x544x960 \
                   -o BatchedNMS \
                   -e $USER_EXPERIMENT_DIR/export_fp16/trt3.engine \
                   -t fp32 \
                   -i nchw \
                   $USER_EXPERIMENT_DIR/export_unpruned/yolov4_resnet18_epoch_080_fp16.etlt
1 Like

Or refer to Python run LPRNet with TensorRT show pycuda._driver.MemoryError: cuMemHostAlloc failed: out of memory - #8 by Morganh

1 Like