Deepstream / Triton Server - YOLOv7

Just sharing info.
I was to be able to configure Deepstream/ Triton Server using Custom Model Yolov7

How to configure Triton Server

DeepStream Config:

config_infer_primary_triton-server.txt

infer_config {
  unique_id: 1
  gpu_ids: [0]
  max_batch_size: 10
  
  backend {
    triton {
      model_name: "yolov7"
      version: -1
      grpc {
        url: "0.0.0.0:8001"
      }
    }
  }

  preprocess {
    network_format: IMAGE_FORMAT_RGB 
    tensor_order: TENSOR_ORDER_LINEAR
    maintain_aspect_ratio: 1
    frame_scaling_hw: FRAME_SCALING_HW_DEFAULT
    frame_scaling_filter: 1
    normalize {
      scale_factor: 0.0039215697906911373
    }
  }

  postprocess {
    labelfile_path: "/apps/configs/labels.txt"
    detection {
      num_detected_classes: 8
      custom_parse_bbox_func: "NvDsInferParseCustomEfficientNMS"
      nms {
        confidence_threshold: 0.35
        topk: 100
        iou_threshold: 0.65
      }
    }
  }

  extra {
    copy_input_to_host_buffers: false
  }
  custom_lib {
    path : "/opt/nvidia/deepstream/deepstream-6.1/sources/libs/nvdsinfer_customparser/libnvds_infercustomparser.so"
  }
}

input_control {
  process_mode : PROCESS_MODE_FULL_FRAME
  interval : 0
}

To be able to map the model outputs from efficientNMS to the NVIDIA NvDsObjectMeta data structure you need this code which needs to be added/compiled in /opt/nvidia/deepstream/deepstream-6.1/sources/libs/nvdsinfer_customparser/nvdsinfer_custombboxparser.cpp

extern "C"
bool NvDsInferParseCustomEfficientNMS (std::vector<NvDsInferLayerInfo> const &outputLayersInfo,
                                   NvDsInferNetworkInfo  const &networkInfo,
                                   NvDsInferParseDetectionParams const &detectionParams,
                                   std::vector<NvDsInferObjectDetectionInfo> &objectList) {
    if(outputLayersInfo.size() != 4)
    {
        std::cerr << "Mismatch in the number of output buffers."
                  << "Expected 4 output buffers, detected in the network :"
                  << outputLayersInfo.size() << std::endl;
        return false;
    }
    const char* log_enable = std::getenv("ENABLE_DEBUG");

    int* p_keep_count = (int *) outputLayersInfo[3].buffer;
    //int* p_keep_count = (int *) outputLayersInfo[0].buffer;


    //float* p_bboxes = (float *) outputLayersInfo[1].buffer;
    float* p_bboxes = (float *) outputLayersInfo[0].buffer;

    //NvDsInferDims inferDims_p_bboxes = outputLayersInfo[1].inferDims;
    NvDsInferDims inferDims_p_bboxes = outputLayersInfo[0].inferDims;

    int numElements_p_bboxes=inferDims_p_bboxes.numElements;

    float* p_scores = (float *) outputLayersInfo[2].buffer;
    //unsigned int* p_classes = (unsigned int *) outputLayersInfo[3].buffer;
    unsigned int* p_classes = (unsigned int *) outputLayersInfo[1].buffer;


    const float threshold = detectionParams.perClassThreshold[0];

    float max_bbox=0;
    for (int i=0; i < numElements_p_bboxes; i++)
    {
        if ( max_bbox < p_bboxes[i] )
            max_bbox=p_bboxes[i];
    }

    if (p_keep_count[0] > 0)
    {
        assert (!(max_bbox < 2.0));
        for (int i = 0; i < p_keep_count[0]; i++) {

            if ( p_scores[i] < threshold) continue;
            assert((unsigned int) p_classes[i] < detectionParams.numClassesConfigured);

            NvDsInferObjectDetectionInfo object;
            object.classId = (int) p_classes[i];
            object.detectionConfidence = p_scores[i];

            object.left=p_bboxes[4*i];
            object.top=p_bboxes[4*i+1];
            object.width=(p_bboxes[4*i+2] - object.left);
            object.height= (p_bboxes[4*i+3] - object.top);

            if(log_enable != NULL && std::stoi(log_enable)) {
                std::cout << "label/conf/ x/y w/h -- "
                << p_classes[i] << " "
                << p_scores[i] << " "
                << object.left << " " << object.top << " " << object.width << " "<< object.height << " "
                << std::endl;
            }

            object.left=CLIP(object.left, 0, networkInfo.width - 1);
            object.top=CLIP(object.top, 0, networkInfo.height - 1);
            object.width=CLIP(object.width, 0, networkInfo.width - 1);
            object.height=CLIP(object.height, 0, networkInfo.height - 1);

            objectList.push_back(object);
        }
    }
    return true;

CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseCustomEfficientNMS);
}
2 Likes

please refer to yoloV7 onnx triton inference - #3 by fanzh

thanks for sharing, great, will add this topic to our FAQ。
could you share your test performance? thanks

My setup is 2x RTX 3090.
1x GPU Deepstream Decode/Tracking/Encode (3 output only)
1x GPU Triton Server - Yolov7 - INT8 - 416x416

Triton Server Docker Image: nvcr.io/nvidia/tritonserver:22.09-py3
Deepstream Docker Image: nvcr.io/nvidia/deepstream:6.1.1-triton

I’m processing 110 rtsp stream at 10 fps.
I got 1100 fps using 95% of GPU RTX 3090, I believe I can get to 1200 fps.

I followed theses steps, after finish training my custom model.

1- Generate new model using Reparameterization procedure
Reparameterization is used to reduce trainable BoF modules into deploy model for fast inference.

2- Export Pytorch Yolov7 → ONNX with grid, EfficientNMS plugin and dynamic batch size

  • In this step you can change network input size - eg. 416x416, 640x640

3- To use TensorRT-INT8 you will need generate calibration file. I used this repo YOLO-TensorRT8

4- Export with INT8 precision, min batch 1, opt batch 10 and max batch 10
Tensorrt Docker Image : nvcr.io/nvidia/tensorrt:22.09-py3
Use calib.cache to generate Tensorrt INT8

./tensorrt/bin/trtexec --onnx=./yolov7-416.onnx \
--minShapes=images:1x3x416x416 \
--optShapes=images:10x3x416x416 \
--maxShapes=images:10x3x416x416 \
--int8 \
--calib=./calib.cache \
--workspace=8092 \
--saveEngine=yolov7-int8-1x10x10x416.engine \
--timingCacheFile=timing.cache

Triton Server Config:

name: "yolov7-416"
platform: "tensorrt_plan"
max_batch_size: 10
input {
  name: "images"
  data_type: TYPE_FP32
  dims: 3
  dims: 416
  dims: 416
}
output {
  name: "num_dets"
  data_type: TYPE_INT32
  dims: 1
}
output {
  name: "det_boxes"
  data_type: TYPE_FP32
  dims: 100
  dims: 4
}
output {
  name: "det_scores"
  data_type: TYPE_FP32
  dims: 100
}
output {
  name: "det_classes"
  data_type: TYPE_INT32
  dims: 100
}
instance_group [
  {
  count: 12
  kind: KIND_GPU
  gpus: [0]
  }
]
dynamic_batching {
}
default_model_filename: "yolov7-int8-1x10x10x416.engine"

Further tests:
Try PTQ calibration using Histogram(MSE)

thanks for sharing, we did some CUDA acceleration on postprocess, please refer to the code if interested.

2 Likes

Hi @Levi_Pereira

I’m the author of the triton deployment instructions in the yolov7 github repo. Thanks for sharing your deepstream setup. I have a few questions if you don’t mind.

  1. In postprocess:
  postprocess {
    labelfile_path: "/apps/configs/labels.txt"
    detection {
      num_detected_classes: 8
      custom_parse_bbox_func: "NvDsInferParseCustomEfficientNMS"
      nms {
        confidence_threshold: 0.35
        topk: 100
        iou_threshold: 0.65
      }
    }
  }

Why is setting NMS, topk and IoU again necessary here? It was my understanding that those are set in the exported ONNX directly in the efficientnms plugin?

  1. When computing the boxes in the custom parser code, I don’t understand where the letterbox rescaling (maintained aspect-ratio) is being reverted. Shouldn’t the padding added during that letterbox postprocessing also affect box coordinates?
    The letterbox function is here: yolov7/datasets.py at 072f76c72c641c7a1ee482e39f604f6f8ef7ee92 · WongKinYiu/yolov7 · GitHub

  2. Does INT8 perform well? Any observable degradation in accuracy compared to FP16?

  3. Would you mind sharing your full pipeline including tracking, just for reference?

Thanks again!

Hi @philipp.schmidt,

Q1. You are right this configuration is obsolete, I forgot to remove nms. (Due testing old yolo models)

Q2. I don’t have the exact answer where letterbox rescaling happens.
I believe that after the post-process calls the NvDsInferParseCustomEfficientNMS the Gst-nvdspostprocess plugin handles with box rescaling

Check:
/opt/nvidia/deepstream/deepstream/sources/gst-plugins/gst-nvdspostprocess/postprocesslib_impl/post_processor_detect.cpp

<snippet>
/**
 * Attach metadata for the detector. We will be adding a new metadata.
 */
void
DetectModelPostProcessor::attachMetadata (NvBufSurface *surf, gint batch_idx,
    NvDsBatchMeta  *batch_meta,
    NvDsFrameMeta  *frame_meta,
    NvDsObjectMeta  *object_meta,
    NvDsObjectMeta *parent_obj_meta,
    NvDsPostProcessFrameOutput & detection_output,
    NvDsPostProcessDetectionParams *all_params,
    std::set <gint> & filterOutClassIds,
    int32_t unique_id,
    gboolean output_instance_mask,
    gboolean process_full_frame,
    float segmentationThreshold,
    gboolean maintain_aspect_ratio)
{
  static gchar font_name[] = "Serif";
  NvDsObjectMeta *obj_meta = NULL;
  nvds_acquire_meta_lock (batch_meta);
  gint surf_width  = surf->surfaceList[batch_idx].width;
  gint surf_height = surf->surfaceList[batch_idx].height;
  float scale_x =
    (float)surf_width/(float)m_NetworkInfo.width;
  float scale_y =
    (float)surf_height/(float)m_NetworkInfo.height;

  //FIXME: Get preprocess data and scale ROI

  frame_meta->bInferDone = TRUE;
  /* Iterate through the inference output for one frame and attach the detected
   * bnounding boxes. */
  for (guint i = 0; i < detection_output.detectionOutput.numObjects; i++) {
    NvDsPostProcessObject & obj = detection_output.detectionOutput.objects[i];
    NvDsPostProcessDetectionParams & filter_params =
        all_params[obj.classIndex];

    /* Scale the bounding boxes proportionally based on how the object/frame was
     * scaled during input. */
    if (maintain_aspect_ratio){
      if (scale_x > scale_y){
        scale_y = scale_x;
      }
      else{
        scale_x = scale_y;
      }
    }
    obj.left = (obj.left - 0)*scale_x + 0;
    obj.top  = (obj.top - 0)*scale_y + 0;
    obj.width *= scale_x;
    obj.height *= scale_y;
<snippet>

Q3. Yes, On my custom dataset MAP .5 decrease from 0.95 to 0.9 but still acceptable for me. I’m working on TensorRT QAT-INT8 using Histogram(MSE)

Q4. I don’t have clean code to share it now, but I will work on it on next days.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.