Bug of LayerNormPlugin

Description

The output of LayerNormPlugin has large error compared to torch.nn.LayerNorm.

Environment

official docker container 22.12

Relevant Files

related files: https://cloud.tsinghua.edu.cn/f/40c33b2aeb2347678cea/?dl=1

Steps To Reproduce

import os
from polygraphy.backend.trt import TrtRunner
import torch
import numpy as np

class MyLayerNorm(torch.nn.LayerNorm):
    def __init__(self, dim, eps=1e-5):
        super().__init__(dim, eps)
    def forward(self, x):
        x = x.view(2, 256, 4, 16, 4, 16)
        x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
        x = x.view(2, 16, 256, 256)
        return super().forward(x)

input_0, model = torch.load('LN_bug.pt')
with torch.no_grad():
    outputs_pt = model(input_0).numpy()

feed_dict = {'input_0': input_0.numpy()}

import tensorrt as trt

TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, '')

def load_engine(engine_file_path):
    assert os.path.exists(engine_file_path)
    print("Reading engine from file {}".format(engine_file_path))
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

os.system('trtexec --onnx=LN_bug.onnx --saveEngine=LN_bug.trt --fp16 --buildOnly')
engine = load_engine('LN_bug.trt')
with TrtRunner(engine) as runner:
    outputs_trt = runner.infer(feed_dict)

print('max error', np.abs(outputs_pt-outputs_trt['output_0']).max())

It will output >11 absolute error.

Hi,
Please refer to below links related custom plugin implementation and sample:

While IPluginV2 and IPluginV2Ext interfaces are still supported for backward compatibility with TensorRT 5.1 and 6.0.x respectively, however, we recommend that you write new plugins or refactor existing ones to target the IPluginV2DynamicExt or IPluginV2IOExt interfaces instead.

Thanks!

Well… However, I think official provided plugins should be bug-free at least…

Hi @lqs1 ,
You should export your model in opset 17 in order to leverage the new normalization layer in TensorRT 8.6.

Thanks