Hello,
I have a pre-trained keras model (MobileNetv2). I follow steps to convert the keras model into a tensorflow graph(.pb) and then reload the graph during inference.
My code looks like this:
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import pdb
import os
import os.path as osp
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
from tensorflow.python.framework import tensor_util
...
...
...
...
with gfile.FastGFile("Path/to/.pb/file",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# sess.graph.as_default()
# g_in = tf.import_graph_def(graph_def)
output_names = 'import/dense/Softmax:0'
trt_graph = trt.create_inference_graph(
input_graph_def=graph_def,
outputs=output_names,
max_batch_size=1,
max_workspace_size_bytes=1 << 15,
precision_mode='FP16',
minimum_segment_size=10
)
tf.import_graph_def(trt_graph, name='')
# write to tensorboard (check tensorboard for each op names)
writer = tf.summary.FileWriter("Path/to/logs/folder")
writer.add_graph(sess.graph)
writer.flush()
writer.close()
tensor_output = sess.graph.get_tensor_by_name('import/dense/Softmax:0')
tensor_input = sess.graph.get_tensor_by_name('import/mobilenetv2_1.00_224_input:0')
When I try to import the graph_def I can see the graph on tensorboard and there isn’t any error.
But when I run create_inference_graph as shown above I get the following error:
2019-06-28 09:29:06.387102: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 2942 MB memory) -> physical GPU (device: 0, name: NVIDIA Tegra X2, pci bus id: 0000:00:00.0, compute capability: 6.2)
2019-06-28 09:29:07.813229: E tensorflow/core/grappler/grappler_item_builder.cc:321] Invalid fetch node name skipping this input
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/grappler/tf_optimizer.py", line 43, in OptimizeGraph
verbose, graph_id, status)
SystemError: <built-in function TF_OptimizeGraph> returned NULL without setting an error
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "create_inference_graph_v3_inference.py", line 75, in <module>
inference()
File "create_inference_graph_v3_inference.py", line 35, in inference
trt_graph = trt.create_inference_graph(
File "/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tensorrt/python/trt_convert.py", line 364, in create_inference_graph
session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph")
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/grappler/tf_optimizer.py", line 43, in OptimizeGraph
verbose, graph_id, status)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Failed to import metagraph, check error log for more info.
I am not sure why the graph extraction is wrong because if I simply do:
g_in = tf.import_graph_def(graph_def)
it works and I don’t see any problem.
I am using:
Jetson TX2 for inference flashed with JetPack 4.2.
Tensorflow 1.13.1
TensorRT 5.0
Looking forward for suggestions.
Appreciate the help! Thank you in advance!
Regards,
T