Wrong inference in TensorRT after convert keras model to TensorRT

I was having problems with the tf.keras model going to the tensorrt model, and I decided that the problem was with the TensorRT conversion.

According to nvidia’s TensorRT guide, the process of tf.keras to TensorRT is like this:
(The first two steps are common steps, we run in our own cloud, and the next two steps are the conversions made on the Jetson Nano.)

Tf.keras model → converted to frozen_model → converted to uff → converted to tensorrt

As a result, the model is inaccurate when it is finally converted and reasoned.

The version of my tf is tf2.0, tf2.0 is also used when generating the model, and the tf version on the Jetson Nano is tf1. However, when converting from keras to pb, the compatibility package of tf.compat.v1 is used. This is equivalent to converting v2’s keras model to v1’s frozen model. If it is a problem with the data structure of keras, the prediction of this step should be a problem. But my actual test result is: the prediction results of the keras and pb models are the same. Therefore, I judge that the final conversion is not accurate and has nothing to do with tf2.

My test results:
    Tf.keras model √
    Convert to frozen_model √
    Convert to uff (converted successfully)
    Converted to tensorrt (converted successfully, but the result is not accurate)

Here are the code:

Model

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape (Reshape)            (None, 100, 400, 1)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 100, 400, 32)      64        
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 50, 200, 32)       0         
_________________________________________________________________
dropout (Dropout)            (None, 50, 200, 32)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 50, 200, 64)       2112      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 25, 100, 64)       0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 25, 100, 64)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 25, 100, 64)       4160      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 13, 50, 64)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 13, 50, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 41600)             0         
_________________________________________________________________
dense (Dense)                (None, 1024)              42599424  
_________________________________________________________________
dropout_3 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 203)               208075    
=================================================================
Total params: 42,813,835
Trainable params: 42,813,835
Non-trainable params: 0
_________________________________________________________________
None

convertTRT

uff_model = uff.from_tensorflow_frozen_model('frozen_model.pb', ['dense_1/BiasAdd'], output_filename='tmp.uff')

with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
    builder.max_workspace_size = 1 << 28
    builder.max_batch_size = 1

    parser.register_input('reshape_input', (1, 40000))
    parser.register_output('dense_1/BiasAdd')
    parser.parse('tmp.uff', network)
    engine = builder.build_cuda_engine(network)

    buf = engine.serialize()

    with open('model.bin', 'wb') as f:
        f.write(buf)

inference

# initialize
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
runtime = trt.Runtime(TRT_LOGGER)

# create engine
with open('model.bin', 'rb') as f:
    buf = f.read()
    engine = runtime.deserialize_cuda_engine(buf)

# create buffer
host_inputs  = []
cuda_inputs  = []
host_outputs = []
cuda_outputs = []
bindings = []
stream = cuda.Stream()

for binding in engine:
    size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
    host_mem = cuda.pagelocked_empty(size, np.float32)
    cuda_mem = cuda.mem_alloc(host_mem.nbytes)

    bindings.append(int(cuda_mem))
    if engine.binding_is_input(binding):
        host_inputs.append(host_mem)
        cuda_inputs.append(cuda_mem)
    else:
        host_outputs.append(host_mem)
        cuda_outputs.append(cuda_mem)
context = engine.create_execution_context()

ori = cv.imread('test.jpg')
image = cv.cvtColor(ori, cv.COLOR_BGR2RGB)
image = convert2gray(image)
image = image.flatten() / 255.0
image = np.expand_dims(image, axis = 0)
np.copyto(host_inputs[0], image.ravel())

cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
context.execute_async(bindings=bindings, stream_handle=stream.handle)
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
stream.synchronize()

# [7 * 29] words embedding 
output = host_outputs[0]

Model file:
https://drive.google.com/drive/folders/1SbobJWI9VJ-i4Bfgv5Jotn9eS0SIhXly?usp=sharing

Hi,

You can find the compatibility information here:
https://docs.nvidia.com/deeplearning/sdk/tensorrt-release-notes/tensorrt-5.html#tensorrt-5

For TensorRT v5.1.x, you will need to use TensorFlow v1.12.0. (v1.13.0 is also good.)

How do you verify your v1.x model. Do you use keras with v1.x TF backend or v2.0 TF?
By the way, is it possible to get your model serialized directly with v1.x API?

Thanks.

Hi,

The model is built by v2.0 TF backend,and frozened by v2.0 TF’s v1.x compat package which is fully the same as the v1 version.

I verify the frozened model by v2.0 TF’s v1.x compat package, it works good.

I convert frozened model to TensorRT model in Jetson Nano, but the inference results is not the same as frozened model. The TensorFlow’s version is v1.13, the TensorRT’s version is 5.0.6.3


Message update:

I retest the same frozened model by v1.14 TF, it also works good.(To avoid misunderstanding, it will not update into drive.)


Message update:

  1. I frozened the model by v1.13.1 TF in PC and convert the model in Jetson Nano.
  2. I test the frozened model in Jetson Nano, it works good and get accuracy result.
  3. I test the TensorRT model in Jetson Nano, it works not good and get wrong result.

I will update the model files in drive.


Test model file:
[url]Carlicense - Google Drive

Thanks.

Hi,

please make sure that you configure the TensorRT logger like this for converting the frozen graph:

TRT_LOGGER = trt.Logger(trt.Logger.INFO)

Afterwards run the conversion script again and check if you see the following line:

DEBUG: convert reshape to flatten node

Please report back if this line is present in the converter output.

Hi,

Yes, it said:

Using output node dense_1/BiasAdd
Converting to UFF graph
DEBUG: convert reshape to flatten node
No. nodes: 40
UFF Output written to tmp.uff

Should I remove the reshape layer in the model?

Hi,

The debug message is misleading. The Flatten layer is the problem (reference: https://devtalk.nvidia.com/default/topic/1036228/tensorrt/fc-layer-unsupported-with-trt/).

In short terms: There is (at least to my understanding) a known bug in TensorRT / UFF parser that affects networks with single channel input (e.g. greyscale image input) that make use of the Flatten layer. You need to replace the Flatten layer with a Reshape layer that does the same job. To do this you need to retrain your network (sorry) - do not attempt any graph hacks with the graph surgeon (gs)!

If I may I want to give you some advice because I was also affected by this and it cost me several hours to find out.

I assume your model definition code around the Flatten layer looks similar to this:

y = Dropout(0.5)(y)
y = Flatten()(y)
y = Dense(units=1024, activation='relu')(y)

Now replace the Flatten layer with a Reshape layer:

y = Dropout(0.5)(y)
y = Reshape([y._keras_shape[1] * y._keras_shape[2] * y._keras_shape[3]])(y)
y = Dense(units=1024, activation='relu')(y)

This should do the trick. Unfortunately TensorRT is not able to parse simpler statements like Reshape(-1).

I have some questions about your code:

  1. Why is there a Reshape layer present at the top of your network? I guess it has something to do with the way you are feeding the image to the network. I assume your image has the shape (100, 400, 1) after it is loaded. You do not need to feed a 1D array to the network. Just remove the Reshape layer and set the input shape of the Conv2D layer to (100, 400, 1).

  2. Why are you doing

ori = cv.imread('test.jpg')
image = cv.cvtColor(ori, cv.COLOR_BGR2RGB)
image = convert2gray(image)

to load a greyscale image? Why not do this instead:

ori = cv.imread('test.jpg', 0)

Btw. is there really no need to resize the input image?

  1. Why are you flattening and raveling the Numpy array?
image = image.flatten() / 255.0
image = np.expand_dims(image, axis = 0)
np.copyto(host_inputs[0], image.ravel())

Has this something to do with what I have asked in my first question? From my perspective you can replace this code (assuming you have removed the Reshape layer at the top of your network) with the following:

img_np = np.asarray(image) / 255.
img_np = img_np.transpose((2, 0, 1))
img_np = img_np.astype(np.float32)
np.copyto(host_inputs[0], img_np.ravel())

I would recommend that you first try to replace the Flatten layer. You can try the other things afterwards. Just keep in mind that TensorRT expects the network input to be in (N, C, H, W) order even if the input did not look like that in Keras.

Please report back.

1 Like

Hi,

I change the Flatten layer to Reshape layer and retrain the model, export it to TensorRT. Thanks god, it works!

Here are the finally model.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 100, 400, 32)      64        
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 50, 200, 32)       0         
_________________________________________________________________
dropout (Dropout)            (None, 50, 200, 32)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 50, 200, 64)       2112      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 25, 100, 64)       0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 25, 100, 64)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 25, 100, 64)       4160      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 13, 50, 64)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 13, 50, 64)        0         
_________________________________________________________________
reshape (Reshape)            (None, 41600)             0         
_________________________________________________________________
dense (Dense)                (None, 1024)              42599424  
_________________________________________________________________
dropout_3 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 203)               208075    
=================================================================
Total params: 42,813,835
Trainable params: 42,813,835
Non-trainable params: 0
_________________________________________________________________

And about your question.

  1. The model’s dataset was generated, and it used convert2gray function which is np.mean(img, -1)

  2. My test image’s size is completely the same as training dataset.

  3. My finally inference code is below

ori = cv.imread('test.jpg')
image = cv.cvtColor(ori, cv.COLOR_BGR2RGB)
image = convert2gray(image)
image = image / 255.0
image = image.reshape((100, 400, 1))
image = image.transpose(2, 0, 1)
np.copyto(host_inputs[0], image.ravel())

hi ,

how did you add reshape layer to the model? can you please provide the steps.
As I am getting invalid predictions when i convert from keras to tensorrt,
for me adding reshape gives me a issue at the 4th line

model = Sequential()
model.add(ResNet50(include_top=False, pooling=‘max’, weights=RESNET_WEIGHTS_PATH))
model.add(Dense(NUM_CLASSES, activation=‘softmax’, name=‘output_tensor’))
model.compile(optimizer=‘sgd’, loss=‘mse’, metrics=[“accuracy”])

Thanks,
gp

Hi gp,

Please open a new topic for you issue if need the support. Thanks

sure Thanks