I am getting the same error using the C++ with TensorRT 5 RC with CUDA 10.0, also we are using Tensorflow 1.11 (attempted 1.9, but would have to downgrade CUDA to 9.0)
Parameter check failed at: Utils.cpp::reshapeWeights::70, condition: input.values != nullptr
UFFParser: Parser error: model/batch_normalization/Const: reshape weights failed!
I had made transformations from the frozen pb model to prune the network and remove layers associated with training. I tried using the non-pruned version of the network, but that resulted in unsupported layers errors such as Switch.
We are now testing loading the pruned network back into to Tensorflow to see if it is able to load the UFF model.
Is there a suggested way to prune a network or design the Tensorflow network to enable easy integration to inference with TRT?
Also, this is stemming from using some version of batch normalization from Tensorflow. If I remove the batch normalization layer, then the network is able to be loaded and evaluated in TRT. However, without batch normalization there is a massive hit in accuracy.
Edit:
Here is a more detailed summary of what we have tried:
# Tensorflow to TensorRT
## Goal
Train fully convolutional network in Tensorflow, convert the frozen saved model to Tensorflow for production inference.
## Configuration
- Ubuntu 18.04
- Cuda 10.0
- CuDNN 7.3.1
- NCCL 2.3.5
- Tensorflow 1.11 within a Conda virtual environment
- TensorRT 5.0 RC
- Python 3.6
Conda Environment
(tensorflow) mmajursk@N116117:~/Downloads/uff-testing$ conda list
packages in environment at /home/mmajursk/anaconda3/envs/tensorflow:
Name Version Build Channel
_tflow_190_select 0.0.1 gpu
_tflow_select 2.1.0 gpu
absl-py 0.3.0 py36_0
astor 0.7.1 py36_0
blas 1.0 mkl
bzip2 1.0.6 h14c3975_5
ca-certificates 2018.03.07 0
caffe2 0.8.dev py36_2018.08.07 caffe2
cairo 1.14.12 h8948797_3
certifi 2018.10.15 py36_0
cudatoolkit 9.0 h13b8566_0
cudnn 7.1.2 cuda9.0_0
cupti 9.0.176 0
ffmpeg 3.4 h7985aa0_0
fontconfig 2.13.0 h9420a91_0
freetype 2.9.1 h8a8886c_0
future 0.16.0 py36_0
gast 0.2.0 py36_0
gflags 2.2.1 hf484d3e_0
glib 2.56.1 h000015b_0
glog 0.3.5 hf484d3e_1
graphite2 1.3.11 h16798f4_2
graphsurgeon 0.2.2
grpcio 1.12.1 py36hdbcaa40_0
h5py 2.7.1 py36h3585f63_0
harfbuzz 1.8.4 hec2c2bc_0
hdf5 1.10.1 h9caa474_1
horovod 0.13.11
icu 58.2 h9c2bf20_1
intel-openmp 2018.0.3 0
jasper 1.900.1 hd497a04_4
jpeg 9b h024ee3a_2
keras-applications 1.0.6 py36_0
keras-preprocessing 1.0.5 py36_0
libedit 3.1.20170329 h6b74fdf_2
libffi 3.2.1 hd88cf55_4
libgcc-ng 8.2.0 hdf63c60_1
libgfortran-ng 7.2.0 hdf63c60_3
libopus 1.2.1 hb9ed12e_0
libpng 1.6.34 hb9fc6fc_0
libprotobuf 3.6.0 hdbcaa40_0
libstdcxx-ng 8.2.0 hdf63c60_1
libtiff 4.0.9 he85c1e1_1
libuuid 1.0.3 h1bed415_2
libvpx 1.7.0 h439df22_0
libxcb 1.13 h1bed415_1
libxml2 2.9.8 h26e45fe_1
markdown 2.6.11 py36_0
mkl 2018.0.3 1
mkl-include 2018.0.3 1
mkl_fft 1.0.4 py36h4414c95_1
mkl_random 1.0.1 py36h4414c95_1
ncurses 6.1 hf484d3e_0
numpy 1.14.3 py36h28100ab_2
numpy-base 1.15.0 py36h3dfced4_0
onnx 1.3.0
onnx-tf 1.2.0
opencv 3.3.1 py36h0a11808_0
openssl 1.0.2p h14c3975_0
pcre 8.42 h439df22_0
pip 10.0.1 py36_0
pixman 0.34.0 hceecf20_3
protobuf 3.6.0 py36hf484d3e_0
python 3.6.6 hc3d631a_0
python-lmdb 0.94 py36h14c3975_0
PyYAML 3.13
readline 7.0 ha6073c6_4
scipy 1.1.0 py36hd20e5f9_0
setuptools 39.2.0 py36_0
six 1.11.0 py36_1
sqlite 3.24.0 h84994c4_0
tensorboard 1.11.0 py36hf484d3e_0
tensorflow 1.11.0 gpu_py36h4459f94_0
tensorflow-base 1.11.0 gpu_py36h8e0ae2d_0
tensorflow-gpu 1.11.0 h0d30ee6_0
tensorrt 5.0.0.10
termcolor 1.1.0 py36_1
tk 8.6.7 hc745277_3
typing 3.6.4 py36_0
typing-extensions 3.6.6
uff 0.5.1
werkzeug 0.14.1 py36_0
wheel 0.31.1 py36_0
xz 5.2.4 h14c3975_4
zlib 1.2.11 ha838bed_2
## Model Architecture
Yolo v3 Region Proposal Network (no bounding box regresssion)
This is a fully convolutional neural network with residual connections between certain layers.
### Inference layers
This model uses the following Tensorflow layers during the inference pass of the model:
tf.layers.conv2d
tf.layers.batch_normalization
tf.add
I have tried both fused and not within the batch normalization.
tf.layers.batch_normalization(fused=True)
tf.layers.batch_normalization(fused=False)
I have since defaulted back to using Tensorflows default for `batch_normalization(fused=True)`
### Training Layers
The following additional layers are used during training:
tf.transpose
tf.reshape
tf.nn.softmax_cross_entropy_with_logits_v2
tf.reduce_mean
tf.train.AdamOptimizer(learning_rate).minimize(loss)
All the source code to construct the CNN and train it is available in the src folder. The code to generate the database, and the database itself are not available.
## Model Training
This model is trainable wiht both `batch_normalization` configurations.
The model trains to convergence on a small test dataset (not the full dataset to ensure the model training procedure works correctly).
During training I am saving Tensorflow checkpoint files after each epoch.
After training completes I save the frozen protobuf form of the model using the following code:
save the graph_def
proto_file = os.path.join(frozen_model_folder, ‘graph.pb’)
with open(proto_file, “wb”) as file:
graph = tf.get_default_graph().as_graph_def(add_shapes=True)
file.write(graph.SerializeToString())
Convert the checkpoint file (a valid checkpoint) and the graph_def into a frozen model
from tensorflow.python.tools import freeze_graph
output_graph_filepath = os.path.join(frozen_model_folder, ‘frozen_graph.pb’)
freeze_graph.freeze_graph(input_graph=proto_file,
input_saver=“”,
input_checkpoint=checkpoint_filepath,
checkpoint_version=2,
output_graph=output_graph_filepath,
output_node_names=‘model/feature_map_NCHW/BiasAdd’,
input_binary=True,
restore_op_name=“unused”,
filename_tensor_name=“unused”,
clear_devices=True,
initializer_nodes=“”)
At the same time I use the uff conversion tool to create the uff form of the same model, but more on that later.
## Evaluate the Frozen Model
Tensorflow provides some nice tools for examining the frozen model.
I summarized the graph and found the frozen protobuf uses the following layers:
(tensorflow) mmajursk@N116117:~/Downloads/uff-testing$ /home/mmajursk/Programs/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/frozen_graph.pb
Found 1 possible inputs: (name=inputs, type=float(1), shape=[?,1,512,512])
No variables spotted.
Found 1 possible outputs: (name=model/feature_map_NCHW/BiasAdd, op=BiasAdd)
Found 40604310 (40.60M) const parameters, 0 (0) variable parameters, and 0 control_edges
Op types used: 366 Const, 210 Identity, 53 BiasAdd, 53 Conv2D, 52 FusedBatchNorm, 52 Maximum, 52 Mul, 23 Add, 1 Placeholder
To use with tensorflow/tools/benchmark:benchmark_model try these arguments:
bazel run tensorflow/tools/benchmark:benchmark_model – --graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/frozen_graph.pb --show_flops --input_layer=inputs --input_layer_type=float --input_layer_shape=-1,1,512,512 --output_layer=model/feature_map_NCHW/BiasAdd
All these layers are supported by TensorRT according to the [Developer Guide](https://docs.nvidia.com/deeplearning/sdk/pdf/TensorRT-Developer-Guide.pdf)
If I import this frozen model into Tensorflow I can sucussfully use it to inference new data using the following import code:
Load the Frozen Model
print(‘Loading trained CNN model’)
detection_graph = tf.Graph()
with detection_graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(frozen_model_file, ‘rb’) as fid:
graph_def.ParseFromString(fid.read())
tf.import_graph_def(graph_def, name=‘’)
# Get the input and output nodes (there are two outputs)
net_input = detection_graph.get_tensor_by_name('inputs:0')
net_output = detection_graph.get_tensor_by_name('model/feature_map_NCHW/BiasAdd:0')
# initialize_all_variables
tf.global_variables_initializer()
Load the Tensorflow model into memory.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(0)
with tf.Session(config=config, graph=detection_graph) as sess:
# code to load and inference new images goes here
### Tensorflow transform_graph
Tensorflow has a tool called transform_graph which can optimize the graph for inference.
If I apply that tool, the number of Ops is reduced using the following code:
(tensorflow) mmajursk@N116117:~/Downloads/uff-testing$ /home/mmajursk/Programs/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
–in_graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/frozen_graph.pb
–out_graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.pb
–inputs=‘inputs’
–outputs=‘model/feature_map_NCHW/BiasAdd’
–transforms=’
sort_by_execution_order
strip_unused_nodes(type=float, shape=“1,1,1024,1024”)
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
remove_control_dependencies
merge_duplicate_nodes
’
2018-10-22 15:05:08.122193: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying sort_by_execution_order
2018-10-22 15:05:08.386022: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying strip_unused_nodes
2018-10-22 15:05:08.693232: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying remove_nodes
2018-10-22 15:05:09.602897: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying fold_constants
2018-10-22 15:05:09.911632: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying fold_batch_norms
2018-10-22 15:05:10.059284: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying fold_old_batch_norms
2018-10-22 15:05:10.510293: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying remove_control_dependencies
2018-10-22 15:05:10.587677: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying merge_duplicate_nodes
Which produces the following graph_summary:
(tensorflow) mmajursk@N116117:~/Downloads/uff-testing$ /home/mmajursk/Programs/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.pb
Found 1 possible inputs: (name=inputs, type=float(1), shape=[?,1,512,512])
No variables spotted.
Found 1 possible outputs: (name=model/feature_map_NCHW/BiasAdd, op=BiasAdd)
Found 40604259 (40.60M) const parameters, 0 (0) variable parameters, and 0 control_edges
Op types used: 212 Const, 53 BiasAdd, 53 Conv2D, 52 FusedBatchNorm, 52 Maximum, 52 Mul, 23 Add, 1 Placeholder
To use with tensorflow/tools/benchmark:benchmark_model try these arguments:
bazel run tensorflow/tools/benchmark:benchmark_model – --graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.pb --show_flops --input_layer=inputs --input_layer_type=float --input_layer_shape=-1,1,512,512 --output_layer=model/feature_map_NCHW/BiasAdd
If I import this frozen model into Tensorflow I can sucussfully use it to inference new data using the same code as above, just loading transformed_frozen_graph.pb instead of frozen_graph.pb
## Convert to UFF
### Frozen Model
After I save the frozen model, within the same script I convert that newly saved .pb file into UFF format using the following code:
import uff
uff_graph = uff.from_tensorflow_frozen_model(output_graph_filepath)
output_uff_graph_filepath = os.path.join(frozen_model_folder, ‘frozen_graph.uff’)
with open(output_uff_graph_filepath, “wb”) as file:
file.write(uff_graph)
I have found no meaningful difference between using the above API and this convert to uff tool:
convert-to-uff /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/frozen_graph.pb -o /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/frozen_graph.uff -O “model/feature_map_NCHW/BiasAdd”
The conversion to UFF produces the following logging:
=== Automatically deduced input nodes ===
[name: “inputs”
op: “Placeholder”
attr {
key: “_output_shapes”
value {
list {
shape {
dim {
size: -1
}
dim {
size: 1
}
dim {
size: 512
}
dim {
size: 512
}
}
}
}
}
attr {
key: “dtype”
value {
type: DT_FLOAT
}
}
attr {
key: “shape”
value {
shape {
dim {
size: -1
}
dim {
size: 1
}
dim {
size: 512
}
dim {
size: 512
}
}
}
}
]
=== Automatically deduced output nodes ===
[name: “model/feature_map_NCHW/BiasAdd”
op: “BiasAdd”
input: “model/feature_map_NCHW/Conv2D”
input: “model/feature_map_NCHW/bias”
attr {
key: “T”
value {
type: DT_FLOAT
}
}
attr {
key: “_output_shapes”
value {
list {
shape {
dim {
size: -1
}
dim {
size: 2
}
dim {
size: 16
}
dim {
size: 16
}
}
}
}
}
attr {
key: “data_format”
value {
s: “NCHW”
}
}
]
Using output node model/feature_map_NCHW/BiasAdd
Converting to UFF graph
No. nodes: 653
### Transformed Frozen Model
In addition to relying on the non-optimized frozen_graph, I convert the transformed graph into uff as well using the following code:
(tensorflow) mmajursk@N116117:~/Downloads/uff-testing$ convert-to-uff /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.pb -o /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.uff -O “model/feature_map_NCHW/BiasAdd”
/home/mmajursk/anaconda3/envs/tensorflow/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
/home/mmajursk/anaconda3/envs/tensorflow/lib/python3.6/site-packages/h5py/init.py:36: FutureWarning: Conversion of the second argument of issubdtype from float
to np.floating
is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type
.
from ._conv import register_converters as _register_converters
Loading /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.pb
=== Automatically deduced input nodes ===
[name: “inputs”
op: “Placeholder”
attr {
key: “_output_shapes”
value {
list {
shape {
dim {
size: -1
}
dim {
size: 1
}
dim {
size: 512
}
dim {
size: 512
}
}
}
}
}
attr {
key: “dtype”
value {
type: DT_FLOAT
}
}
attr {
key: “shape”
value {
shape {
dim {
size: -1
}
dim {
size: 1
}
dim {
size: 512
}
dim {
size: 512
}
}
}
}
]
Using output node model/feature_map_NCHW/BiasAdd
Converting to UFF graph
No. nodes: 499
UFF Output written to /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.uff
Just as before, the conversion process produces no errors.
### UFF Import in C++
The UFF parser in C++ throws an error when loading the model:
W1019 12:36:32.058586 39879 StarInferenceTask.cpp:50] Parameter check failed at: Utils.cpp::reshapeWeights::70, condition: input.values != nullptr
W1019 12:36:32.059023 39879 StarInferenceTask.cpp:50] UFFParser: Parser error: model/batch_normalization_51/Const_1: reshape weights failed!
This happens regardless of whether we are using:
- frozen_graph.pb covnverted to uff
- transformed_frozen_graph.pb converted to uff
## UFF Import Debugging
The import error appears to stem from the fact that the Const named "model/batch_normalization_51/Const_1" is null or empty.
If I export the full set of layers from transformed_frozen_graph.pb, I can see that "model/batch_normalization_51/Const_1" does infact not have a value.
This commmand:
/home/mmajursk/Programs/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph.pb --print_structure=true > /home/mmajursk/MITS-PIP/cnn/data/tmp/frozen-model/transformed_frozen_graph_structure.txt
Generates a file containing the following information (along with several hundred other lines, one per Op).
model/batch_normalization_51/Const_1 (Const): , value=Tensor<type: float shape: [0] values: >
The Op "model/batch_normalization_51/Const_1" is used many places within the transformed graph, due to the "merge_duplicate_nodes" option provided to the graph_transform tool.
In context, this node appears to be the 4th and 5th parameter input into batch_normalization.
Here are the last three layers of the network (3 conv_2d, 2 of them with batch_normalization):
model/conv2d_50/Conv2D (Conv2D): [model/Add_21, model/conv2d_50/kernel]
model/conv2d_50/BiasAdd (BiasAdd): [model/conv2d_50/Conv2D, model/conv2d_50/bias]
model/conv2d_50/LeakyRelu/mul (Mul): [model/conv2d_51/LeakyRelu/alpha, model/conv2d_50/BiasAdd]
model/conv2d_50/LeakyRelu (Maximum): [model/conv2d_50/LeakyRelu/mul, model/conv2d_50/BiasAdd]
model/batch_normalization_50/FusedBatchNorm (FusedBatchNorm): [model/conv2d_50/LeakyRelu, model/batch_normalization_50/gamma, model/batch_normalization_50/beta, model/batch_normalization_51/Const_1, model/batch_normalization_51/Const_1]
model/conv2d_51/Conv2D (Conv2D): [model/batch_normalization_50/FusedBatchNorm, model/conv2d_51/kernel]
model/conv2d_51/BiasAdd (BiasAdd): [model/conv2d_51/Conv2D, model/conv2d_51/bias]
model/conv2d_51/LeakyRelu/mul (Mul): [model/conv2d_51/LeakyRelu/alpha, model/conv2d_51/BiasAdd]
model/conv2d_51/LeakyRelu (Maximum): [model/conv2d_51/LeakyRelu/mul, model/conv2d_51/BiasAdd]
model/batch_normalization_51/FusedBatchNorm (FusedBatchNorm): [model/conv2d_51/LeakyRelu, model/batch_normalization_51/gamma, model/batch_normalization_51/beta, model/batch_normalization_51/Const_1, model/batch_normalization_51/Const_1]
model/Add_22 (Add): [model/batch_normalization_43/FusedBatchNorm, model/batch_normalization_51/FusedBatchNorm]
model/feature_map_NCHW/Conv2D (Conv2D): [model/Add_22, model/feature_map_NCHW/kernel]
model/feature_map_NCHW/BiasAdd (BiasAdd): [model/feature_map_NCHW/Conv2D, model/feature_map_NCHW/bias]
## Conclusions
This transformed_frozen_graph.pb can be imported into tensorflow and used to perform inference. The numeric results are nonsense because the network was only trained for 1 epoch on a very small dataset. Once the machinery to convert a non-sense model into UFF and TensorRT is proven, I will spend the effort to construt a well trained model in Tensorflow.
Therefore, the problem seems to be with 1 of 2 things:
1) How the transformed_frozen_graph.pb is converted into UFF, causing the Op "model/batch_normalization_51/Const_1" to be set to null/empty.
- Since the value is missing in the transformed_frozen_graph.pb, Tensorflow is obviously not using those 2 Op inputs when performing inference, or it would also throw an error.
- When you examine the frozen_model.pb (without graph transformations) the same layer is a problem, and it is also null/empty. There are just many more like it with different names. All the graph transformation did was compress them all into a single Const
2) How the UFF model is loaded into TensorRT. If the empty value of "model/batch_normalization_51/Const_1" is not intialized into something reasonable, like maybe 0, the C++ TensorRT UFF parser cannot create the correct graph representation in memory.