TensorRT 4.0 UFF parser fails to parse Keras Resnet50

Hello,

I am trying to use TensorRT 4.0.0.3 to perform inference on a Resnet50 model that I have trained in Keras (with Tensorflow backend).
I am able to freeze the tensorflow graph and convert it to uff format. But when I try to create a TensorRT Engine using the uff graph, I receive the following error:

[TensorRT] ERROR: Parameter check failed at: ../builder/Network.cpp::addInput::377, condition: isValidDims(dims)
[1]    14165 segmentation fault (core dumped)  python tmp.py

Here is the program I am trying to run:

import tensorflow as tf
from keras import backend as K
from keras.applications.resnet50 import ResNet50

import uff
import tensorrt as trt
from tensorrt.parsers import uffparser

with K.get_session() as sess:
    image_batch_t = tf.placeholder(tf.float32, shape=(None, 224, 224, 3))

    model = ResNet50(input_tensor=image_batch_t,
                    weights='imagenet',
                    include_top=False)
    K.set_learning_phase(0)
    conf_t = model(image_batch_t)
    output_names = [conf_t.name[:-2]]
    graphdef = sess.graph.as_graph_def()
    frozen_graph = tf.graph_util.convert_variables_to_constants(sess, graphdef, output_names)
    frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)

uff_model = uff.from_tensorflow(frozen_graph, output_names)

G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.ERROR)

parser = uffparser.create_uff_parser()
input_shape = (3, 224, 224)
parser.register_input("placeholder", input_shape, 1)
parser.register_output(output_names[0])
engine = trt.utils.uff_to_trt_engine(G_LOGGER,
    uff_model,
    parser,
    1,
    1 << 30)

parser.destroy()
print('success')

The whole output of this program is:

2018-04-16 15:21:41.178976: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-04-16 15:21:43.057830: W tensorflow/stream_executor/cuda/cuda_driver.cc:527] A non-primary context 0x551ea70 for device 0 exists before initializing the StreamExecutor. The primary context is now 0x55283c0. We haven't verified StreamExecutor works with that.
2018-04-16 15:21:43.392835: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1344] Found device 0 with properties:
name: GeForce GTX 1070 major: 6 minor: 1 memoryClockRate(GHz): 1.683
pciBusID: 0000:82:00.0
totalMemory: 7.92GiB freeMemory: 7.74GiB
2018-04-16 15:21:43.392877: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1423] Adding visible gpu devices: 0
2018-04-16 15:21:44.983678: I tensorflow/core/common_runtime/gpu/gpu_device.cc:911] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-04-16 15:21:44.983744: I tensorflow/core/common_runtime/gpu/gpu_device.cc:917]      0
2018-04-16 15:21:44.983753: I tensorflow/core/common_runtime/gpu/gpu_device.cc:930] 0:   N
2018-04-16 15:21:44.984160: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1041] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 7459 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1070, pci bus id: 0000:82:00.0, compute capability: 6.1)
Converted 318 variables to const ops.
Using output node resnet50/avg_pool/AvgPool
Converting to UFF graph
No. nodes: 917
[TensorRT] ERROR: Parameter check failed at: ../builder/Network.cpp::addInput::377, condition: isValidDims(dims)
[1]    14165 segmentation fault (core dumped)  python tmp.py

Library information:

  • Cuda 9.0
  • Cudnn 7.0.5
  • TensorRT 4.0.0.3
  • Tensorflow 1.7.0
  • Keras 2.1.4

Additional comments:

  • I have read that TensorRT will remove tensorflow reshape ops from the graph_def, which can lead to dimension incompatibility issues. I have checked the Keras Resnet50 graph used above, and there are no reshape ops.
  • Thanks for any help.
1 Like

I got the exactly same error after successfully generating .uff file. My script is:

import os
import subprocess
import uff
import pdb
import sys
import tensorrt as trt

frozen_graph_filename = path_to_pb_file
input_shape = (3, 300, 300)
output_name = 'resnet_v1_50/predictions/Reshape_1'
engine = trt.lite.Engine(framework="tf", # Source framework
                         path=frozen_graph_filename, # Model File
                         max_batch_size=1, # Max number of images to be processed at a time
                         input_nodes={"in": input_shape}, # Input layers
                         output_nodes=[output_name], # Ouput layers
                         preprocessors={"in": 'normalize'}, # Preprocessing functions
                         postprocessors={"out": 'argmax'})

Also checked https://github.com/NVIDIA-Jetson/tf_to_trt_image_classification, but the

5. Clone and build this project

cannot pass.

Running on Tesla M40, with

  • Python 3.5.2
  • CUDA 9.0
  • cuDNN 7.0.5
  • TensorRT 4.0.0.3
  • Tensorflow-GPU 1.7.0 (installed from source)
    .
    Could you give some help? Thanks.

I managed to solve this issue. I needed to set the tensorflow placeholders name to be the same as the name used in the register_input call on the the uff parser. Also, I think I had the incorrect dimension type. So the changes to the code above would be:

- image_batch_t = tf.placeholder(tf.float32, shape=(None, 224, 224, 3))
+ image_batch_t = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='image_tensor')

and

- parser.register_input("placeholder", input_shape, 1)
+ parser.register_input("image_tensor", input_shape, 0)

Sebasitan,

Glad you were able to resolve this.

Weizh888,

Does this resolve the issue for you?

-Chris

We created a new “Deep Learning Training and Inference” section in Devtalk to improve the experience for deep learning and accelerated computing, and HPC users:
https://devtalk.nvidia.com/default/board/301/deep-learning-training-and-inference-/

We are moving active deep learning threads to the new section.

URLs for topics will not change with the re-categorization. So your bookmarks and links will continue to work as earlier.

-Siddharth

I solved it by converting the Keras model to TF model first, then load into TRT.

Hello I also used the model from keras.applications.resnet50.ResNet50

with the problem when convert the frozen_graph into uff as following:

Using output node fc1000_1/Softmax
Converting to UFF graph
DEBUG: convert reshape to flatten node
Warning: No conversion function registered for layer: Merge yet.
Converting as custom op Merge bn5a_branch1_1/cond/Merge
name: “bn5a_branch1_1/cond/Merge”
op: “Merge”
input: “bn5a_branch1_1/cond/batchnorm/add_1”
input: “bn5a_branch1_1/cond/Switch_1:1”
attr {
key: “N”
value {
i: 2
}
}

Does anyone meet the same issues and already resolve it ? Thanks.
attr {
key: “T”
value {
type: DT_FLOAT
}
}

hi taipei11408, I got similar warnings, did you solve it?

Warning: keepdims is ignored by the UFF Parser and defaults to True
Warning: No conversion function registered for layer: Merge yet.
Converting as custom op Merge bn5a_branch1_1/cond/Merge
name: "bn5a_branch1_1/cond/Merge"
op: "Merge"
input: "bn5a_branch1_1/cond/FusedBatchNorm"
input: "bn5a_branch1_1/cond/Switch_1:1"
attr {
  key: "N"
  value {
    i: 2
  }
}
attr {
  key: "T"
  value {
    type: DT_FLOAT

}
}

Thanks