Summary:
How to deploy mrcnn model (GitHub - matterport/Mask_RCNN: Mask R-CNN for object detection and instance segmentation on Keras and TensorFlow) with resnet50 backbone and classNum change in h5 model ?
tensorflow-gpu: 1.9.0
Keras: 2.1.3
cuda-9.0
cudnn 7.0
python 2.7
H5 model: mask_rcnn_nucleus_0080.h5 - Google Drive names it “mask_rcnn_coco_restnet50.h5”
1. Check by samples/demo.ipynb
a. Code change
diff --git a/mrcnn/config.py b/mrcnn/config.py
--- a/mrcnn/config.py
+++ b/mrcnn/config.py
@@ -52,7 +52,7 @@ class Config(object):
# You can also provide a callable that should have the signature
# of model.resnet_graph. If you do so, you need to supply a callable
# to COMPUTE_BACKBONE_SHAPE as well
- BACKBONE = "resnet101"
+ BACKBONE = "resnet50"
# Only useful if you supply a callable to BACKBONE. Should compute
# the shape of each layer of the FPN Pyramid.
diff --git a/samples/coco_config/coco_config.py b/samples/coco_config/coco_config.py
--- a/samples/coco_config/coco_config.py
+++ b/samples/coco_config/coco_config.py
@@ -49,7 +49,7 @@ from mrcnn.config import Config
from mrcnn import model as modellib, utils
# Path to trained weights file
-COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
+COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco_restnet50.h5")
# Directory to save logs and model checkpoints, if not provided
# through the command line argument --logs
@@ -77,5 +77,5 @@ class CocoConfig(Config):
# GPU_COUNT = 8
# Number of classes (including background)
- NUM_CLASSES = 1 + 80 # COCO has 80 classes
+ NUM_CLASSES = 1 + 1 # COCO has 1 classes
b. Note: Don’t apply the patch “0001-Update-the-Mask_RCNN-model-from-NHWC-to-NCHW.patch”
2. Convert H5 to uff
a. Refer to https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleUffMaskRCNN
b. Different config.py
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import graphsurgeon as gs
import tensorflow as tf
fpn_p5upsampled = gs.create_plugin_node("fpn_p5upsampled", op="ResizeNearest_TRT", dtype=tf.float32, scale=2.0)
fpn_p4upsampled = gs.create_plugin_node("fpn_p4upsampled", op="ResizeNearest_TRT", dtype=tf.float32, scale=2.0)
fpn_p3upsampled = gs.create_plugin_node("fpn_p3upsampled", op="ResizeNearest_TRT", dtype=tf.float32, scale=2.0)
roi = gs.create_plugin_node("ROI", op="ProposalLayer_TRT", prenms_topk=1024, keep_topk=1000, iou_threshold=0.7)
roi_align_classifier = gs.create_plugin_node("roi_align_classifier", op="PyramidROIAlign_TRT", pooled_size=7)
mrcnn_detection = gs.create_plugin_node("mrcnn_detection", op="DetectionLayer_TRT", num_classes=81, keep_topk=100, score_threshold=0.7, iou_threshold=0.3)
roi_align_mask = gs.create_plugin_node("roi_align_mask_trt", op="PyramidROIAlign_TRT", pooled_size=14)
mrcnn_detection_bboxes = gs.create_plugin_node("mrcnn_detection_bboxes", op="SpecialSlice_TRT")
namespace_plugin_map = {
"fpn_p5upsampled":fpn_p5upsampled,
"fpn_p4upsampled":fpn_p4upsampled,
"fpn_p3upsampled":fpn_p3upsampled,
"roi_align_classifier":roi_align_classifier,
"mrcnn_detection":mrcnn_detection,
"ROI":roi,
"roi_align_mask":roi_align_mask,
"lambda_1": mrcnn_detection_bboxes,
}
timedistributed_remove_list = [
"mrcnn_class_conv1/Reshape/shape", "mrcnn_class_conv1/Reshape", "mrcnn_class_conv1/Reshape_1/shape", "mrcnn_class_conv1/Reshape_1",
"mrcnn_class_bn1/Reshape/shape", "mrcnn_class_bn1/Reshape", "mrcnn_class_bn1/Reshape_5/shape", "mrcnn_class_bn1/Reshape_5",
"mrcnn_class_conv2/Reshape/shape", "mrcnn_class_conv2/Reshape", "mrcnn_class_conv2/Reshape_1/shape", "mrcnn_class_conv2/Reshape_1",
"mrcnn_class_bn2/Reshape/shape", "mrcnn_class_bn2/Reshape", "mrcnn_class_bn2/Reshape_5/shape", "mrcnn_class_bn2/Reshape_5",
"mrcnn_class_logits/Reshape/shape", "mrcnn_class_logits/Reshape","mrcnn_class_logits/Reshape_1/shape", "mrcnn_class_logits/Reshape_1",
"mrcnn_class/Reshape/shape", "mrcnn_class/Reshape","mrcnn_class/Reshape_1/shape", "mrcnn_class/Reshape_1",
"mrcnn_bbox_fc/Reshape/shape", "mrcnn_bbox_fc/Reshape","mrcnn_bbox_fc/Reshape_1/shape", "mrcnn_bbox_fc/Reshape_1",
"mrcnn_mask_conv1/Reshape/shape", "mrcnn_mask_conv1/Reshape", "mrcnn_mask_conv1/Reshape_1/shape", "mrcnn_mask_conv1/Reshape_1",
"mrcnn_mask_bn1/Reshape/shape", "mrcnn_mask_bn1/Reshape", "mrcnn_mask_bn1/Reshape_5/shape", "mrcnn_mask_bn1/Reshape_5",
"mrcnn_mask_conv2/Reshape/shape", "mrcnn_mask_conv2/Reshape", "mrcnn_mask_conv2/Reshape_1/shape", "mrcnn_mask_conv2/Reshape_1",
"mrcnn_mask_bn2/Reshape/shape", "mrcnn_mask_bn2/Reshape", "mrcnn_mask_bn2/Reshape_5/shape", "mrcnn_mask_bn2/Reshape_5",
"mrcnn_mask_conv3/Reshape/shape", "mrcnn_mask_conv3/Reshape", "mrcnn_mask_conv3/Reshape_1/shape", "mrcnn_mask_conv3/Reshape_1",
"mrcnn_mask_bn3/Reshape/shape", "mrcnn_mask_bn3/Reshape", "mrcnn_mask_bn3/Reshape_5/shape", "mrcnn_mask_bn3/Reshape_5",
"mrcnn_mask_conv4/Reshape/shape", "mrcnn_mask_conv4/Reshape", "mrcnn_mask_conv4/Reshape_1/shape", "mrcnn_mask_conv4/Reshape_1",
"mrcnn_mask_bn4/Reshape/shape", "mrcnn_mask_bn4/Reshape", "mrcnn_mask_bn4/Reshape_5/shape", "mrcnn_mask_bn4/Reshape_5",
"mrcnn_mask_deconv/Reshape/shape", "mrcnn_mask_deconv/Reshape", "mrcnn_mask_deconv/Reshape_1/shape", "mrcnn_mask_deconv/Reshape_1",
"mrcnn_mask/Reshape/shape", "mrcnn_mask/Reshape", "mrcnn_mask/Reshape_1/shape", "mrcnn_mask/Reshape_1",
]
timedistributed_connect_pairs = [
("mrcnn_mask_deconv/Relu", "mrcnn_mask/convolution"), # mrcnn_mask_deconv -> mrcnn_mask
("activation_40/Relu", "mrcnn_mask_deconv/conv2d_transpose"), #active74 -> mrcnn_mask_deconv
("mrcnn_mask_bn4/batchnorm/add_1","activation_40/Relu"), # mrcnn_mask_bn4 -> active74
("mrcnn_mask_conv4/BiasAdd", "mrcnn_mask_bn4/batchnorm/mul_1"), #mrcnn_mask_conv4 -> mrcnn_mask_bn4
("activation_39/Relu", "mrcnn_mask_conv4/convolution"), #active73 -> mrcnn_mask_conv4
("mrcnn_mask_bn3/batchnorm/add_1","activation_39/Relu"), #mrcnn_mask_bn3 -> active73
("mrcnn_mask_conv3/BiasAdd", "mrcnn_mask_bn3/batchnorm/mul_1"), #mrcnn_mask_conv3 -> mrcnn_mask_bn3
("activation_38/Relu", "mrcnn_mask_conv3/convolution"), #active72 -> mrcnn_mask_conv3
("mrcnn_mask_bn2/batchnorm/add_1","activation_38/Relu"), #mrcnn_mask_bn2 -> active72
("mrcnn_mask_conv2/BiasAdd", "mrcnn_mask_bn2/batchnorm/mul_1"), #mrcnn_mask_conv2 -> mrcnn_mask_bn2
("activation_37/Relu", "mrcnn_mask_conv2/convolution"), #active71 -> mrcnn_mask_conv2
("mrcnn_mask_bn1/batchnorm/add_1","activation_37/Relu"), #mrcnn_mask_bn1 -> active71
("mrcnn_mask_conv1/BiasAdd", "mrcnn_mask_bn1/batchnorm/mul_1"), #mrcnn_mask_conv1 -> mrcnn_mask_bn1
("roi_align_mask_trt", "mrcnn_mask_conv1/convolution"), #roi_align_mask -> mrcnn_mask_conv1
("mrcnn_class_bn2/batchnorm/add_1","activation_35/Relu"), # mrcnn_class_bn2 -> active 69
("mrcnn_class_conv2/BiasAdd", "mrcnn_class_bn2/batchnorm/mul_1"), # mrcnn_class_conv2 -> mrcnn_class_bn2
("activation_34/Relu", "mrcnn_class_conv2/convolution"), # active 68 -> mrcnn_class_conv2
("mrcnn_class_bn1/batchnorm/add_1","activation_34/Relu"), # mrcnn_class_bn1 -> active 68
("mrcnn_class_conv1/BiasAdd", "mrcnn_class_bn1/batchnorm/mul_1"), # mrcnn_class_conv1 -> mrcnn_class_bn1
("roi_align_classifier", "mrcnn_class_conv1/convolution"), # roi_align_classifier -> mrcnn_class_conv1
]
dense_compatible_patch =["pool_squeeze/Squeeze", "pool_squeeze/Squeeze_1", #No need to squeeze the dimensions for TRT Dense Layer
"mrcnn_bbox/Shape", "mrcnn_bbox/strided_slice/stack", # mrcnn_bbox(Reshape): No need to reshape, cause we can process it as 1-D array in detectionlayer's kernel
"mrcnn_bbox/strided_slice/stack_1", "mrcnn_bbox/strided_slice/stack_2",
"mrcnn_bbox/strided_slice", "mrcnn_bbox/Reshape/shape/1",
"mrcnn_bbox/Reshape/shape/2", "mrcnn_bbox/Reshape/shape/3",
"mrcnn_bbox/Reshape/shape", "mrcnn_bbox/Reshape"]
dense_compatible_connect_pairs = [
("activation_35/Relu","mrcnn_bbox_fc/MatMul"), #activation_69 -> mrcnn_bbox_fc
("activation_35/Relu", "mrcnn_class_logits/MatMul"), #activation_69 -> mrcnn_class_logits
("mrcnn_class_logits/BiasAdd", "mrcnn_class/Softmax"), #mrcnn_class_logits -> mrcnn_class
("mrcnn_class/Softmax", "mrcnn_detection"), #mrcnn_class -> mrcnn_detection
("mrcnn_bbox_fc/BiasAdd", "mrcnn_detection"), #mrcnn_bbox_fc -> mrcnn_detection
]
def connect(dynamic_graph, connections_list):
for node_a_name, node_b_name in connections_list:
if node_a_name not in dynamic_graph.node_map[node_b_name].input:
dynamic_graph.node_map[node_b_name].input.insert(0, node_a_name)
def preprocess(dynamic_graph):
# Now create a new graph by collapsing namespaces
dynamic_graph.collapse_namespaces(namespace_plugin_map, unique_inputs=True)
dynamic_graph.remove(timedistributed_remove_list)
dynamic_graph.remove(dense_compatible_patch)
dynamic_graph.remove(['input_anchors', 'input_image_meta'])
connect(dynamic_graph, timedistributed_connect_pairs)
connect(dynamic_graph, dense_compatible_connect_pairs)
Different mrcnn_to_trt_single.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from keras.models import model_from_json, Model
from keras import backend as K
from keras.layers import Input, Lambda
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from mrcnn.model import *
import mrcnn.model as modellib
from mrcnn.config import Config
import sys
import os
ROOT_DIR = os.path.abspath("./")
LOG_DIR = os.path.join(ROOT_DIR, "logs")
import argparse
import os
import uff
def parse_command_line_arguments(args=None):
parser = argparse.ArgumentParser(prog='keras_to_trt', description='Convert trained keras .hdf5 model to trt .uff')
parser.add_argument(
'-w',
'--weights',
type=str,
default=None,
required=True,
help="The checkpoint weights file of keras model."
)
parser.add_argument(
'-o',
'--output_file',
type=str,
default=None,
required=True,
help="The path to output .uff file."
)
parser.add_argument(
'-l',
'--list-nodes',
action='store_true',
help="show list of nodes contained in converted pb"
)
parser.add_argument(
'-p',
'--preprocessor',
type=str,
default=False,
help="The preprocess function for converting tf node to trt plugin"
)
return parser.parse_args(args)
class CocoConfig(Config):
"""Configuration for training on MS COCO.
Derives from the base Config class and overrides values specific
to the COCO dataset.
"""
# Give the configuration a recognizable name
NAME = "coco"
# We use a GPU with 12GB memory, which can fit two images.
# Adjust down if you use a smaller GPU.
IMAGES_PER_GPU = 2
# Uncomment to train on 8 GPUs (default is 1)
# GPU_COUNT = 8
# Number of classes (including background)
NUM_CLASSES = 1 + 1 # COCO has 1 classes
BACKBONE = 'resnet50'
class InferenceConfig(CocoConfig):
# Set batch size to 1 since we'll be running inference on
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
GPU_COUNT = 1
IMAGES_PER_GPU = 1
BACKBONE = 'resnet50' # added ResNet50
def main(args=None):
K.set_image_data_format('channels_first')
K.set_learning_phase(0)
args = parse_command_line_arguments(args)
model_weights_path = args.weights
output_file_path = args.output_file
list_nodes = args.list_nodes
config = InferenceConfig()
config.display()
model = modellib.MaskRCNN(mode="inference", model_dir=LOG_DIR, config=config).keras_model
model.load_weights(model_weights_path, by_name=True)
model_A = Model(inputs=model.input, outputs=model.get_layer('mrcnn_mask').output)
model_A.summary()
output_nodes = ['mrcnn_detection', "mrcnn_mask/Sigmoid"]
convert_model(model_A, output_file_path, output_nodes, preprocessor=args.preprocessor,
text=True, list_nodes=list_nodes)
def convert_model(inference_model, output_path, output_nodes=[], preprocessor=None, text=False,
list_nodes=False):
# convert the keras model to pb
orig_output_node_names = [node.op.name for node in inference_model.outputs]
print("The output names of tensorflow graph nodes: {}".format(str(orig_output_node_names)))
sess = K.get_session()
constant_graph = graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
orig_output_node_names)
temp_pb_path = "~/temp.pb"
graph_io.write_graph(constant_graph, os.path.dirname(temp_pb_path), os.path.basename(temp_pb_path),
as_text=False)
predefined_output_nodes = output_nodes
if predefined_output_nodes != []:
trt_output_nodes = predefined_output_nodes
else:
trt_output_nodes = orig_output_node_names
# convert .pb to .uff
uff.from_tensorflow_frozen_model(
temp_pb_path,
output_nodes=trt_output_nodes,
preprocessor=preprocessor,
text=text,
list_nodes=list_nodes,
output_filename=output_path,
debug_mode = False
)
os.remove(temp_pb_path)
if __name__ == "__main__":
main()
c. $ python mrcnn_to_trt_single.py -w /path/to/data/mask_rcnn_coco_restnet50.h5 -o /path/to/data/mrcnn_nchw_resnet50.uff -p ./config.py
3. Run by https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleUffMaskRCNN
a. Code change
--- a/mrcnn_config.h
+++ b/mrcnn_config.h
@@ -57,7 +57,7 @@ static const int FPN_CLASSIF_FC_LAYERS_SIZE = 1024;
static const int TOP_DOWN_PYRAMID_SIZE = 256;
// Number of classification classes (including background)
-static const int NUM_CLASSES = 1 + 80; // COCO has 80 classes
+static const int NUM_CLASSES = 1 + 1; // COCO has 1 classes
// Length of square anchor side in pixels
static const std::vector<float> RPN_ANCHOR_SCALES = {32, 64, 128, 256, 512};
@@ -94,7 +94,7 @@ static const std::vector<std::string> CLASS_NAMES = {
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
};
-static const std::string MODEL_NAME = "mrcnn_nchw.uff";
+static const std::string MODEL_NAME = "mrcnn_nchw_resnet50.uff";
b. Run the sample
4. Deepstream deploy
https://github.com/NVIDIA-AI-IOT/deepstream_4.x_apps
User can implement mask osd referring to dsexample plugin.