Onnx python post-processing vs. TAO train post processing

Hello!

I’ve been working on training trafficcamnet on my own dataset of images with only cars using TAO train. After the model has been trained, instead of using TAO inference, I want to conduct inference in python with an exported onnx model. I’ve created post-processing code but when I run debug images through my post-processor they do not match the detections I am seeing in clearml bbox_preds. The coordinates of the detections are off by about 10 pixels in both x and y dimensions. This is before I run the detections through a clustering algo like NMS or DBSCAN.

To create my post-processor I’ve dug through the [tao_tensorflow_backend] (https://github.com/NVIDIA/tao_tensorflow1_backend/tree/main/nvidia_tao_tf1/cv/detectnet_v2/postprocessor) repository. Here is my code:

class TrafficCamNetPostProcess(object):

    def __init__(self, width, height, score=0.15, classes=[0]):
        '''
        Params:
            width,height(int) is the image-size that you want to get the BBX
            score(float) is the confidence
            classes(int-list) is the 3-classes(0 for person,1 for bag, 2 for face)
        '''
        self.image_width = width
        self.image_height = height

        self.model_h = 544
        self.model_w = 960
        self.stride = 16.0
        self.box_norm = 35.0

        self.grid_h = int(self.model_h / self.stride)
        self.grid_w = int(self.model_w / self.stride)
        self.grid_size = self.grid_h * self.grid_w

        self.grid_centers_w = []
        self.grid_centers_h = []

        for i in range(self.grid_h):
            value = (i * self.stride + 0.5) / self.box_norm
            self.grid_centers_h.append(value)

        for i in range(self.grid_w):
            value = (i * self.stride + 0.5) / self.box_norm
            self.grid_centers_w.append(value)

        '''
        min_confidence (float): min confidence to accept detection
        analysis_classes (list of int): indices of the classes to consider
        '''
        self.min_confidence = score
        self.analysis_classes = classes

    def applyBoxNorm(self, o1, o2, o3, o4, x, y):
        """
        Applies the GridNet box normalization
        Args:
            o1 (float): first argument of the result
            o2 (float): second argument of the result
            o3 (float): third argument of the result
            o4 (float): fourth argument of the result
            x: row index on the grid
            y: column index on the grid
        Returns:
            float: rescaled first argument
            float: rescaled second argument
            float: rescaled third argument
            float: rescaled fourth argument
        """
        o1 = (o1 - self.grid_centers_w[x]) * -self.box_norm
        o2 = (o2 - self.grid_centers_h[y]) * -self.box_norm
        o3 = (o3 + self.grid_centers_w[x]) * self.box_norm
        o4 = (o4 + self.grid_centers_h[y]) * self.box_norm
        return o1, o2, o3, o4

    def change_model_size_to_real(self, model_size, type):
        real_size = 0
        if type == 'x':
            real_size = (model_size / float(self.model_w)) * self.image_width
        elif type == 'y':
            real_size = (model_size / float(self.model_h)) * self.image_height
        real_size = int(real_size)
        return real_size

    def start(self, buffer_bbox, buffer_scores):
        """
        Postprocesses the inference output
        Args:
            outputs (list of float): inference output
        Returns: list of list tuple: each element is a two list tuple (x, y) representing the corners of a bb
        """

        bbs = []
        for c in range(3):
            if c not in self.analysis_classes:
                continue

            x1_idx = (c * 4 * self.grid_size)
            y1_idx = x1_idx + self.grid_size
            x2_idx = y1_idx + self.grid_size
            y2_idx = x2_idx + self.grid_size

            boxes = buffer_bbox
            for h in range(self.grid_h):
                for w in range(self.grid_w):
                    i = w + h * self.grid_w
                    score = buffer_scores[c * self.grid_size + i]
                    if score >= self.min_confidence:
                        o1 = boxes[x1_idx + w + h * self.grid_w]
                        o2 = boxes[y1_idx + w + h * self.grid_w]
                        o3 = boxes[x2_idx + w + h * self.grid_w]
                        o4 = boxes[y2_idx + w + h * self.grid_w]

                        o1, o2, o3, o4 = self.applyBoxNorm(
                            o1, o2, o3, o4, w, h)

                        xmin_model = int(o1)
                        ymin_model = int(o2)
                        xmax_model = int(o3)
                        ymax_model = int(o4)

                        xmin_image = self.change_model_size_to_real(
                            xmin_model, 'x')
                        ymin_image = self.change_model_size_to_real(
                            ymin_model, 'y')
                        xmax_image = self.change_model_size_to_real(
                            xmax_model, 'x')
                        ymax_image = self.change_model_size_to_real(
                            ymax_model, 'y')

                        rect = (xmin_image, ymin_image, xmax_image, ymax_image)

                        bbs.append(rect)
        return bbs

Here is how one of the test images looks in clearml,

When I run it through onnx inference and my post processor like this:

results = onnx_infer(image_np, ort_session)
bboxs = results[0].flatten()
scores = results[1].flatten()

trafficcamnet_postprocess = TrafficCamNetPostProcess(model_w, model_h, score=0.001, classes=[0], pad = 15)

 rects = trafficcamnet_postprocess.start(bboxs, scores)

I get a result which looks like this:

I thought some background pre-processing during TAO training could be causing this offset, so I eliminated translate_max_y and translate_max_x in my training config. Here is my augmentation_config and postprocessing_config.

augmentation_config {
  preprocessing {
    output_image_width: 960
    output_image_height: 544
    output_image_channel: 3
  }
  spatial_augmentation {
    hflip_probability: 0.0
    zoom_min: 1.0
    zoom_max: 1.0
    translate_max_x: 0.0
    translate_max_y: 0.0
  }
  color_augmentation {
    hue_rotation_max: 25.0
    saturation_shift_max: 0.20000000298
    contrast_scale_max: 0.10000000149
    contrast_center: 0.5
  }
}
postprocessing_config {
  target_class_config {
    key: "car"
    value {
      clustering_config {
        clustering_algorithm: DBSCAN
        dbscan_confidence_threshold: 0.1
        coverage_threshold: 0.001
        dbscan_eps: 0.15
        dbscan_min_samples: 1
        minimum_bounding_box_height: 20
      }
    }
  }
}

Any idea why I am seeing this offset? Thanks!

Please use tao deploy to generate tensorrt engine and run inference to double check.
Refer to DetectNet_v2 with TAO Deploy - NVIDIA Docs and https://github.com/NVIDIA/tao_deploy/tree/main/nvidia_tao_deploy/cv/detectnet_v2.

I generated a tensorrt engine using tao deploy with the following command:

# Need to pass the actual image directory instead of data root for tao deploy to locate images for calibration
!sed -i "s|/workspace/tao-experiments/data/forsight-trafficcamnet-replace-anything-samples-cars-only-padded|/workspace/tao-experiments/data/forsight-trafficcamnet-replace-anything-samples-cars-only-padded/data|g" $LOCAL_SPECS_DIR/trafficcamnet_train_resnet18_kitti.txt
# Convert to TensorRT engine (INT8)
!tao deploy detectnet_v2 gen_trt_engine \
                  -m $USER_EXPERIMENT_DIR/forsight-replace-anything-samples-cars-only-paded-experiment-dir/final/model_epoch-100.onnx \
                  --data_type int8 \
                  --batches 10 \
                  --batch_size 4 \
                  --max_batch_size 64\
                  --engine_file $USER_EXPERIMENT_DIR/forsight-replace-anything-samples-cars-only-paded-experiment-dir/final/model_epoch-100.engine \
                  --cal_cache_file $USER_EXPERIMENT_DIR/forsight-replace-anything-samples-cars-only-paded-experiment-dir/final/trafficcamnet_int8.txt \
                  -e $SPECS_DIR/trafficcamnet_train_resnet18_kitti.txt \
                  --results_dir $USER_EXPERIMENT_DIR/forsight-replace-anything-samples-cars-only-paded-experiment-dir/final \
                  --verbose
# Convert back the spec file
!sed -i "s|/workspace/tao-experiments/data/forsight-trafficcamnet-replace-anything-samples-cars-only-padded/data|/workspace/tao-experiments/data/forsight-trafficcamnet-replace-anything-samples-cars-only-padded|g" $LOCAL_SPECS_DIR/trafficcamnet_train_resnet18_kitti.txt

Next, using https://github.com/NVIDIA/tao_deploy/tree/main/nvidia_tao_deploy/cv/detectnet_v2/scripts/inference.py, I ran inference using the tensorrt engine and the same postprocessing config as I specified in model training. I got the following output:


So, the offset like what I witnessed in the outputs from the onnx model is there. Also, a lot of detections are filtered out → I’m not sure why. To try and debug further I also ran inference on the same exact image using the hdf5 model file which is created during training. With this model I got the results which I am expecting:

It seems to me that something is going wrong when I export my hdf5 model to onnx after training. This is how I am exporting:

!tao model detectnet_v2 export \
                  -m $USER_EXPERIMENT_DIR/forsight-replace-anything-samples-cars-only-paded-experiment-dir/model.epoch-100.hdf5 \
                  -e $SPECS_DIR/trafficcamnet_train_resnet18_kitti.txt \
                  -o $USER_EXPERIMENT_DIR/forsight-replace-anything-samples-cars-only-paded-experiment-dir/final/model_epoch-100.onnx \
                  --gen_ds_config

To recap, there is a difference between my hdf5 model inference and my onnx/tensorrt inference. I see an offset in location of my bounding boxes. Bboxes are also being filtered out. Maybe I am doing something wrong during model export? Also, can you guide me towards some python code for hdf5 inference? This would not be good for a long term solution but would certainly help with my model analysis pipeline in the short term. Thanks.

Please add below in your export command and retry.
--onnx_route tf2onnx

You can refer to the source code https://github.com/NVIDIA/tao_tensorflow1_backend/blob/c7a3926ddddf3911842e057620bceb45bb5303cc/nvidia_tao_tf1/cv/detectnet_v2/scripts/inference.py.

I added the above parameter to my export API call and it fixed the issue.

Can this be added to the documenation here, DetectNet_v2 - NVIDIA Docs?

I didn’t see any mention of the onnx_route parameter. Maybe it is explaned somewhere else but it could have saved me a lot of time if it was in here. But anyway, thank you!

Glad to know it is working now. Yes, we will improve the document.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.