Accuracy decreases after integrating mask rcnn in deepstream

Hi,
I have retrained the mask rcnn default model with custom dataset. I have my custom model h5 file which i converted to uff format following this https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleUffMaskRCNN once the uff was created with the defined steps i converted the uff to engine and used that engine file in deepstream.

Now the results that i am obtaining with deepstream are very bad in terms of accuracy compared to result with tensorflow based inference.

The configuration while retraining the model were
classes = 1 + 1 (person and background)
image_shape = (1024 1024 3)

and the same configurations were set when converting h5 to engine
Can someone point out the issue here ?

Hi,

1. Which platform do you use? A desktop GPU or Jetson device?
Do you generate the TensorRT engine on the same platform of Deepstream?

2. Do you use the same input sequence as tested with TensorFlow?
Could you help to double check that this issue is from frameworks rather than model itself?

3. Which precision do you use for inference?
Please noticed that you will need to re-generate a calibration cache for INT8 inference.

Thanks.

Hi,

  1. I am using Tesla V100’s. I generated the uff file in a tensorflow container from nvdia ngc registery with tensorrt 6.0.1 and for engine file generation i used a deepstream 4.0.2 container from nvidia ngc and in the same container i am running deepstream with mask rcnn.
  2. As i mentioned earlier the input shape in training the model is [1024 1024 3] and is the same when converting to uff as (3 1024 1024). I already double checked by doing inference outside of deepstream and is giving pretty good results.
  3. I am using fp16 as the precision for inference as it was giving a better through put.

Thanks

Hi,

To check this further, could you share a simple reproducible source for us debugging?
It will be good to include the source code, model and testing video.

Thanks.

Hi,

I have shared the configuration files for every step from the retraining to the engine file creation.
Further i will share the retrained model on custom data set along with the engine files that we created and a testing video in direct message.

Step 1: Retraining of the coco based pre-trained maskrcnn model on custom data set.

Following are the configurations for retraining of the model:

Configurations:
BACKBONE                       resnet101
BACKBONE_STRIDES               [4, 8, 16, 32, 64]
BATCH_SIZE                     1
BBOX_STD_DEV                   [0.1 0.1 0.2 0.2]
COMPUTE_BACKBONE_SHAPE         None
DETECTION_MAX_INSTANCES        100
DETECTION_MIN_CONFIDENCE       0.7
DETECTION_NMS_THRESHOLD        0.3
FPN_CLASSIF_FC_LAYERS_SIZE     1024
GPU_COUNT                      1
GRADIENT_CLIP_NORM             5.0
IMAGES_PER_GPU                 1
IMAGE_CHANNEL_COUNT            3
IMAGE_MAX_DIM                  1024
IMAGE_META_SIZE                14
IMAGE_MIN_DIM                  800
IMAGE_MIN_SCALE                0
IMAGE_RESIZE_MODE              square
IMAGE_SHAPE                    [1024 1024    3]
LEARNING_MOMENTUM              0.9
LEARNING_RATE                  0.001
LOSS_WEIGHTS                   {'rpn_class_loss': 1.0, 'rpn_bbox_loss': 1.0, 'mrcnn_class_loss': 1.0, 'mrcnn_bbox_loss': 1.0, 'mrcnn_mask_loss': 1.0}
MASK_POOL_SIZE                 14
MASK_SHAPE                     [28, 28]
MAX_GT_INSTANCES               100
MEAN_PIXEL                     [123.7 116.8 103.9]
MINI_MASK_SHAPE                (56, 56)
NAME                           cafe
NUM_CLASSES                    2
POOL_SIZE                      7
POST_NMS_ROIS_INFERENCE        1000
POST_NMS_ROIS_TRAINING         2000
PRE_NMS_LIMIT                  1024
ROI_POSITIVE_RATIO             0.33
RPN_ANCHOR_RATIOS              [0.5, 1, 2]
RPN_ANCHOR_SCALES              (32, 64, 128, 256, 512)
RPN_ANCHOR_STRIDE              1
RPN_BBOX_STD_DEV               [0.1 0.1 0.2 0.2]
RPN_NMS_THRESHOLD              0.7
RPN_TRAIN_ANCHORS_PER_IMAGE    256
STEPS_PER_EPOCH                1000
TOP_DOWN_PYRAMID_SIZE          256
TRAIN_BN                       False
TRAIN_ROIS_PER_IMAGE           200
USE_MINI_MASK                  True
USE_RPN_ROIS                   True
VALIDATION_STEPS               100
WEIGHT_DECAY                   0.0001

The retained model was then tested then and the results are shared with you and they were really good.

Step 2: H5 to UFF conversion

As mentioned earlier the following steps were followed for uff conversion https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleUffMaskRCNN

Tensorrt versions in this conversion container:

ii  graphsurgeon-tf                        6.0.1-1+cuda10.1                  amd64        GraphSurgeon for TensorRT package
ii  libnvinfer-bin                         6.0.1-1+cuda10.1                  amd64        TensorRT binaries
ii  libnvinfer-dev                         6.0.1-1+cuda10.1                  amd64        TensorRT development libraries and headers
ii  libnvinfer-plugin-dev                  6.0.1-1+cuda10.1                  amd64        TensorRT plugin libraries
ii  libnvinfer-plugin6                     6.0.1-1+cuda10.1                  amd64        TensorRT plugin libraries
ii  libnvinfer6                            6.0.1-1+cuda10.1                  amd64        TensorRT runtime libraries
ii  libnvonnxparsers-dev                   6.0.1-1+cuda10.1                  amd64        TensorRT ONNX libraries
ii  libnvonnxparsers6                      6.0.1-1+cuda10.1                  amd64        TensorRT ONNX libraries
ii  libnvparsers-dev                       6.0.1-1+cuda10.1                  amd64        TensorRT parsers libraries
ii  libnvparsers6                          6.0.1-1+cuda10.1                  amd64        TensorRT parsers libraries
ii  python3-libnvinfer                     6.0.1-1+cuda10.1                  amd64        Python 3 bindings for TensorRT
ii  python3-libnvinfer-dev                 6.0.1-1+cuda10.1                  amd64        Python 3 development package for TensorRT
ii  uff-converter-tf                       6.0.1-1+cuda10.1                  amd64        UFF converter for TensorRT package

mrcnn_to_trt_single.py

from keras.models import model_from_json, Model

from keras import backend as K

from keras.layers import Input, Lambda

from tensorflow.python.framework import graph_util

from tensorflow.python.framework import graph_io

from mrcnn.model import *

import mrcnn.model as modellib

from mrcnn.config import Config

import sys

import os

ROOT_DIR = os.path.abspath("./")

LOG_DIR = os.path.join(ROOT_DIR, "logs")

import argparse

import os

import uff

def parse_command_line_arguments(args=None):

    parser = argparse.ArgumentParser(prog='keras_to_trt', description='Convert trained keras .hdf5 model to trt .uff')

    parser.add_argument(

        '-w',

        '--weights',

        type=str,

        default=None,

        required=True,

        help="The checkpoint weights file of keras model."

    )

    parser.add_argument(

        '-o',

        '--output_file',

        type=str,

        default=None,

        required=True,

        help="The path to output .uff file."

    )

    parser.add_argument(

        '-l',

        '--list-nodes',

        action='store_true',

        help="show list of nodes contained in converted pb"

    )

    parser.add_argument(

        '-p',

        '--preprocessor',

        type=str,

        default=False,

        help="The preprocess function for converting tf node to trt plugin"

    )

    return parser.parse_args(args)

class CocoConfig(Config):

    """Configuration for training on MS COCO.

    Derives from the base Config class and overrides values specific

    to the COCO dataset.

    """

    # Give the configuration a recognizable name

    NAME = "cafe"

    # We use a GPU with 12GB memory, which can fit two images.

    # Adjust down if you use a smaller GPU.

    IMAGES_PER_GPU = 2

    # Uncomment to train on 8 GPUs (default is 1)

    # GPU_COUNT = 8

    # Number of classes (including background)

    NUM_CLASSES = 1 + 1  # COCO has 80 classes

class InferenceConfig(CocoConfig):

    # Set batch size to 1 since we'll be running inference on

    # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU

    GPU_COUNT = 1

    IMAGES_PER_GPU = 1

    NUM_CLASSES = 1 + 1

def main(args=None):

    K.set_image_data_format('channels_first')

    K.set_learning_phase(0)

    args = parse_command_line_arguments(args)

    model_weights_path = args.weights

    output_file_path = args.output_file

    list_nodes = args.list_nodes

    config = InferenceConfig()

    config.display()

    model = modellib.MaskRCNN(mode="inference", model_dir=LOG_DIR, config=config).keras_model

    model.load_weights(model_weights_path, by_name=True)

    model_A = Model(inputs=model.input, outputs=model.get_layer('mrcnn_mask').output)

    model_A.summary()

    output_nodes = ['mrcnn_detection', "mrcnn_mask/Sigmoid"]

    convert_model(model_A, output_file_path, output_nodes, preprocessor=args.preprocessor,

                  text=True, list_nodes=list_nodes)

def convert_model(inference_model, output_path, output_nodes=[], preprocessor=None, text=False,

                  list_nodes=False):

    # convert the keras model to pb

    orig_output_node_names = [node.op.name for node in inference_model.outputs]

    print("The output names of tensorflow graph nodes: {}".format(str(orig_output_node_names)))

    sess = K.get_session()

    constant_graph = graph_util.convert_variables_to_constants(

        sess,

        sess.graph.as_graph_def(),

        orig_output_node_names)

    temp_pb_path = "../temp.pb"

    graph_io.write_graph(constant_graph, os.path.dirname(temp_pb_path), os.path.basename(temp_pb_path),

                         as_text=False)

    predefined_output_nodes = output_nodes

    if predefined_output_nodes != []:

        trt_output_nodes = predefined_output_nodes

    else:

        trt_output_nodes = orig_output_node_names

    # convert .pb to .uff

    uff.from_tensorflow_frozen_model(

        temp_pb_path,

        output_nodes=trt_output_nodes,

        preprocessor=preprocessor,

        text=text,

        list_nodes=list_nodes,

        output_filename=output_path,

        debug_mode = False

    )

    os.remove(temp_pb_path)

if __name__ == "__main__":

    main()

Further in the config.h the number of classes was updated.

Step 3: Uff to engine conversion

In this step a deepstream 4.0.2 container was established with the same tensorrt 6 as in the previous container.

uff_to_engine.py

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # gpu ID
import tensorrt as trt

tile_nums = (3,5)
batch_size = tile_nums[0]*tile_nums[1]+1
batch_size = 1

MODEL_INPUT = 'input_image'
MODEL_OUTPUTS =  ["mrcnn_detection", "mrcnn_mask/Sigmoid"]
INPUT_SHAPE = (3, 1024, 1024)
FP16_FLAG = False
model_file = './uff/mask_rcnn_cafe_0018_81_370train.uff'
batch_size = batch_size

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# must initialize plugins (custom layers for MRCNN)
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
print("Building Engine file")
# Create context, engine
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
    parser.register_input(name=MODEL_INPUT, shape=INPUT_SHAPE, order=trt.UffInputOrder.NCHW)
    for output in MODEL_OUTPUTS:
        parser.register_output(name=output)
    parser.parse(file=model_file, network=network, weights_type=trt.tensorrt.DataType.FLOAT)
    builder.max_batch_size = batch_size
    builder.max_workspace_size = int(8e8)
    builder.fp16_mode = FP16_FLAG
    engine = builder.build_cuda_engine(network=network)
    context = engine.create_execution_context()

with open("mask_rcnn_cafe_0018_81_370_batch1.engine", "wb") as f:
    f.write(engine.serialize())

print("Finished")

The final generated engine file is also shared with you with i have integrated with deepstream and the results were not as good as outside deepstream.
Further i have tried all the above steps with 81 classes that were default and no difference .

Thanks

Hi,

Thanks for your sharing.

May I know if the accuracy is good when inferencing with TensorRT OSS MaskRCNN sample directly?
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleUffMaskRCNN#running-the-sample

This will help us figure out the issue comes from TensorRT or Deepstream.

Thanks.

Hi,
I have not yet tested the MaskRCNN sample. Though i have figured out that the pre_nms_limit when converting to uff is set at 1024 and when we increase this it gives error with deepstream.
While when we retrained the model the pre_nms_limit was set to 6000 and maybe this is the reason we are getting better results outside deesptrem.
Is there any way we can increase the pre_nms_limit when integrating with deepstream.

Thanks.

Hi,
Just to follow-up. Can we change pre_nms_limit to some value greater than 1024 at the moment with TensorRT?
When we re-trained all layers of pre-trained model instead of just the final layer using pre_nms_limit of 1024 then results are much better in deepstream but still not the same as outside deepstream. It seems that if we will be able to port a model by providing more proposals accuracy should be improved.

Hi,

Sorry that I cannot find a ‘pre_nms_limit’ parameter in either the mrcnn_nchw.uff or mrcnn_to_trt_single.py.

Do you indicate the prenms_topk_u_int parameter in ROI node?

  nodes {
    id: "ROI"
    inputs: "rpn_class/concat"
    inputs: "rpn_bbox/concat"
    operation: "_ProposalLayer_TRT"
    ...
    fields {
      key: "prenms_topk_u_int"
      value {
        i: 1024
      }
    }
  }

Thanks.

Hi,
Yes it is prenms_topk. In config.py prenms_topk is set to 1024.

roi = gs.create_plugin_node(“ROI”, op=“ProposalLayer_TRT”, prenms_topk=1024, keep_topk=1000, iou_threshold=0.7)

Hi,

Do you meet any error when converting the uff model with pre_nms_limit = 6000?

roi = gs.create_plugin_node("ROI", op="ProposalLayer_TRT", prenms_topk=6000, keep_topk=1000, iou_threshold=0.7)

And could you share the error log from Deepstream with us?
You may need to update the parameter used in the MaskRCNN parser here:

Thanks.

Hi,
We don’t get any error when we convert to uff model with pre_nms_limit = 6000.
But once we integrate it with deepstream with it gives the following error.

Creating LL OSD context new
rvaas-ds: /root/TensorRT/TensorRT/plugin/common/kernels/maskRCNNKernels.cu:1221: 
cudaError_t proposalRefineBatchClassNMS(cudaStream_t, int, int, int, nvinfer1::DataType, const RefineNMSParameters&, const ProposalWorkSpace&, void*, const void*, const void*, const void*, const void*, void*): Assertion `false && "unsupported sortPerClass"' failed.
Aborted (core dumped)

Further can you specify which parameter you were suggesting to update in the MaskRCNN paser file. I don’t see any parameter related to nms_limit.

Thanks

Hi,

Based on this source, TensorRT only support pre_nms_limit up to 4096:

To avoid this error, you can use a lower pre_nms_limit value or implement the sortPerClass for 6000 on your own.
Maybe you can try if "sortPerClass<256, 32>(..)" is working.

Thanks.

Hi,
We tried pre_nms_limit of 2048 and now error is changed. It seems else if code block for samples <= 2048 is not called at this step.

/root/TensorRT/TensorRT/plugin/common/kernels/maskRCNNKernels.cu:380: void PerClassNMS_kernel(int, int, float, const void *, const void *, const void *, const void *, const void *, void *) [with DType = float, BoxType = float, Threads = 256, ItemsPerThreads = 4]: block: [0,0,0], thread: [224,0,0] Assertion samples <= Threads * ItemsPerThreads failed.

Hi,

Could you share how do you set up the pre_nms_limit number?

The error indicates some issue among the samples, Threads and ItemsPerThreads.
But the value is calculated in the source directly. It should be correct.

ItemsPerThreads : = divUp(samples, Threads)

Thanks.

Hi,

we are setting this value in two files while uff generation

first is https://github.com/NVIDIA/TensorRT/blob/master/samples/opensource/sampleUffMaskRCNN/converted/config.py

here we set prenms_topk=2048 at below line

roi = gs.create_plugin_node(“ROI”, op=“ProposalLayer_TRT”, prenms_topk=2048, keep_topk=1000, iou_threshold=0.7)

other is

https://github.com/NVIDIA/TensorRT/blob/master/samples/opensource/sampleUffMaskRCNN/converted/mrcnn_to_trt_single.py

here we specified
PRE_NMS_LIMIT = 2048 in InferenceConfig

class InferenceConfig(CocoConfig):
# Set batch size to 1 since we’ll be running inference on
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
GPU_COUNT = 1
IMAGES_PER_GPU = 1
PRE_NMS_LIMIT = 2048

also note that model is trained with same number of proposals

Hi,

Just want to clarify first.

Does the model contain the PRE_NMS_LIMIT value itself? Just like a parameter value.
If yes, is the value also set to 4096?

Thanks.

Hi,
yes when we train the model we also specify PRE_NMS_LIMIT. In our case we set it to the same value as we are specifying while uff conversion. 4096 at both places. But in order to run inference it is not necessary that both values should be same. If we have some higher value set during the training and we use 1024 for uff conversion that will work. In that case then accuracy degrades significantly with deepstream and tensorRT.