Converting TensorFlow autoencoder decoder to TensorRT engine via UFF

I have a trained TensorFlow (Keras) network which I am attempting to convert into a TensorRT engine for inference with the TensorRT C++ API, via the intermediate UFF format. The network architecture is a little tricky: it’s only the decoder half of a 2D convolutional autoencoder. Below is the code snippet I use to build the network in TensorFlow (Keras).

def build_CAE(featShape,numConvLayers,latentSize,
			  kernelInitDist,biasInitDist,
			  encodeKernelList,encodeStrideList,encodeFilterList,encodePaddingList,encodeActivationList,
			  decodeKernelList,decodeStrideList,decodeFilterList,decodePaddingList,decodeActivationList,
			  denseActivation):
		
	inputEncoder = Input(shape=featShape,name='inputEncode')

        # convolutional half of autoencoder
	def build_encoder():
	
		x = inputEncoder

                # convolutional layers
		for convNum in range(numConvLayers):
			x = Conv2D(filters=encodeFilterList[convNum],
                                   kernel_size=encodeKernelList[convNum],
                                   strides=encodeStrideList[convNum],
                                   padding=encodePaddingList[convNum],
                                   activation=encodeActivationList[convNum],
				   kernel_initializer=kernelInitDist,
                                   bias_initializer=biasInitDist,
                                   name='conv'+str(convNum))(x)

                # flatten for input to dense layer
		shape_before_flatten = x.shape.as_list()[1:]
		x = Flatten(name='flatten')(x)

                # dense layer
		x = Dense(latentSize,
                  activation=denseActivation,
                  kernel_initializer=kernelInitDist,
                  bias_initializer=biasInitDist,
                  name='fcnConv')(x)

		return Model(inputEncoder,x),shape_before_flatten

        # build encoder, save tensor dimensions before flattening
	encoder, shape_before_flatten = build_encoder() 
	dim_before_flatten = np.prod(shape_before_flatten)

        # decoder half of autoencoder
	def build_decoder(): 

                # input from latent space
		inputDecoder = Input(shape=encoder.layers[-1].output_shape[1:])
		x = inputDecoder

                # dense layer
		x = Dense(dim_before_flatten,
                  activation=denseActivation,
                  kernel_initializer=kernelInitDist,
                  bias_initializer=biasInitDist,
				  name='fcnDeconv')(x)

                # reshape to appropriate tensor dimensions for input to transpose convolution layers
		x = Reshape(target_shape=shape_before_flatten,name='reshapeConv')(x)

		for deconvNum in range(numConvLayers):
			x = Conv2DTranspose(filters=decodeFilterList[deconvNum],
                                            kernel_size=decodeKernelList[deconvNum],
                                            strides=decodeStrideList[deconvNum],
                                            padding=decodePaddingList[deconvNum],
                                            activation=decodeActivationList[deconvNum],
					    kernel_initializer=kernelInitDist,
                                            bias_initializer=biasInitDist,
                                            name='conv'+str(deconvNum))(x)

		return Model(inputDecoder,x)

        # build decoder
	decoder = build_decoder()

        # build autoencoder from encoder and decoder 
	return Model(inputEncoder,decoder(encoder(inputEncoder)))

For a CAE with four (transpose) convolutional layers and ELU activations at all but the last layer, I am able to extract the encoder and decoder halves of the network from the layer objects of the final convolutional autoencoder, and save them normally. The frozen graph, from input to output, has the form:

x
model_1/fcnDeconv/MatMul/ReadVariableOp/resource
model_1/fcnDeconv/MatMul/ReadVariableOp
model_1/fcnDeconv/MatMul
model_1/fcnDeconv/BiasAdd/ReadVariableOp/resource
model_1/fcnDeconv/BiasAdd/ReadVariableOp
model_1/fcnDeconv/BiasAdd
model_1/fcnDeconv/Elu
model_1/reshapeConv/Shape
model_1/reshapeConv/strided_slice/stack
model_1/reshapeConv/strided_slice/stack_1
model_1/reshapeConv/strided_slice/stack_2
model_1/reshapeConv/strided_slice
model_1/reshapeConv/Reshape/shape/1
model_1/reshapeConv/Reshape/shape/2
model_1/reshapeConv/Reshape/shape/3
model_1/reshapeConv/Reshape/shape
model_1/reshapeConv/Reshape
model_1/conv0/Shape
model_1/conv0/strided_slice/stack
model_1/conv0/strided_slice/stack_1
model_1/conv0/strided_slice/stack_2
model_1/conv0/strided_slice
model_1/conv0/strided_slice_1/stack
model_1/conv0/strided_slice_1/stack_1
model_1/conv0/strided_slice_1/stack_2
model_1/conv0/strided_slice_1
model_1/conv0/mul/y
model_1/conv0/mul
model_1/conv0/strided_slice_2/stack
model_1/conv0/strided_slice_2/stack_1
model_1/conv0/strided_slice_2/stack_2
model_1/conv0/strided_slice_2
model_1/conv0/mul_1/y
model_1/conv0/mul_1
model_1/conv0/stack/3
model_1/conv0/stack
model_1/conv0/conv2d_transpose/ReadVariableOp/resource
model_1/conv0/conv2d_transpose/ReadVariableOp
model_1/conv0/conv2d_transpose
model_1/conv0/BiasAdd/ReadVariableOp/resource
model_1/conv0/BiasAdd/ReadVariableOp
model_1/conv0/BiasAdd
model_1/conv0/Elu
##### etc. for conv1, conv2, conv3 ######
model_1/conv3/conv2d_transpose/ReadVariableOp/resource
model_1/conv3/conv2d_transpose/ReadVariableOp
model_1/conv3/conv2d_transpose
model_1/conv3/BiasAdd/ReadVariableOp/resource
model_1/conv3/BiasAdd/ReadVariableOp
model_1/conv3/BiasAdd
Identity

My trouble initially stemmed from the fact that the input to the decoder is two-dimensional (batch size and latent space dimension). It’s the latent variables output from the encoder, fed through a Dense layer. In TensorFlow (Keras), it’s standard that the output of a Dense layer has the form (batch_size, units). I am able to successfully convert the saved Keras model to UFF through a TensorFlow frozen graph. However, when I attempt to use the TensorRT C++ API to generate a TRT engine, I am greeted by the following message:

[02/10/2020-19:55:57] [E] [TRT] model_1/fcnDeconv/MatMul: at least 3 dimensions are required for input
[02/10/2020-19:55:57] [E] [TRT] model_1/fcnDeconv/MatMul: at least 3 dimensions are required for input
[02/10/2020-19:55:57] [E] [TRT] model_1/fcnDeconv/MatMul: at least 3 dimensions are required for input
[02/10/2020-19:55:57] [E] [TRT] UffParser: Parser error: model_1/fcnDeconv/BiasAdd: The input to the Scale Layer is required to have a minimum of 3 dimensions.
[02/10/2020-19:55:57] [E] [TRT] Network must have at least one output
[02/10/2020-19:55:57] [E] [TRT] Network validation failed.

After some snooping, I discovered from the TensorRT C++ API that indeed, IFullyConnectedLayer requires three or more non-batch dimensions (e.g. CHW). Okay, I’ll just reshape the encoder output shape such that it has three dimensions, two of which are singleton dimensions. I add the following code after the encoder Dense layer:

x = Reshape((1,1,latentSize))(x)

Training goes fine, I get the same results (when running the Keras model) as before. However, now I cannot even convert the Keras model to UFF! convert_to_uff throws the following error:

uff.model.exceptions.UffException: Transpose permutation has op ConcatV2, expected Const. Only constant permuations are supported in UFF.

The frozen graph for this network has the form:

x
model_1/fcnDeconv/Tensordot/free
model_1/fcnDeconv/Tensordot/axes
model_1/fcnDeconv/Tensordot/concat/axis
model_1/fcnDeconv/Tensordot/concat
model_1/fcnDeconv/Tensordot/transpose
model_1/fcnDeconv/Tensordot/Shape
model_1/fcnDeconv/Tensordot/GatherV2/axis
model_1/fcnDeconv/Tensordot/GatherV2
model_1/fcnDeconv/Tensordot/Const
model_1/fcnDeconv/Tensordot/Prod
model_1/fcnDeconv/Tensordot/GatherV2_1/axis
model_1/fcnDeconv/Tensordot/GatherV2_1
model_1/fcnDeconv/Tensordot/Const_1
model_1/fcnDeconv/Tensordot/Prod_1
model_1/fcnDeconv/Tensordot/stack
model_1/fcnDeconv/Tensordot/Reshape
model_1/fcnDeconv/Tensordot/ReadVariableOp/resource
model_1/fcnDeconv/Tensordot/ReadVariableOp
model_1/fcnDeconv/Tensordot/transpose_1/perm
model_1/fcnDeconv/Tensordot/transpose_1
model_1/fcnDeconv/Tensordot/Reshape_1/shape
model_1/fcnDeconv/Tensordot/Reshape_1
model_1/fcnDeconv/Tensordot/MatMul
model_1/fcnDeconv/Tensordot/Const_2
model_1/fcnDeconv/Tensordot/concat_1/axis
model_1/fcnDeconv/Tensordot/concat_1
model_1/fcnDeconv/Tensordot
model_1/fcnDeconv/BiasAdd/ReadVariableOp/resource
model_1/fcnDeconv/BiasAdd/ReadVariableOp
model_1/fcnDeconv/BiasAdd
model_1/fcnDeconv/Elu
model_1/reshapeConv/Shape
model_1/reshapeConv/strided_slice/stack
model_1/reshapeConv/strided_slice/stack_1
model_1/reshapeConv/strided_slice/stack_2
model_1/reshapeConv/strided_slice
model_1/reshapeConv/Reshape/shape/1
model_1/reshapeConv/Reshape/shape/2
model_1/reshapeConv/Reshape/shape/3
model_1/reshapeConv/Reshape/shape
model_1/reshapeConv/Reshape
model_1/conv0/Shape
model_1/conv0/strided_slice/stack
model_1/conv0/strided_slice/stack_1
model_1/conv0/strided_slice/stack_2
model_1/conv0/strided_slice
model_1/conv0/strided_slice_1/stack
model_1/conv0/strided_slice_1/stack_1
model_1/conv0/strided_slice_1/stack_2
model_1/conv0/strided_slice_1
model_1/conv0/mul/y
model_1/conv0/mul
model_1/conv0/strided_slice_2/stack
model_1/conv0/strided_slice_2/stack_1
model_1/conv0/strided_slice_2/stack_2
model_1/conv0/strided_slice_2
model_1/conv0/mul_1/y
model_1/conv0/mul_1
model_1/conv0/stack/3
model_1/conv0/stack
model_1/conv0/conv2d_transpose/ReadVariableOp/resource
model_1/conv0/conv2d_transpose/ReadVariableOp
model_1/conv0/conv2d_transpose
model_1/conv0/BiasAdd/ReadVariableOp/resource
model_1/conv0/BiasAdd/ReadVariableOp
model_1/conv0/BiasAdd
model_1/conv0/Elu
##### etc. for conv1, conv2, conv3 ######
model_1/conv3/conv2d_transpose/ReadVariableOp/resource
model_1/conv3/conv2d_transpose/ReadVariableOp
model_1/conv3/conv2d_transpose
model_1/conv3/BiasAdd/ReadVariableOp/resource
model_1/conv3/BiasAdd/ReadVariableOp
model_1/conv3/BiasAdd
Identity

Clearly, in changing the Dense layer input into a 3D tensor, the operations utilized in the Dense layer have changed to tensor operations from a simple Matmul, now requiring unsupported ops. I am just so incredibly surprised that the UFF parser cannot handle what I would think is the most basic of neural network operations! It’s just computing a Dense layer, using the dimensionality requirements set by TensorRT.

Is there any way to resolve this issue, perhaps formatting the network architecture in Keras such that the graph is compatible with UFF? Am I doing something wrong in reshaping the inputs to the decoder Dense layer? I would rather fix this issue on the Keras end, if possible. I am averse to using TF-TRT, as there are no guarantees that all nodes will convert to TensorRT nodes, and I absolutely need to use the optimized TRT engine with the C++ API ONLY, no Python whatsoever. I am also not so keen on using tf2onnx, as I don’t see any guarantees that the ONNX format will resolve my issue. Furthermore, it doesn’t have published support for TF 2.0 networks. But if someone can assure me that ONNX is able to support these operations, then by all means I will attack it from that angle.

I have tested this issue with the same results on the following systems:

  1. Local machine
    Ubuntu 18.04
    TensorFlow 2.0
    CUDA 10.2
    cuDNN 7.6.4
    TensorRT 7
    NVIDIA Quadro K1200

  2. Remote cluster (TensorRT builds and runs samples just fine, despite no explicit support for SLES)
    SLES 12.3
    TensorFlow 1.14
    CUDA 9.0
    cuDNN 7.6.5
    TensorRT 7
    NVIDIA Tesla P100

Hi,

Deprecation of Caffe Parser and UFF Parser in TRT 7
https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-700/tensorrt-release-notes/tensorrt-7.html#rel_7-0-0

You can try tf2onnx + ONNX parser as an alternative. Any layer that are not supported needs to be replaced by custom plugin.
https://github.com/onnx/tensorflow-onnx
https://github.com/onnx/onnx-tensorrt/blob/master/operators.md

Also, you can convert the model to TRT using TF-TRT and serialize it to a .plan file. Then deserialize the .plan file using the C++ API (TensorRT’s C++ API or through the TensorRT Inference Server).
See:
https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html#usage-example
and
https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html#tensorrt-plan

If issue persist, could you please share your model file so we can help better?

Thanks