TensorRT int8 slower than FP16 due to reformat layer

Description

TensorRT int8 slower than FP16,

Environment

TensorRT Version: 10.2.0.19
GPU Type: RTX 3090
Nvidia Driver Version: 530.30.02
CUDA Version: 11.3
CUDNN Version: 8.2
Operating System + Version: Ubuntu 20.04.2 LTS
Python Version (if applicable): 3.8.5
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.10
Baremetal or Container (if container which image + tag):

Relevant Files

Steps To Reproduce

1. convert pytorch model to onnx

import torch
import torch.nn as nn

class test_net(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(test_net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3,bias=True, stride=1, padding=1)
        self.conv2_1 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3,bias=True, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3,bias=True, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3,bias=True, stride=1, padding=1)
        self.act = nn.ReLU()
    def forward(self, x):
        x1 = self.conv1(x)
        x2_1 = self.conv2_1(x1)
        x2_1 = self.act(x2_1)
        x2_2 = self.conv2_2(x2_1)
        x2_2 = self.act(x2_2)
        x2_3 = self.conv2_3(x2_2)
        x2_3 = self.act(x2_3)
        output = x1+x2_3
        # output = x2_3
        return output

model = test_net(3,16)
input_patch = torch.rand((1,3,3008,4000)).cuda()
output_onnx_path = 'deploy_models/testnet_testnet_1x3x3008x4000_fp32.onnx')
model.to('cuda:0')
model.eval()
torch.onnx.export(model, input_patch, output_onnx_path,
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={'input':{
                                           2: 'inp_width',
                                           3: 'inp_height'},
                                   'output': {
                                              2: 'out_width',
                                              3: 'out_height'}
                                 },
                    opset_version=14
                )

2. convert fp32 onnx model to fp16 trt engine

/data1/tensorrt/TensorRT-10.2.0.19/bin/trtexec   --fp16 --onnx=deploy_models/testnet_testnet_1x3x3008x4000_fp32.onnx  --shapes=input:1x3x3008x4000  --builderOptimizationLevel=5 --useCudaGraph  --saveEngine=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_fp16.engine --verbose --exportProfile=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_fp16.engine.profile.json  --exportLayerInfo=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_fp16.engine.graph.json --profilingVerbosity=detailed  --inputIOFormats=fp32:chw --outputIOFormats=fp32:chw

2.1. visualize fp16 trt engine with trex

2.2. run fp16 trt engine

/data1/tensorrt/TensorRT-10.2.0.19/bin/trtexec --loadEngine=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_fp16.engine

[10/11/2024-17:19:48] [I] GPU Compute Time: min = 7.23355 ms, max = 8.04761 ms, mean = 7.60943 ms, median = 7.61963 ms, percentile(90%) = 7.62775 ms, percentile(95%) = 7.63086 ms, percentile(99%) = 8.04761 ms

3. convert fp32 onnx model to int8 trt engine

/data1/tensorrt/TensorRT-10.2.0.19/bin/trtexec   --int8 --onnx=deploy_models/testnet_testnet_1x3x3008x4000_fp32.onnx  --shapes=input:1x3x3008x4000  --builderOptimizationLevel=5 --useCudaGraph  --saveEngine=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_int8.engine --verbose --exportProfile=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_int8.engine.profile.json  --exportLayerInfo=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_int8.engine.graph.json --profilingVerbosity=detailed  --inputIOFormats=int8:chw --outputIOFormats=int8:chw

3.1. visualize int8 trt engine with trex

3.2. run int8 trt engine

/data1/tensorrt/TensorRT-10.2.0.19/bin/trtexec --loadEngine=deploy_models/testnet_testnet_1x3x3008x4000_fp32_direct_int8.engine

[10/11/2024-17:20:23] [I] GPU Compute Time: min = 10.1407 ms, max = 10.4663 ms, mean = 10.2806 ms, median = 10.2698 ms, percentile(90%) = 10.3618 ms, percentile(95%) = 10.37 ms, percentile(99%) = 10.4028 ms

4. Problems

Int8 tensorrt engine is slower than Fp16’s, mainly due to the reformat layer. The single conv in int8 is a little fatster than fp16, but the reformat layer substantially cost about 1 ms. So what can i do to solve this problem.