UFF deprecation - Are there good alternatives for the TensorFlow workflow?

Hi,

After the UFF deprecation notice, we have two options for working with TensorFlow:

  • Use TF2ONNX. There’s a number of unsupported layers that TensorRT does support. Example: ConvTranspose2D is supported by TensorRT and not by TF2ONNX.

  • Use TF-TRT. This requires installing TensorFlow in Xavier to be able to generate a .plan file to deploy onto a Xavier. Is my understanding correct? Installing TensorFlow on Xavier is not exactly straightforward, is this really the recommended workflow?

It feels UFF is deprecated without providing an alternative that is realiable and easy to work with.

Hi,

I don’t have a ton of experience with Tensorflow, so please correct this if it looks wrong.

I tried to make a simple example to verify if tf2onnx can handle ConvTranspose2D.

  1. Create simple model with conv2d_transpose op:
pip install tensorflow==1.14
# repro.py

import numpy as np
import tensorflow as tf

# https://datascience.stackexchange.com/questions/26451/how-to-calculate-the-output-shape-of-conv2d-transpose
x = tf.placeholder(dtype=tf.float32, shape=(None, 7, 7, 32))
dcout = tf.layers.conv2d_transpose(x, 64, 4, 3, padding="valid")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    xin = np.random.rand(1,7,7,32)
    out = sess.run(dcout, feed_dict={x:xin})
    print(out.shape)

    # freeze the graph so that it can be converted to onnx
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph.as_graph_def(),
        [n.name for n in tf.get_default_graph().as_graph_def().node])

    # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091
    for node in output_graph_def.node:
        print(node.name, node.op)
        if node.op == "Assign":
            node.op = "Identity"
            if 'use_locking' in node.attr: del node.attr['use_locking']
            if 'validate_shape' in node.attr: del node.attr['validate_shape']
            if len(node.input) == 2:
                # input0: ref: Should be from a Variable node. May be uninitialized.
                # input1: value: The value to be assigned to the variable.
                node.input[0] = node.input[1]
                del node.input[1]

    output_graph = "model_tf_test.pb"
    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
python3 repro.py
  1. Convert to ONNX with tf2onnx:
pip install onnx==1.6.0 tf2onnx
$ python3 -m tf2onnx.convert --input model_tf_test.pb --inputs Placeholder:0 --outputs conv2d_transpose/BiasAdd:0 --opset=11 --output conv2dtranspose.opset11.onnx
...
2020-02-11 18:51:27,976 - INFO - Using tensorflow=1.14.0, onnx=1.6.0, tf2onnx=1.5.4/ee756a
2020-02-11 18:51:27,976 - INFO - Using opset <onnx, 11>
2020-02-11 18:51:28,014 - INFO - Optimizing ONNX model
2020-02-11 18:51:28,078 - INFO - After optimization: Const -28 (39->11), Gather +1 (0->1), Identity -1 (1->0), Squeeze -1 (3->2), Unsqueeze -2 (4->2)
2020-02-11 18:51:28,081 - INFO - 
2020-02-11 18:51:28,081 - INFO - Successfully converted TensorFlow model model_tf_test.pb to ONNX
2020-02-11 18:51:28,082 - INFO - ONNX model is saved at conv2dtranspose.opset11.onnx
  1. Convert ONNX to TensorRT
$ dpkg -l | grep -i tensorrt
...
ii  tensorrt                    7.0.0.11-1+cuda10.2               amd64        Meta package of TensorRT

$ trtexec --explicitBatch --onnx=conv2dtranspose.onnx 
...
&&&& PASSED TensorRT.trtexec # trtexec --explicitBatch --onnx=conv2dtranspose.opset11.onnx

The above was all using TensorRT 7.0 and ONNX opset 11 on an x86 machine.

I noticed you mentioned Xavier, so I tried to reproduce using TensorRT 6.0 and ONNX opset 9/10.

Seems like TensorRT 6 doesn’t support some of these ops, even with building the OSS ONNX parser release 19.12:

$ dpkg -l | grep -i tensorrt
ii  tensorrt                    6.0.1.8-1+cuda10.2                amd64        Meta package of TensorRT

$ python3 -m tf2onnx.convert --input model_tf_test.pb --inputs Placeholder:0 --outputs conv2d_transpose/BiasAdd:0 --opset=10 --output=conv2dtranspose.opset10.onnx

$ trtexec --onnx=conv2dtranspose.opset10.onnx 
...
WARNING: ONNX model has a newer ir_version (0.0.6) than this parser was built against (0.0.3).
While parsing node number 1 [ConvTranspose]:
ERROR: ModelImporter.cpp:296 In function importModel:
[5] Assertion failed: tensors.count(input_name)

$ bash /opt/tensorrt/install_opensource.sh 
...
Done!

$ trtexec --explicitBatch --onnx=conv2dtranspose.opset10.onnx 
...
ERROR: /opt/tensorrt/TensorRT/parsers/onnx/builtin_op_importers.cpp:286 In function importCast:
[8] Assertion failed: trt_dtype == nvinfer1::DataType::kHALF && cast_dtype == ::ONNX_NAMESPACE::TensorProto::FLOAT
[01/11/2020-18:58:30] [E] Failed to parse onnx file
[01/11/2020-18:58:30] [E] Parsing model failed
[01/11/2020-18:58:30] [E] Engine could not be created
&&&& FAILED TensorRT.trtexec # trtexec --explicitBatch --onnx=conv2dtranspose.opset9.onnx

Which seems to have been since changed/fixed in TensorRT 7.

TRT6 fails here: onnx-tensorrt/builtin_op_importers.cpp at 397cdbafd898153f69a2b1a87dcc1c4dc5add18b · onnx/onnx-tensorrt · GitHub

TRT7 fixed here: https://github.com/onnx/onnx-tensorrt/blob/84b5be1d6fc03564f2c0dba85a2ee75bad242c2e/builtin_op_importers.cpp#L313

So you might just have to wait for TensorRT >= 7 on Xavier for this, but I can look into if there is a workaround or not.


If you have a model that is not working, please share it so we can investigate what’s unsupported and perhaps make a feature request to tf2onnx if necessary.

Thanks for the detailed explanation!

According to tf2onnx, Conv2DTranspose is not in the list of officially supported layers:

So it could be that support is still not mature. We currently experience issues converting our Conv2DTranspose layers to ONNX. Will have a look at the bug you mentioned in your code to see if it solves the issue. Otherwise I’ll open a ticket at tf2onnx.

Still I do believe Nvidia should have verified that there’s a 1:1 officially supported mapping between TensorRT and ONNX before deprecating UFF :)

According to tf2onnx, Conv2DTranspose is not in the list of officially supported layers:

Seems like that doc has only been updated ~once since July. Maybe that make some simplifications/assumptions to avoid enumerating every op, such as listing “Conv2D” and “Transpose”, but not “Conv2dTranspose”, but not sure.

Still I do believe Nvidia should have verified that there’s a 1:1 officially supported mapping between TensorRT and ONNX before deprecating UFF :)

I agree, however I do think the ONNX parser provides a much wider range of support compared to the UFF parser. I think it sometimes comes down to waiting on others to catch up in terms of converting model ops to ONNX (pytorch, tf2onnx, keras2onnx, etc.).

The UFF parser should still work for now as it did before if you require it for your workflow, it’s just not getting the same attention as ONNX, doesn’t support dynamic shape, and isn’t recommended to continue to use moving forward.

The ONNX parser on the other hand is open-source so you can better understand what’s happening when things fail, contribute new ops, fork, debug, see the code, etc.: GitHub - onnx/onnx-tensorrt: ONNX-TensorRT: TensorRT backend for ONNX

Feel free to include link your tensorflow-onnx github issue here if you raise one.

I do agree ONNX is the way forward given the number of supported layers. I just wish there was a smooth transition :)

I’m trying to reproduce the code you posted above, and it works fine. However in our project we are using instead:

conv =  tf.compat.v1.layers.Conv2DTranspose(
            filters=16,
            kernel_size=3,
            strides=2,
            padding='same',
            data_format='channels_first',
            name='conv_transposed')

dcout = conv(x)

When importing the generated protobuf, I’m getting similar errors as you patched with “Assign”, but with “IsVariableInitialized”

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node import/conv_transposed/IsVariableInitialized was passed float from import/conv_transposed/bias:0 incompatible with expected float_ref.

Are you familiar with this error? I understand this is a TensorFlow issue so I don’t expect to solve this here :) But it blocks us from trying the proposed solution.

I opened a tensorflow-onnx issue: https://github.com/onnx/tensorflow-onnx/issues/801

Is it possible for you to use the TensorFlow Keras API? I have successfully used https://github.com/onnx/keras-onnx to convert a network with Conv2DTranspose to ONNX, and then used onnx2trt to successfully create and deploy a TensorRT inference engine. I was never able to get tf2onnx to work and gave up on it eventually, I’ve been very pleased with keras2onnx.