Custom detection ONNX model gives wrong outputs using nvinfer with DeepStream 5.1

• Hardware Platform (Jetson / GPU) 1080 Ti

• DeepStream Version 5.1 (Official Docker Container)

• Problem Description:

Hi, I’m now woring on a face detection task, and I’m having trouble with the output given by the nvinfer.

I exported a pre-trained detection model using pytroch to ONNX model, then I used DeepStream 5.1 to run the ONNX model with single image (ref: deepstream-image-decode-test) using the official guide, and got a wrong result.

I have implemented the custom lib with C++, and a simple unit test was applied to make sure that the creation of anchors and bbox decoding process works correctly.

Meanwhile, I have run the ONNX model with onnxruntime package, the result also showed all the faces.

So, now the trouble should be: Using DeepStream nvinfer, the output of the same model is different from the pytorch or onnxrutime ones.

• How to reproduce the issue ?

Config and code files are all available on google drive.

ONNX file: https://drive.google.com/file/d/1UAyNPi9tylJcGWdG86-6PkG4PILUdr6S/view?usp=sharing

DeepStream image-decode-test file (Only the path of nvinfer config file is modified): https://drive.google.com/file/d/1rUNhBUiZWkxgWGqu2Ta7RbHQZR_Pwj8z/view?usp=sharing

DeepStream nvinfer config file: https://drive.google.com/file/d/1VJS46QM8c8Ubmx6ljyAiR3AkWioLCTFn/view?usp=sharing

[property]
gpu-id=0
#0=RGB, 1=BGR
model-color-format=1
onnx-file=face_detector.onnx
model-engine-file=face_detector.onnx_b1_gpu0_fp32.engine
labelfile-path=labels.txt
## 0=FP32, 1=INT8, 2=FP16 mode
network-mode=0
network-type=0
num-detected-classes=1
gie-unique-id=1
process-mode=1
interval=0
maintain-aspect-ratio=0
infer-dims=3;768;1024
output-blob-names=bbox;conf;landmark
parse-bbox-func-name=NvDsInferParseFaceDetector
custom-lib-path=nvdsinfer_custom_impl_face_detection/libnvdsinfer_custom_impl_face_detection.so
offsets=104;117;123

[class-attrs-all]
pre-cluster-threshold=0.1

Custom implementation C++ file: https://drive.google.com/file/d/1egRZGtO7Jnd4RKfZZNky80wipUiCBOfE/view?usp=sharing

#include "nvdsinfer_custom_impl.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <stack>
#include <string>
#include <vector>

#define MIN(a, b) ((a) < (b) ? (a) : (b))

extern "C" struct bbox {
  float x1;
  float y1;
  float x2;
  float y2;
  float s;
};

extern "C" struct box {
  float cx;
  float cy;
  float sx;
  float sy;
};

extern "C" bool cmp(bbox a, bbox b);
extern "C" bool cmp(bbox a, bbox b) {
  if (a.s > b.s)
    return true;
  return false;
}

extern "C" void get_priors(std::vector<box> &anchors, int w, int h);
extern "C" void get_priors(std::vector<box> &anchors, int w, int h) {
  anchors.clear();
  std::vector<std::vector<int>> feature_map(4), min_sizes(4);
  float steps[] = {8, 16, 32, 64};
  for (int i = 0; i < feature_map.size(); i++) {
    feature_map[i].push_back(ceil(h / steps[i]));
    feature_map[i].push_back(ceil(w / steps[i]));
  }
  std::vector<int> minsize1 = {10, 16, 24};
  min_sizes[0] = minsize1;
  std::vector<int> minsize2 = {32, 48};
  min_sizes[1] = minsize2;
  std::vector<int> minsize3 = {64, 96};
  min_sizes[2] = minsize3;
  std::vector<int> minsize4 = {128, 192, 256};
  min_sizes[3] = minsize4;

  for (int k = 0; k < feature_map.size(); k++) {
    std::vector<int> min_size = min_sizes[k];
    for (int i = 0; i < feature_map[k][0]; i++) {
      for (int j = 0; j < feature_map[k][1]; j++) {
        for (int l = 0; l < min_size.size(); l++) {
          float s_kx = min_size[l] * 1.0 / w;
          float s_ky = min_size[l] * 1.0 / h;
          float cx = (j + 0.5) * steps[k] / w;
          float cy = (i + 0.5) * steps[k] / h;
          box box = {cx, cy, s_kx, s_ky};
          anchors.push_back(box);
        }
      }
    }
  }
}

extern "C" bbox decode(box &anchor, box &det, float conf, int w, int h);
extern "C" bbox decode(box &anchor, box &det, float conf, int w, int h) {
  box output_box;
  output_box.cx = anchor.cx + det.cx * 0.1 * anchor.sx;
  output_box.cy = anchor.cy + det.cy * 0.1 * anchor.sy;
  output_box.sx = anchor.sx * exp(det.sx * 0.2);
  output_box.sy = anchor.sy * exp(det.sy * 0.2);

  bbox result;
  result.x1 = std::max((float)0, (output_box.cx - output_box.sx / 2) * w);
  result.y1 = std::max((float)0, (output_box.cy - output_box.sy / 2) * h);
  result.x2 = std::min((float)w, (output_box.cx + output_box.sx / 2) * w);
  result.y2 = std::min((float)h, (output_box.cy + output_box.sy / 2) * h);
  result.s = conf;

  return result;
}

extern "C" void nms(std::vector<bbox> &input_boxes, float NMS_THRESH);
extern "C" void nms(std::vector<bbox> &input_boxes, float NMS_THRESH) {
  std::vector<float> vArea(input_boxes.size());
  for (int i = 0; i < int(input_boxes.size()); ++i) {
    vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1) *
               (input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
  }
  for (int i = 0; i < int(input_boxes.size()); i++) {
    for (int j = i + 1; j < int(input_boxes.size());) {
      float xx1 = std::max(input_boxes[i].x1, input_boxes[j].x1);
      float yy1 = std::max(input_boxes[i].y1, input_boxes[j].y1);
      float xx2 = std::min(input_boxes[i].x2, input_boxes[j].x2);
      float yy2 = std::min(input_boxes[i].y2, input_boxes[j].y2);
      float w = std::max(float(0), xx2 - xx1 + 1);
      float h = std::max(float(0), yy2 - yy1 + 1);
      float inter = w * h;
      float ovr = inter / (vArea[i] + vArea[j] - inter);
      if (ovr >= NMS_THRESH) {
        input_boxes.erase(input_boxes.begin() + j);
        vArea.erase(vArea.begin() + j);
      } else {
        j++;
      }
    }
  }
}

extern "C" bool NvDsInferParseFaceDetector(
    std::vector<NvDsInferLayerInfo> const &outputLayersInfo,
    NvDsInferNetworkInfo const &networkInfo,
    NvDsInferParseDetectionParams const &detectionParams,
    std::vector<NvDsInferObjectDetectionInfo> &objectList);

/* C-linkage to prevent name-mangling */
extern "C" bool NvDsInferParseFaceDetector(
    std::vector<NvDsInferLayerInfo> const &outputLayersInfo,
    NvDsInferNetworkInfo const &networkInfo,
    NvDsInferParseDetectionParams const &detectionParams,
    std::vector<NvDsInferObjectDetectionInfo> &objectList) {

  static int bboxLayerIndex = -1;
  static int confLayerIndex = -1;
  static bool classMismatchWarn = false;

  static const int NUM_CLASSES_FACE_DETECTOR = 1;
  static const float nmsIOUThreshold = 0.4;

  if (bboxLayerIndex == -1) {
    for (unsigned int i = 0; i < outputLayersInfo.size(); i++) {
      if (strcmp(outputLayersInfo[i].layerName, "bbox") == 0) {
        bboxLayerIndex = i;
        break;
      }
    }
    if (bboxLayerIndex == -1) {
      std::cerr << "Could not find bbox layer buffer while parsing"
                << std::endl;
      return false;
    }
  }

  if (confLayerIndex == -1) {
    for (unsigned int i = 0; i < outputLayersInfo.size(); i++) {
      if (strcmp(outputLayersInfo[i].layerName, "conf") == 0) {
        confLayerIndex = i;
        break;
      }
    }
    if (confLayerIndex == -1) {
      std::cerr << "Could not find conf layer buffer while parsing"
                << std::endl;
      return false;
    }
  }

  if (!classMismatchWarn) {
    if (NUM_CLASSES_FACE_DETECTOR != detectionParams.numClassesConfigured) {
      std::cerr << "WARNING: Num classes mismatch. Configured:"
                << detectionParams.numClassesConfigured
                << ", detected by network: " << NUM_CLASSES_FACE_DETECTOR
                << std::endl;
    }
    classMismatchWarn = true;
  }

  static int numClassesToParse =
      MIN(NUM_CLASSES_FACE_DETECTOR, detectionParams.numClassesConfigured);

  float *confOut = (float *)outputLayersInfo[confLayerIndex].buffer;
  float *detectionOut = (float *)outputLayersInfo[bboxLayerIndex].buffer;
  int count = outputLayersInfo[confLayerIndex].inferDims.d[0];

  std::vector<bbox> bboxes;
  std::vector<box> anchors;
  get_priors(anchors, networkInfo.width, networkInfo.height);

  int classId = 0;

  for (int i = 0; i < count; i++) {
    float *det = detectionOut + i * 4;
    float conf = (confOut + i * 2)[1];

    if (classId >= numClassesToParse)
      continue;

    float threshold = detectionParams.perClassPreclusterThreshold[classId];

    if (conf < threshold) {
      continue;
    } else {
      box anchor = anchors[i];
      box det_box;
      det_box.cx = det[0];
      det_box.cy = det[1];
      det_box.sx = det[2];
      det_box.sy = det[3];

      bbox result =
          decode(anchor, det_box, conf, networkInfo.width, networkInfo.height);

      std::cout << "conf: " << conf << " output: [" << det_box.cx << ", "
                << det_box.cy << ", " << det_box.sx << ", " << det_box.sy
                << "] bbox: [" << result.x1 << ", " << result.y1 << ", "
                << result.x2 << ", " << result.y2 << "]" << std::endl;

      bboxes.push_back(result);
    }
  }

  std::sort(bboxes.begin(), bboxes.end(), cmp);
  nms(bboxes, nmsIOUThreshold);

  std::cout << "Object Lists: " << bboxes.size() << std::endl;
  for (int j = 0; j < bboxes.size(); j++) {
    NvDsInferObjectDetectionInfo object;
    object.classId = classId;
    object.detectionConfidence = bboxes[j].s;
    object.left = bboxes[j].x1;
    object.top = bboxes[j].y1;
    object.width = bboxes[j].x2;
    object.height = bboxes[j].y2;

    std::cout << object.detectionConfidence << ": [" << object.left << ", "
              << object.top << ", " << object.width << ", " << object.height
              << "]" << std::endl;

    objectList.push_back(object);
  }

  return true;
}

/* Check that the custom function has been defined correctly */
CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseFaceDetector);

Pytorch detection code: Face-Detector-1MB-with-landmark/detect.py at master · biubug6/Face-Detector-1MB-with-landmark · GitHub

from __future__ import print_function
import os
import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import time
from data import cfg_mnet, cfg_slim, cfg_rfb
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
import cv2
from models.retinaface import RetinaFace
from models.net_slim import Slim
from models.net_rfb import RFB
from utils.box_utils import decode, decode_landm
from utils.timer import Timer


parser = argparse.ArgumentParser(description='Test')
parser.add_argument('-m', '--trained_model', default='./weights/RBF_Final.pth',
                    type=str, help='Trained state_dict file path to open')
parser.add_argument('--network', default='RFB', help='Backbone network mobile0.25 or slim or RFB')
parser.add_argument('--origin_size', default=True, type=str, help='Whether use origin image size to evaluate')
parser.add_argument('--long_side', default=640, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)')
parser.add_argument('--save_folder', default='./widerface_evaluate/widerface_txt/', type=str, help='Dir to save txt results')
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
parser.add_argument('--confidence_threshold', default=0.02, type=float, help='confidence_threshold')
parser.add_argument('--top_k', default=5000, type=int, help='top_k')
parser.add_argument('--nms_threshold', default=0.4, type=float, help='nms_threshold')
parser.add_argument('--keep_top_k', default=750, type=int, help='keep_top_k')
parser.add_argument('--save_image', action="store_true", default=True, help='show detection results')
parser.add_argument('--vis_thres', default=0.6, type=float, help='visualization_threshold')
args = parser.parse_args()


def check_keys(model, pretrained_state_dict):
    ckpt_keys = set(pretrained_state_dict.keys())
    model_keys = set(model.state_dict().keys())
    used_pretrained_keys = model_keys & ckpt_keys
    unused_pretrained_keys = ckpt_keys - model_keys
    missing_keys = model_keys - ckpt_keys
    print('Missing keys:{}'.format(len(missing_keys)))
    print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
    print('Used keys:{}'.format(len(used_pretrained_keys)))
    assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
    return True


def remove_prefix(state_dict, prefix):
    ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
    print('remove prefix \'{}\''.format(prefix))
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}


def load_model(model, pretrained_path, load_to_cpu):
    print('Loading pretrained model from {}'.format(pretrained_path))
    if load_to_cpu:
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
    else:
        device = torch.cuda.current_device()
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')
    check_keys(model, pretrained_dict)
    model.load_state_dict(pretrained_dict, strict=False)
    return model


if __name__ == '__main__':
    torch.set_grad_enabled(False)

    cfg = None
    net = None
    if args.network == "mobile0.25":
        cfg = cfg_mnet
        net = RetinaFace(cfg = cfg, phase = 'test')
    elif args.network == "slim":
        cfg = cfg_slim
        net = Slim(cfg = cfg, phase = 'test')
    elif args.network == "RFB":
        cfg = cfg_rfb
        net = RFB(cfg = cfg, phase = 'test')
    else:
        print("Don't support network!")
        exit(0)

    net = load_model(net, args.trained_model, args.cpu)
    net.eval()
    print('Finished loading model!')
    print(net)
    cudnn.benchmark = True
    device = torch.device("cpu" if args.cpu else "cuda")
    net = net.to(device)

    # testing begin
    for i in range(100):
        image_path = "./img/sample.jpg"

        img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
        img = np.float32(img_raw)

        # testing scale
        target_size = args.long_side
        max_size = args.long_side
        im_shape = img.shape
        im_size_min = np.min(im_shape[0:2])
        im_size_max = np.max(im_shape[0:2])
        resize = float(target_size) / float(im_size_min)
        # prevent bigger axis from being more than max_size:
        if np.round(resize * im_size_max) > max_size:
            resize = float(max_size) / float(im_size_max)
        if args.origin_size:
            resize = 1

        if resize != 1:
            img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
        im_height, im_width, _ = img.shape


        scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
        img -= (104, 117, 123)
        img = img.transpose(2, 0, 1)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.to(device)
        scale = scale.to(device)

        tic = time.time()
        loc, conf, landms = net(img)  # forward pass
        print('net forward time: {:.4f}'.format(time.time() - tic))

        priorbox = PriorBox(cfg, image_size=(im_height, im_width))
        priors = priorbox.forward()
        priors = priors.to(device)
        prior_data = priors.data
        boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
        boxes = boxes * scale / resize
        boxes = boxes.cpu().numpy()
        scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
        landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
        scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
                               img.shape[3], img.shape[2], img.shape[3], img.shape[2],
                               img.shape[3], img.shape[2]])
        scale1 = scale1.to(device)
        landms = landms * scale1 / resize
        landms = landms.cpu().numpy()

        # ignore low scores
        inds = np.where(scores > args.confidence_threshold)[0]
        boxes = boxes[inds]
        landms = landms[inds]
        scores = scores[inds]

        # keep top-K before NMS
        order = scores.argsort()[::-1][:args.top_k]
        boxes = boxes[order]
        landms = landms[order]
        scores = scores[order]

        # do NMS
        dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
        keep = py_cpu_nms(dets, args.nms_threshold)
        # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
        dets = dets[keep, :]
        landms = landms[keep]

        # keep top-K faster NMS
        dets = dets[:args.keep_top_k, :]
        landms = landms[:args.keep_top_k, :]

        dets = np.concatenate((dets, landms), axis=1)

        # show image
        if args.save_image:
            for b in dets:
                if b[4] < args.vis_thres:
                    continue
                text = "{:.4f}".format(b[4])
                b = list(map(int, b))
                cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
                cx = b[0]
                cy = b[1] + 12
                cv2.putText(img_raw, text, (cx, cy),
                            cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))

                # landms
                cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
                cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
                cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
                cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
                cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
            # save image

            name = "test.jpg"
            cv2.imwrite(name, img_raw)```

Sorry for the late response, we will do the investigation and have the update soon.

Thank you so much for your reply.

I have run the auto engine file, generated from the ONNX model, with the python TensorRT and the result is correct. So, maybe there’s something wrong with the nvinfer preprocessing step. I’m working on it and try to find out why.

can you try with add

net-scale-factor=0.00392156862745098

into your [property] group

Yes, I’ve tried that before and the Deepstream nvinfer model cannot detect any object.

The input image to the PyTorch model is like:

tensor([[[[ 114.,  114.,  114.,  ...,   81.,   81.,   81.],
          [ 114.,  114.,  115.,  ...,   81.,   81.,   81.],
          [ 114.,  115.,  115.,  ...,   81.,   81.,   81.],
          ...,
          [  17.,   16.,   15.,  ...,  -90.,  -89.,  -88.],
          [  17.,   16.,   15.,  ...,  -91.,  -90.,  -89.],
          [  17.,   16.,   15.,  ...,  -92.,  -90.,  -90.]],

         [[  93.,   93.,   93.,  ...,   67.,   67.,   67.],
          [  93.,   93.,   94.,  ...,   67.,   67.,   67.],
          [  93.,   94.,   94.,  ...,   67.,   67.,   67.],
          ...,
          [  12.,   11.,   10.,  ..., -104., -103., -102.],
          [  12.,   11.,   10.,  ..., -104., -103., -102.],
          [  12.,   11.,   10.,  ..., -105., -103., -103.]],

         [[  80.,   80.,   80.,  ...,   63.,   63.,   63.],
          [  80.,   80.,   81.,  ...,   63.,   63.,   63.],
          [  80.,   81.,   81.,  ...,   63.,   63.,   63.],
          ...,
          [  43.,   42.,   41.,  ..., -108., -107., -106.],
          [  43.,   42.,   41.,  ..., -110., -109., -108.],
          [  43.,   42.,   41.,  ..., -111., -109., -109.]]]])

Is this preprocess image input?
if it right i think you need skip some step and retrain model

  • Maybe just resize, and with img = img.transpose(2, 0, 1) i think you need put in to model fist layer
  • Deepstream expected NCHW format.

Thank you for your reply!

Yes, it’s preprocessing the image input.

I’ve tried another image with shape [3, 640, 368], so resize here is 1.0 and no OpenCV resize operation is applied. But the result in Deepstream is still not correct.

The model input shape should be [N, 3, H, W] in both PyTorch and Deepstream, aka it’s NCHW format. In the python code, the shape of a raw image read by OpenCV is HWC format, so there’s a transpose operation.

1 Like

Is the output matrix from deepstream the same as the output from pytorch ?

The shape of the output is ok, but the values of the tensor is not.

1 Like

have you try option maintain-aspect-ratio=1

Sorry for the late reply.

I’ve tried it before. But, unfortunately, it did not work.

Hi @infinitesamsarax ,
Could you try " 2. [DS5.0GA_Jetson_GPU_Plugin] Dump the Inference Input" in DeepStream SDK FAQ - #9 by mchi to dump the inference image sent to TensorRT enqueue() and check if it match with your standalone sample?

Thanks!

Thank you and sorry for the late reply.

I tried to dump the deepstream preprocessed nvinfer input with different color-format and net-scale-factor to files. Comparing with the python preprocessed input, they’re quite similar. Then I forwarded the deepstream preprocessed input (file: gie-1_input-0_batch-0_frame-1_format-1_scale-1.000000.png) to pytroch network and the result seemed fine (despite it failed to detect one of the faces ).

All the pics are available in: https://drive.google.com/drive/folders/1PdeGYT3zh-ZYEN_596YhflS2FfpbElwz

Ok, could you refer to Very small Bounding Boxes with custom sgie model - #13 by mchi to dump the DeepStream inference raw output and check if the output is right?
I think, in DeepStream, if the input is right, TRT inference should be fine, the problem may be in post-processing.

Thanks!

I followed the post and got the dumped post-processing outputs binary files:

Then, using Python, I tried to read them with numpy and to save the data as csv file (the output of landmarks is currently ignored):

file_layer_1 = 'gie-1-output-layer-index-1-inferDims-45120-4-numElements-180480-dataType-0'
# file_layer_2 = 'gie-1-output-layer-index-2-inferDims-45120-10-numElements-451200-dataType-0'
file_layer_3 = 'gie-1-output-layer-index-3-inferDims-45120-2-numElements-90240-dataType-0'

loc = np.fromfile(file_layer_1, dtype=np.float32).reshape((45120, 4))
# landms = np.fromfile(file_layer_2, dtype=np.float32).reshape((45120, 10))
conf = np.fromfile(file_layer_3, dtype=np.float32).reshape((45120, 2))

data = np.concatenate([np.expand_dims(conf[:, 1], 1), loc], axis=1)
data = pd.DataFrame(data, columns=['conf', 'cx', 'cy', 'sx', 'sy'])

data.to_csv('output_in_dumped_binary_file.csv', index=False)

And the output data seems to be the same as the ones got from my custom parsing function.However, comparing with the Python TensorRT inference output csv using the same engine file, they are quite different. Meanwhile, the final output image of the dumped binary file, decoded using the same code in python, is still incorrect.

Since that output image is far different from the direct output image of deepstream, I guess that there’s also something wrong with my custom parsing function.

I’ll try to run it by replacing the deepstream output with the Python TensorRT inference csv file and then update the post soon.

All the files and pics are available at: https://drive.google.com/drive/folders/1moJ8j1NGGcSpnHo28uL_CNLj7-J6JKlF?usp=sharing

Thank you!

Thanks for the update!

I think I have solved this problem.

The cause of the bug should be:
Object bbox takes [left, top, width, height]. However, my input is [x1, y1, x2, y2] and that’s why the output looks so weird.

Now everything goes well and I can successfully detect faces in Deepstream. But I’m still wondering, in the previous post, why the final output image of the dumped binary file (parsed by python code) is incorrect. So, there’re still some debugging works to be done.

Thanks to kayccc, PhongNT and mchi! I appreciate your help so much.

1 Like