I tried to minimalize the problem.
Below codes work fine with TensorRT 4 or 5 of Linux versions.
However, an error happens only with PX2 Auto Chauffeur (DRIVE OS 5.0.10.3 Linux SDK for DRIVE PX 2)
1. Simple Network and Frozen Graph (why.pb)
I defined a very minimal network. This generates why.pb
.
import tensorflow as tf
slim = tf.contrib.slim
NAME = 'why'
IMAGE_HEIGHT = 5
IMAGE_WIDTH = 5
with tf.Graph().as_default():
# very simple network
image_ph = tf.placeholder(tf.float32, [1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
net = slim.conv2d(image_ph, 3, [3, 3])
net = slim.conv2d(net, 3, [3, 3])
branches = []
for i in range(2):
with tf.variable_scope('branch_%d' % i):
net_ = slim.conv2d(net, 3, [3, 3])
net_ = tf.reshape(net_, [-1, 1])
branches.append(net_)
# just a simple plugin layer
def merge(b1):
return 0
net = tf.py_func(merge, branches, tf.float32, name="output")
net.set_shape((1))
# frozen graph
gpu_config = tf.ConfigProto(allow_soft_placement=True)
gpu_config.gpu_options.allow_growth = True
with tf.Session(config=gpu_config) as sess:
init = tf.global_variables_initializer()
sess.run(init)
""" specify tensors that I will need when doing inference """
output_names = ['output']
output_tensors = [tf.get_default_graph().get_tensor_by_name(n + ":0") for n in output_names]
graphdef = tf.get_default_graph().as_graph_def()
frozen_graph = tf.graph_util.convert_variables_to_constants(sess, graphdef, output_names)
frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
tf.train.write_graph(frozen_graph, '.', NAME + '.pb', as_text=False)
2. Graph Surgery for UFF (why.uff)
For some reason, I would like to just skip reshaping layers (branch_0/Reshape
, branch_1/Reshape
), and define my own custom output layer, which has two inputs.
Below is the graph surgery file (uff-surgery-why.py
) for UFF conversion.
import graphsurgeon as gs
import tensorflow as tf
conv_cls1 = gs.create_node("conv_cls1", op="reshape_to_4", dtype=tf.float32)
conv_cls2 = gs.create_node("conv_cls2", op="reshape_to_4", dtype=tf.float32)
output = gs.create_node("output", op="flatten", dtype=tf.float32)
namespace_plugin_map = {
"output": output,
"branch_0/Reshape": conv_cls1,
"branch_1/Reshape": conv_cls2,
}
def preprocess(dynamic_graph):
dynamic_graph.collapse_namespaces(namespace_plugin_map)
def find_nodes(name):
return [node for node in dynamic_graph._internal_graphdef.node if node.name == name]
node = dynamic_graph.find_nodes_by_name('conv_cls1')
dynamic_graph.forward_inputs(node)
node = dynamic_graph.find_nodes_by_name('conv_cls2')
dynamic_graph.forward_inputs(node)
I got UFF file by running below
convert-to-uff tensorflow --input-file why.pb -O output -p uff-surgery-why.py # TensorRT 4
convert-to-uff --input-file why.pb -O output -p uff-surgery-why.py # TensorRT 5
3. Inference
Then I tried to inference with why.uff
with C++ API. The codes runs okay with TensorRT 4 or 5 on my laptop and two desktops (all Linux). However, if I run the codes on PX 2, an error happens in UFF parser as below.
# -----------------------
# PX 2
# -----------------------
Begin parsing model...
terminate called after throwing an instance of 'std::out_of_range'
what(): _Map_base::at
Aborted (core dumped)
# -----------------------
# My desktop, laptop, etc.
# -----------------------
Begin parsing model...
_flatten
Flatten::Flatten()
Flatten::getOutputDimensions()
nbInputDims 2
--input 0, (3, 5, 5, )
--input 1, (3, 5, 5, )
End parsing model...
Begin building engine...
Flatten::configure()
nbInputs 2
--input 0, (3, 5, 5, )
--input 1, (3, 5, 5, )
nbOutputs 1
--output 0, (3, 5, 5, )
Flatten::getWorkspaceSize()
Flatten::getWorkspaceSize()
Flatten::initialize()
End building engine...
Flatten::getSerializationSize()
Flatten::serialize()
Flatten::terminate()
Flatten::~Flatten()
*** deserializing
_flatten_HL_1804289383
Flatten::Flatten()
Flatten::initialize()
engine created
batch size: 1
nbBindings: 2
size of binding 0: 75
size of binding 1: 75
----------input binding 0
----------safe malloc 0, 75
----------safe malloc 1, 75
input image copy done.
inference takes 2.68461 ms.
Could you please check this?