Exist a example of inference with YOLO v4 in python?
I found this: YOLO v4 inference with TensorRT after training with TLT 3.0 - #4 by johan_b
But is missing the postporsessing, do you know where i can find it?
Thanks!
Exist a example of inference with YOLO v4 in python?
I found this: YOLO v4 inference with TensorRT after training with TLT 3.0 - #4 by johan_b
But is missing the postporsessing, do you know where i can find it?
Thanks!
Please refer to deepstream_tao_apps/post_processor at master · NVIDIA-AI-IOT/deepstream_tao_apps · GitHub or https://forums.developer.nvidia.com/t/inferring-yolo-v3-trt-model-in-python/
Ok, with your recomendations i found a workinkg example of inferring with yolo v4. But i still have some issues:
The model im using is custom YOLO v4 trained with our own dataset. With the example of TLT (tlt_vc_samples_v1.1.0/yolo_v4/yolo_v4.ipynb). It is trained for Person, Car and Two_wheels.
Model: trt2-yolo.engine - Google Drive
For exporting the model we use:
!tlt yolo_v4 export -m $USER_EXPERIMENT_DIR/experiment_dir_unpruned/weights/yolov4_resnet18_epoch_080.tlt \
-k $KEY \
-o $USER_EXPERIMENT_DIR/export_unpruned/yolov4_resnet18_epoch_080_fp16.etlt \
-e $SPECS_DIR/yolo_v4_train_resnet18_kitti.txt \
--batch_size 1\
--data_type fp16
Convert to engine:
!tlt tlt-converter -k $KEY \
-p Input,1x3x544x960,8x3x544x960,16x3x544x960 \
-m 1 \
-d 3x544x960 \
-o BatchedNMS \
-e $USER_EXPERIMENT_DIR/export_fp16/trt3.engine \
-t fp32 \
-i nchw \
$USER_EXPERIMENT_DIR/export_unpruned/yolov4_resnet18_epoch_080_fp16.etlt
This validation is working fine:
!tlt yolo_v4 inference -m $USER_EXPERIMENT_DIR/export_fp16/trt3.engine \
-e $SPECS_DIR/yolo_v4_train_resnet18_kitti.txt \
-i $DATA_DOWNLOAD_DIR/testing/images \
-o $USER_EXPERIMENT_DIR/yolo_infer_images \
-k $KEY \
-t 0.5
This is the code that im using:
'''Class for infer yolo trt engine'''
import time
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
from PIL import Image
import cv2
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
class TRTLoader:
def __init__(self, trt_engine_path, model_w, model_h, num_classes, threshold, nms_threshold,
box_norm, stride=16):
self.model_w = model_w
self.model_h = model_h
self.stride = stride
self.box_norm = box_norm
self.min_confidence = threshold
self.NUM_CLASSES = num_classes
self.grid_h = int(model_h / stride)
self.grid_w = int(model_w / stride)
self.grid_size = self.grid_h * self.grid_w
self.trt_engine = self.load_engine(trt_engine_path)
self.context = self.trt_engine.create_execution_context()
inputs, outputs, bindings, stream = self.allocate_buffers(self.trt_engine)
self.inputs = inputs
self.outputs = outputs
self.bindings = bindings
self.stream = stream
self.nms_threshold = nms_threshold
def load_engine(self, engine_path):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(None, '')
trt_runtime = trt.Runtime(TRT_LOGGER)
with open(engine_path, "rb") as f:
engine_data = f.read()
engine = trt_runtime.deserialize_cuda_engine(engine_data)
return engine
def allocate_buffers(self, engine, batch_size=1):
"""Allocates host and device buffer for TRT engine inference.
This function is similair to the one in common.py, but
converts network outputs (which are np.float32) appropriately
before writing them to Python buffer. This is needed, since
TensorRT plugins doesn't support output type description, and
in our particular case, we use NMS plugin as network output.
Args:
engine (trt.ICudaEngine): TensorRT engine
Returns:
inputs [HostDeviceMem]: engine input memory
outputs [HostDeviceMem]: engine output memory
bindings [int]: buffer to device bindings
stream (cuda.Stream): cuda stream for engine inference synchronization
"""
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
# Current NMS implementation in TRT only supports DataType.FLOAT but
# it may change in the future, which could brake this sample here
# when using lower precision [e.g. NMS output would not be np.float32
# anymore, even though this is assumed in binding_to_type]
binding_to_type = {
'Input': np.float32,
'BatchedNMS': np.int32,
'BatchedNMS_1': np.float32,
'BatchedNMS_2': np.float32,
'BatchedNMS_3': np.float32
}
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * batch_size
dtype = binding_to_type[str(binding)]
# Allocate host and device buffers
#size = abs(size)
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream
def process_image(self, arr):
# image = Image.fromarray(np.uint8(arr))
# image_resized = image.resize(size=(self.model_w, self.model_h), resample=Image.BILINEAR)
image_resized = cv2.resize(arr, (self.model_w, self.model_h))
img_np = image_resized.astype(np.float32)
# HWC -> CHW
img_np = img_np.transpose((2, 0, 1))
# print(img_np)
# Normalize to [0.0, 1.0] interval (expected by model)
# img_np = (1.0 / 255.0) * img_np
img_np = img_np.ravel()
return img_np
# def process_image(self,arr):
# image = Image.fromarray(np.uint8(arr))
# # image_resized = image.thumbnail((self.model_w, self.model_h), Image.ANTIALIAS)
# image_resized = image.resize(size=(self.model_w, self.model_h),resample = Image.NEAREST)
# img_np = np.array(image_resized,dtype=np.float32)
# print(img_np.shape)
# # HWC -> CHW
# img_np = img_np.transpose((2, 0, 1))
# # Normalize to [0.0, 1.0] interval (expected by model)
# # img_np = img_np/255.0
# img_np = (1.0 / 255.0) * img_np
# print(img_np)
# img_np = img_np.ravel()
# return img_np
# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(self, context, bindings, inputs, outputs, stream, batch_size=1):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(
batch_size=batch_size, bindings=bindings, stream_handle=stream.handle
)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]
def _nms_boxes(self, boxes, box_confidences):
"""Apply the Non-Maximum Suppression (NMS) algorithm on the bounding boxes with their
confidence scores and return an array with the indexes of the bounding boxes we want to
keep (and display later).
Keyword arguments:
boxes -- a NumPy array containing N bounding-box coordinates that survived filtering,
with shape (N,4); 4 for x,y,height,width coordinates of the boxes
box_confidences -- a Numpy array containing the corresponding confidences with shape N
"""
x_coord = boxes[:, 0]
y_coord = boxes[:, 1]
width = boxes[:, 2]
height = boxes[:, 3]
areas = width * height
ordered = box_confidences.argsort()[::-1]
keep = list()
while ordered.size > 0:
# Index of the current element:
i = ordered[0]
keep.append(i)
xx1 = np.maximum(x_coord[i], x_coord[ordered[1:]])
yy1 = np.maximum(y_coord[i], y_coord[ordered[1:]])
xx2 = np.minimum(x_coord[i] + width[i], x_coord[ordered[1:]] + width[ordered[1:]])
yy2 = np.minimum(y_coord[i] + height[i], y_coord[ordered[1:]] + height[ordered[1:]])
width1 = np.maximum(0.0, xx2 - xx1 + 1)
height1 = np.maximum(0.0, yy2 - yy1 + 1)
intersection = width1 * height1
union = (areas[i] + areas[ordered[1:]] - intersection)
# Compute the Intersection over Union (IoU) score:
iou = intersection / union
# The goal of the NMS algorithm is to reduce the number of adjacent bounding-box
# candidates to a minimum. In this step, we keep only those elements whose overlap
# with the current bounding box is lower than the threshold:
indexes = np.where(iou <= self.nms_threshold)[0]
ordered = ordered[indexes + 1]
keep = np.array(keep)
return keep
def predict(self, image):
"""Infers model on batch of same sized images resized to fit the model.
Args:
image_paths (str): paths to images, that will be packed into batch
and fed into model
"""
img = self.process_image(image)
self.img_shape = image.shape
print('image_shape ', image.shape)
inputs = self.inputs
outputs = self.outputs
bindings = self.bindings
stream = self.stream
np.copyto(inputs[0].host, img)
# When infering on single image, we measure inference
# time to output it to the user
# Fetch output from the model
detection_out = self.do_inference(
self.context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream
)
# Output inference time
"""print(
"TensorRT inference time: {} ms".format(
int(round((time.time() - inference_start_time) * 1000))
)
)"""
# And return results
return detection_out
def postprocess(self, outputs, wh_format=True):
"""
Postprocesses the inference output
Args:
outputs (list of float): inference output
min_confidence (float): min confidence to accept detection
analysis_classes (list of int): indices of the classes to consider
Returns: list of list tuple: each element is a two list tuple (x, y) representing the corners of a bb
"""
p_keep_count = outputs[0]
p_bboxes = outputs[1]
p_scores = outputs[2]
p_classes = outputs[3]
analysis_classes = list(range(self.NUM_CLASSES))
threshold = self.min_confidence
p_bboxes = np.array_split(p_bboxes, len(p_bboxes) / 4)
bbs = []
class_ids = []
scores = []
print(p_bboxes)
x_scale = self.img_shape[1] / self.model_w
y_scale = self.img_shape[0] / self.model_h
for i in range(p_keep_count[0]):
assert (p_classes[i] < len(analysis_classes))
if p_scores[i] > threshold:
x1 = int(np.round(p_bboxes[i][0] * x_scale))
y1 = int(np.round(p_bboxes[i][1] * y_scale))
x2 = int(np.round(p_bboxes[i][2] * x_scale))
y2 = int(np.round(p_bboxes[i][3] * y_scale))
bbs.append([x1, y1, x2, y2])
class_ids.append(p_classes[i])
scores.append(p_scores[i])
# print(class_ids)
bbs = np.asarray(bbs)
print(scores)
class_ids = np.asarray(class_ids)
scores = np.asarray(scores)
nms_boxes, nms_categories, nscores = list(), list(), list()
for category in set(class_ids):
idxs = np.where(class_ids == category)
box = bbs[idxs]
category = class_ids[idxs]
confidence = scores[idxs]
keep = self._nms_boxes(box, confidence)
print('keep', keep)
nms_boxes.append(box[keep])
nms_categories.append(category[keep])
nscores.append(confidence[keep])
if len(nms_boxes) == 0:
return [], [], []
return nms_boxes, nms_categories, nscores
if __name__ == '__main__':
engine = TRTLoader(trt_engine_path='trt2-yolo.engine', model_w=960, model_h=544,
num_classes=3, threshold=0.3, nms_threshold=0.5, box_norm=True,
stride=16)
img = Image.open('person3.jpg')
img = np.array(img) # im2arr.shape: height x width x channel
data = engine.predict(img)
We have a couple of errors:
1
File "test4.py", line 98, in allocate_buffers
host_mem = cuda.pagelocked_empty(size, dtype)
pycuda._driver.MemoryError: cuMemHostAlloc failed: out of memory
Doing some debugin i found this:
the value of engine.get_binding_shape(binding)
is (-1, 3, 544, 960)
so size
is -1566720
So to fix this i did size = abs(size)
, but i dont know if this is correct.
But “fix” the issue.
After tring this, the problem is this:
[TensorRT] ERROR: Parameter check failed at: engine.cpp::resolveSlots::1318, condition: allInputDimensionsSpecified(routine)
So i don’t know what i’m doing wrong
Firstly, can you run official inference way successfully against your trt engine?
See the command in YOLOv4 — Transfer Learning Toolkit 3.0 documentation
I fix the problem, when i convert the model it seems that using diferent Inputs docent work so i change to: - p Input,1x3x544x960,1x3x544x960,1x3x544x960
!tlt tlt-converter -k $KEY \
-p Input,1x3x544x960,1x3x544x960,1x3x544x960 \
-m 1 \
-d 3x544x960 \
-o BatchedNMS \
-e $USER_EXPERIMENT_DIR/export_fp16/trt3.engine \
-t fp32 \
-i nchw \
$USER_EXPERIMENT_DIR/export_unpruned/yolov4_resnet18_epoch_080_fp16.etlt