Can't accelerate BIDIRECTION rnn(lstm) in TensorRT5.1.5.0

RNN model use tensorflow API:tf.contrib.cudnn_rnn.CudnnLSTM create bidirection rnn。from the docs, I think it’s support this API. (https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#working_with_tf_rnn)

But I don’t know how to set parameters.

Does anyone know how to solve it? Or is there an example for BiRNN?

Thanks.

Here is my env:

  1. Linux version: Ubuntu 18.04
  2. GPU: GTX745
  3. Nvidia driver version: 430.26
  4. CUDA version: 10.1
  5. CUDNN version: 7
  6. TensorRT version: 5.1.5.0
  7. python3.6
#!/usr/bin python
# -*- coding: utf-8 -*-

import numpy as np
import tensorrt as trt
import common
from tensorflow.python import pywrap_tensorflow

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class ModelData(object):
    INPUT_NAME = "x"
    INPUT_SHAPE = (48, 1, 32, 256)
    OUTPUT_NAME = "output"
    OUTPUT_SIZE = 355
    DTYPE = trt.float32

def populate_network(network, weights):
    # cnn
    input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)

    scale1 = network.add_scale(input=input_tensor, mode=trt.ScaleMode.UNIFORM)

    conv1_w = weights['conv1/weights']
    conv1_b = weights['conv1/BatchNorm/beta']

    conv1 = network.add_convolution(input=scale1.get_output(0), num_output_maps=64, kernel_shape=(3, 3), kernel=conv1_w, bias=conv1_b)
    conv1.padding_mode = trt.PaddingMode.SAME_UPPER
    print(scale1.get_output(0).shape, conv1.get_output(0).shape)

    scale2 = network.add_scale(input=conv1.get_output(0), mode=trt.ScaleMode.UNIFORM)

    conv2_w = weights['conv2/weights']
    conv2_b = weights['conv2/BatchNorm/beta']
    conv2 = network.add_convolution(input=scale2.get_output(0), num_output_maps=64, kernel_shape=(3, 3), kernel=conv2_w,
                                    bias=conv2_b)
    conv2.padding_mode = trt.PaddingMode.SAME_UPPER

    pool1 = network.add_pooling(input=conv2.get_output(0), type=trt.PoolingType.MAX, window_size=[2, 2])
    pool1.stride = (2, 2)
    pool1.padding_mode = trt.PaddingMode.SAME_UPPER

    scale3 = network.add_scale(input=pool1.get_output(0), mode=trt.ScaleMode.UNIFORM)

    conv3_w = weights['conv3/weights']
    conv3_b = weights['conv3/BatchNorm/beta']
    conv3 = network.add_convolution(input=scale3.get_output(0), num_output_maps=128, kernel_shape=(3, 3), kernel=conv3_w,
                                    bias=conv3_b)
    conv3.padding_mode = trt.PaddingMode.SAME_UPPER

    scale4 = network.add_scale(input=conv3.get_output(0), mode=trt.ScaleMode.UNIFORM)
    conv4_w = weights['conv4/weights']
    conv4_b = weights['conv4/BatchNorm/beta']
    conv4 = network.add_convolution(input=scale4.get_output(0), num_output_maps=128, kernel_shape=(3, 3), kernel=conv4_w,
                                    bias=conv4_b)
    conv4.padding_mode = trt.PaddingMode.SAME_UPPER

    pool2 = network.add_pooling(input=conv4.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2))
    pool2.stride = (2, 1)
    pool2.padding_mode = trt.PaddingMode.SAME_UPPER

    scale5 = network.add_scale(input=pool2.get_output(0), mode=trt.ScaleMode.UNIFORM)
    conv5_w = weights['conv5/weights']
    conv5_b = weights['conv5/BatchNorm/beta']
    conv5 = network.add_convolution(input=scale5.get_output(0), num_output_maps=256, kernel_shape=(3, 3), kernel=conv5_w,
                                    bias=conv5_b)
    conv5.padding_mode = trt.PaddingMode.SAME_UPPER

    scale6 = network.add_scale(input=conv5.get_output(0), mode=trt.ScaleMode.UNIFORM)
    conv6_w = weights['conv6/weights']
    conv6_b = weights['conv6/BatchNorm/beta']
    conv6 = network.add_convolution(input=scale6.get_output(0), num_output_maps=256, kernel_shape=(3, 3), kernel=conv6_w,
                                    bias=conv6_b)
    conv6.padding_mode = trt.PaddingMode.SAME_UPPER

    pool3 = network.add_pooling(input=conv6.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2))
    pool3.stride = (2, 1)
    pool3.padding_mode = trt.PaddingMode.SAME_UPPER

    scale7 = network.add_scale(input=pool3.get_output(0), mode=trt.ScaleMode.UNIFORM)
    conv7_w = weights['conv7/weights']
    conv7_b = weights['conv7/BatchNorm/beta']
    conv7 = network.add_convolution(input=scale7.get_output(0), num_output_maps=512, kernel_shape=(3, 3), kernel=conv7_w,
                                    bias=conv7_b)
    conv7.padding_mode = trt.PaddingMode.SAME_UPPER

    scale8 = network.add_scale(input=conv7.get_output(0), mode=trt.ScaleMode.UNIFORM)

    conv8_w = weights['conv8/weights']
    conv8_b = weights['conv8/BatchNorm/beta']
    conv8 = network.add_convolution(input=scale8.get_output(0), num_output_maps=512, kernel_shape=(3, 3), kernel=conv8_w,
                                    bias=conv8_b)
    conv8.padding_mode = trt.PaddingMode.SAME_UPPER
    print(scale8.get_output(0).shape, conv8_w.shape, conv8_b.shape, conv8.get_output(0).shape)

    scale9 = network.add_scale(input=conv8.get_output(0), mode=trt.ScaleMode.UNIFORM)
    pool4_w = weights['pool4/weights']
    pool4_b = weights['pool4/BatchNorm/beta']
    pool4 = network.add_convolution(input=scale9.get_output(0), num_output_maps=512, kernel_shape=(4, 3), kernel=pool4_w,
                                    bias=pool4_b)
    pool4.padding_mode = trt.PaddingMode.SAME_UPPER

    pool5 = network.add_pooling(input=pool4.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2))
    pool5.stride = (2, 1)
    pool5.padding_mode = trt.PaddingMode.SAME_UPPER

    feat = network.add_shuffle(input=pool5.get_output(0))
    feat.first_transpose = (0, 2, 3, 1)
    feat.reshape_dims = (feat.get_output(0).shape[0], -1, feat.get_output(0).shape[-1])

    feat = network.add_shuffle(input=feat.get_output(0))
    feat.first_transpose = (1, 0, 2)

    # rnn
    rnn = network.add_rnn_v2(input=feat.get_output(0), layer_count=2, hidden_size=256, max_seq_length=48,
                             op=trt.RNNOperation.LSTM)

rnn.input_mode = trt.RNNInputMode.SKIP
    rnn.direction = trt.RNNDirection.BIDIRECTION
 
    rnn0_fw_w = weights['rnn/stack_bidirectional_rnn/cell_0/bidirectional_rnn/fw/cudnn_compatible_lstm_cell/kernel']
    rnn0_fw_b = weights['rnn/stack_bidirectional_rnn/cell_0/bidirectional_rnn/fw/cudnn_compatible_lstm_cell/bias']

    rnn0_bw_w = weights['rnn/stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/cudnn_compatible_lstm_cell/kernel']
    rnn0_bw_b = weights['rnn/stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/cudnn_compatible_lstm_cell/bias']

    rnn1_fw_w = weights['rnn/stack_bidirectional_rnn/cell_1/bidirectional_rnn/fw/cudnn_compatible_lstm_cell/kernel']
    rnn1_fw_b = weights['rnn/stack_bidirectional_rnn/cell_1/bidirectional_rnn/fw/cudnn_compatible_lstm_cell/bias']

    rnn1_bw_w = weights['rnn/stack_bidirectional_rnn/cell_1/bidirectional_rnn/bw/cudnn_compatible_lstm_cell/kernel']
    rnn1_bw_b = weights['rnn/stack_bidirectional_rnn/cell_1/bidirectional_rnn/bw/cudnn_compatible_lstm_cell/bias']

gate_order = [trt.RNNGateType.INPUT, trt.RNNGateType.CELL, trt.RNNGateType.FORGET, trt.RNNGateType.OUTPUT]
    num_gates = len(gate_order)
    for i in range(num_gates*2):
        flag = True
        if i >= num_gates:
            flag = False
        rnn.set_weights_for_gate(layer_index=0, gate=gate_order[i % num_gates], is_w=flag, weights=rnn0_fw_w)
        rnn.set_bias_for_gate(layer_index=0, gate=gate_order[i % num_gates], is_w=flag, bias=rnn0_fw_b)
        rnn.set_weights_for_gate(layer_index=0, gate=gate_order[i % num_gates], is_w=flag, weights=rnn0_bw_w)
        rnn.set_bias_for_gate(layer_index=0, gate=gate_order[i % num_gates], is_w=flag, bias=rnn0_bw_b)

        rnn.set_weights_for_gate(layer_index=1, gate=gate_order[i % num_gates], is_w=flag, weights=rnn1_fw_w)
        rnn.set_bias_for_gate(layer_index=1, gate=gate_order[i % num_gates], is_w=flag, bias=rnn1_fw_b)
        rnn.set_weights_for_gate(layer_index=1, gate=gate_order[i % num_gates], is_w=flag, weights=rnn1_bw_w)
        rnn.set_bias_for_gate(layer_index=1, gate=gate_order[i % num_gates], is_w=flag, bias=rnn1_bw_b)

    print('rnn shape', rnn.get_output(0).shape)

fc_w = weights['rnn/logits/kernel']
    fc_b = weights['rnn/logits/bias']

    feat = network.add_shuffle(input=rnn.get_output(0))
    feat.reshape_dims = (rnn.get_output(0).shape[0], rnn.get_output(0).shape[1], 1, 1, rnn.get_output(0).shape[-1])
    print('feat', feat.get_output(0).shape)

    fc = network.add_fully_connected(input=feat.get_output(0), num_outputs=355, kernel=fc_w, bias=fc_b)
    print('fc', fc.get_output(0).shape)
    activate = network.add_activation(input=fc.get_output(0), type=trt.ActivationType.RELU)
    print('activate', activate.get_output(0).shape)

    feat = network.add_shuffle(input=activate.get_output(0))
    feat.reshape_dims = (activate.get_output(0).shape[0], activate.get_output(0).shape[1], activate.get_output(0).shape[2])
    print('output', feat.get_output(0).shape)

    feat.get_output(0).name = ModelData.OUTPUT_NAME
    network.mark_output(tensor=feat.get_output(0))

def get_weights(mod_path='../0912/fix.ckpt'):
    print('run')
    reader = pywrap_tensorflow.NewCheckpointReader(mod_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    weights = {}
    for key, value in var_to_shape_map.items():
        # print('tensor_name:  ', key, value)
        # print(reader.get_tensor(key))
        weights[key] = reader.get_tensor(key)
        # print(type(reader.get_tensor(key)))
        # print(reader.get_tensor(key))
    return weights

def build_engine(weights):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network:
        builder.max_workspace_size = common.GiB(1)
        builder.max_batch_size = 1
        # Populate the network using weights.
        populate_network(network, weights)
        print('network', network)
        # Build and return an engine.
        engine = builder.build_cuda_engine(network)
        print('engine', engine)
        return builder.build_cuda_engine(network)

def get_engine(engine_file_path, weights):
    try:
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            # Note that we have to provide the plugin factory when deserializing an engine built with an IPlugin or IPluginExt.
            return runtime.deserialize_cuda_engine(serialized_engine=f.read(), plugin_factory=None)
    except Exception as e:
        # Fallback to building an engine if the engine cannot be loaded for any reason.
        engine = build_engine(weights)
        # with open(engine_file_path, "wb") as f:
        #     f.write(engine.serialize())
        print('engine', engine)
        return engine

def main():
    weights = get_weights()
    print('success')
    engine_file_path = 'crnn.engine'
    with get_engine(engine_file_path, weights) as engine:
        # Build an engine, allocate buffers and create a stream.
        inputs, outputs, bindings, stream = common.allocate_buffers(engine)
        with engine.create_execution_context() as context:
            [output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
            # pred = np.argmax(output)
            print('success')

if __name__ == '__main__':
    main()

this is my pb file:

https://drive.google.com/file/d/15Qnyg99Yyj-PhbZHkzVbMQ3KcqVk2d08/view?usp=sharing