TensorRT 8.2.1 convert LSTM model failed

Description

When I’m using the latest TensorRT (8.2.1) to convert a OCR model, an error happens in myelin assertion code. After some digging up, I’ve found that the problem occurs from the LSTM operator in the model. The error message:

/root/gpgpu/MachineLearning/myelin/src/compiler/optimizer/formats.cpp:3052: bool myelin::ir::no_data_move(const myelin::tensor_descriptor_t*, const std::vector<int>&): Assertion `perm[i] >= 0 && perm[i] < (int) out->get_const_dimensions().size()' failed.

Environment

TensorRT Version: 8.2.1.8
GPU Type: Tesla T4
Nvidia Driver Version: 440.33.01
CUDA Version: 10.2
CUDNN Version: 8.2.1
Operating System + Version: CentOS 7
Python Version (if applicable): 3.7
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.7.0
Baremetal or Container (if container which image + tag):

Relevant Files

Exported model lstm.onnx (99.0 KB)

Full error log convert.log (15.5 KB)

Steps To Reproduce

Minimal steps to reproduce the bug:

  1. Export lstm.onnx model using pytorch 1.7.0 (newer pytorch version will export model with slightly different structure, but the error still occurs)
import torch
import torch.nn as nn
import numpy as np


class Model(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Model, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)

    def forward(self, input):
        # the permute and squeeze steps are copied from original OCR model
        # which are needed to reproduce the bug
        input = input.permute((0, 3, 1, 2)).squeeze(3)
        recurrent, _ = self.rnn(input)
        return recurrent


batch_size = 10
time_step = 16
input_size = 64
hidden_size = 32

data = torch.FloatTensor(np.random.rand(batch_size, input_size, 1, time_step))
model = Model(input_size, hidden_size)
torch.onnx.export(model, data, "lstm.onnx", input_names=['data'], export_params=True, opset_version=10, verbose=True)
  1. Convert TensorRT lstm.onnx using python script:
import pycuda.autoinit
import tensorrt as trt
import onnx

logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)
network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

model = onnx.load('lstm.onnx')

shape = (10, 64, 1, 16)

if not parser.parse(model.SerializeToString()):
    error = parser.get_error(0)
    msg = "While parsing node number %i:\n" % error.node()
    msg += ("%s:%i In function %s:\n[%i] %s" %
            (error.file(), error.line(), error.func(),
             error.code(), error.desc()))
    raise RuntimeError(msg)

config = builder.create_builder_config()
config.max_workspace_size = 1024 << 20

profile = builder.create_optimization_profile()
profile.set_shape("data", shape, shape, shape)
config.add_optimization_profile(profile)

# this produces the error
engine = builder.build_serialized_network(network, config)

with open('lstm.trt', 'wb') as f:
    f.write(bytes(engine))

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

  1. onnx checker

    The model holds with onnx checker. Actually, the same model file is converted successfully using TensorRT 7.1.3.4.

  2. Run with trtexec command

    I run with the command: trtexec --onnx=lstm.onnx --verbose, the error log is almost the same as the convert.log attached above.

@NVES Any updates?

get the same log with you

Managed to workaround the problem by changing onnx model:

import onnx
import onnx_graphsurgeon as gs

prefix = 'lstm'

graph = gs.import_onnx(onnx.load(prefix + '.onnx'))

lstm = [n for n in graph.nodes if 'LSTM' in n.name][0]
lstm.inputs = lstm.inputs[:4]
trans2 = lstm.inputs[0].inputs[0]
squeeze = trans2.inputs[0].inputs[0]
trans1 = squeeze.inputs[0].inputs[0]
trans1.attrs['perm'] = [3, 0, 1, 2]
lstm.inputs[0] = squeeze.outputs[0]

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), prefix + '-fix.onnx')

2 Likes

Awesome, that works! Thank you!

Awesome