TensorRT inference speed is lower than pytorch

Description

A clear and concise description of the bug or issue.

Environment

TensorRT Version: 8.2.1
GPU Type: Jetson Xavier
Nvidia Driver Version:
CUDA Version: 10.2
CUDNN Version:
Operating System + Version:
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

I have a point cloud classification model, written in pytorch:

#  Ref: https://github.com/WangYueFt/dgcnn/blob/master/pytorch/model.py
#  Ref: https://github.com/hansen7/OcCo/blob/master/OcCo_Torch/models/dgcnn_cls.py

import torch, torch.nn as nn, torch.nn.functional as F

def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx


def get_graph_feature(x, k=20, idx=None, extra_dim=False):

    batch_size, num_dims, num_points = x.size()
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if extra_dim is False:
            idx = knn(x, k=k)
        else:
            idx = knn(x[:, 6:], k=k)  # idx = knn(x[:, :3], k=k)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx += idx_base
    idx = idx.view(-1)

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()

    return feature  # (batch_size, 2 * num_dims, num_points, k)


class get_model(nn.Module):

    def __init__(self, args, num_channel=3, num_class=40, **kwargs):
        super(get_model, self).__init__()
        self.args = args
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(args.emb_dims)

        self.conv1 = nn.Sequential(nn.Conv2d(num_channel*2, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=args.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=args.dropout)
        self.linear3 = nn.Linear(256, num_class)

    def forward(self, x):
        batch_size = x.size()[0]
        x = get_graph_feature(x, k=self.args.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, k=self.args.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, k=self.args.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x3, k=self.args.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)

        x = self.conv5(x)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
        x = self.dp2(x)
        x = self.linear3(x)
        return x


class get_loss(torch.nn.Module):
    def __init__(self): 
        super(get_loss, self).__init__()


    @staticmethod
    def cal_loss(pred, gold, smoothing=True, label_weights=False):
        """Calculate cross entropy loss, apply label smoothing if needed."""
        gold = gold.contiguous().view(-1)

        if smoothing and type(label_weights) == bool:
            eps = 0.2
            loss = F.cross_entropy(pred, gold, reduction='mean', label_smoothing=eps)
            
            # n_class = pred.size()[1]
            # one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)  # (num_points, num_class)
            # one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
            # log_prb = F.log_softmax(pred, dim=1)
            # loss = -(one_hot * log_prb).sum(dim=1).mean()  # ~ F.nll_loss(log_prb, gold)
        elif type(label_weights) != bool and not smoothing:
            label_weights = torch.as_tensor(label_weights, device="cuda", dtype=torch.float)
            loss = F.cross_entropy(pred, gold, reduction='mean', weight=label_weights)
        
        elif type(label_weights) != bool and smoothing:
            eps = 0.2
            label_weights = torch.as_tensor(label_weights, device="cuda", dtype=torch.float)
            loss = F.cross_entropy(pred, gold, reduction='mean', 
                                   weight=label_weights, label_smoothing=eps)

        else:
            loss = F.cross_entropy(pred, gold, reduction='mean')

        return loss

    def forward(self, pred, target, label_weights=False):
        return self.cal_loss(pred, target, smoothing=True, label_weights=label_weights)

Model inference time is 18 ms on pytorch.
Than I convert it to onnx and after it to tensorrt through trtexec to have possibility to profile the model.
The inference time when convert in fp32 is 114 ms in fp16 is 91 ms.
Here is the profiling output of the model:

[03/03/2023-05:54:28] [I] === Profile (37 iterations ) ===
[03/03/2023-05:54:28] [I]                                                                                                           Layer   Time (ms)   Avg. Time (ms)   Time %
[03/03/2023-05:54:28] [I]                                                                                              [HostToDeviceCopy]        1.17           0.0316      0.0
[03/03/2023-05:54:28] [I]                                                                 PWN(68 + (Unnamed Layer* 23) [Shuffle], Pow_21)        0.89           0.0240      0.0
[03/03/2023-05:54:28] [I]                                                                                                    ReduceSum_22        1.04           0.0281      0.0
[03/03/2023-05:54:28] [I]                                                       Reformatting CopyNode for Output Tensor 0 to ReduceSum_22        1.51           0.0409      0.0
[03/03/2023-05:54:28] [I]                                                                                                    Transpose_25        0.26           0.0069      0.0
[03/03/2023-05:54:28] [I]                                                           Reformatting CopyNode for Input Tensor 0 to MatMul_17        3.68           0.0994      0.1
[03/03/2023-05:54:28] [I]                                                           Reformatting CopyNode for Input Tensor 1 to MatMul_17        2.35           0.0636      0.1
[03/03/2023-05:54:28] [I]                                                                                                       MatMul_17       13.00           0.3515      0.4
[03/03/2023-05:54:28] [I]                          PWN(PWN(Neg_23, PWN(PWN(66 + (Unnamed Layer* 20) [Shuffle], Mul_19), Sub_24)), Sub_26)       26.20           0.7081      0.8
[03/03/2023-05:54:28] [I]                                                                                                         TopK_27      288.02           7.7843      8.5
[03/03/2023-05:54:28] [I]                                                                                                        Range_29        1.63           0.0440      0.0
[03/03/2023-05:54:28] [I]                                                                      {ForeignNode[Transpose_36...Transpose_58]}        5.91           0.1598      0.2
[03/03/2023-05:54:28] [I]                                                                                                        Range_88        0.83           0.0225      0.0
[03/03/2023-05:54:28] [I]                                                             Reformatting CopyNode for Input Tensor 0 to Conv_59        4.30           0.1162      0.1
[03/03/2023-05:54:28] [I]                                                                                                         Conv_59       25.55           0.6906      0.8
[03/03/2023-05:54:28] [I]                                                   Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_60)       29.31           0.7923      0.9
[03/03/2023-05:54:28] [I]                                                                                               PWN(LeakyRelu_60)       25.72           0.6950      0.8
[03/03/2023-05:54:28] [I]                                                                                                    ReduceMax_61       15.38           0.4157      0.5
[03/03/2023-05:54:28] [I]                                                               PWN(145 + (Unnamed Layer* 108) [Shuffle], Pow_80)        1.59           0.0429      0.0
[03/03/2023-05:54:28] [I]                                                                                                    ReduceSum_81        1.46           0.0395      0.0
[03/03/2023-05:54:28] [I]                                                                                                    Transpose_84        0.02           0.0007      0.0
[03/03/2023-05:54:28] [I]                                                                                                       MatMul_76       24.27           0.6561      0.7
[03/03/2023-05:54:28] [I]                        PWN(PWN(Neg_82, PWN(PWN(143 + (Unnamed Layer* 105) [Shuffle], Mul_78), Sub_83)), Sub_85)       26.11           0.7058      0.8
[03/03/2023-05:54:28] [I]                                                                                                         TopK_86      288.28           7.7913      8.5
[03/03/2023-05:54:28] [I]                                                                     {ForeignNode[Transpose_95...Transpose_117]}      121.05           3.2715      3.6
[03/03/2023-05:54:28] [I]                                                                                                       Range_147        1.52           0.0410      0.0
[03/03/2023-05:54:28] [I]                                                            Reformatting CopyNode for Input Tensor 0 to Conv_118       80.58           2.1778      2.4
[03/03/2023-05:54:28] [I]                                                                                                        Conv_118       57.20           1.5459      1.7
[03/03/2023-05:54:28] [I]                                                  Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_119)       29.88           0.8075      0.9
[03/03/2023-05:54:28] [I]                                                                                              PWN(LeakyRelu_119)       26.59           0.7188      0.8
[03/03/2023-05:54:28] [I]                                                                                                   ReduceMax_120       15.39           0.4161      0.5
[03/03/2023-05:54:28] [I]                                                              PWN(222 + (Unnamed Layer* 193) [Shuffle], Pow_139)        1.55           0.0418      0.0
[03/03/2023-05:54:28] [I]                                                                                                   ReduceSum_140        1.45           0.0393      0.0
[03/03/2023-05:54:28] [I]                                                                                                   Transpose_143        0.02           0.0005      0.0
[03/03/2023-05:54:28] [I]                                                                                                      MatMul_135       24.27           0.6559      0.7
[03/03/2023-05:54:28] [I]                    PWN(PWN(Neg_141, PWN(PWN(220 + (Unnamed Layer* 190) [Shuffle], Mul_137), Sub_142)), Sub_144)       26.79           0.7242      0.8
[03/03/2023-05:54:28] [I]                                                                                                        TopK_145      288.35           7.7932      8.5
[03/03/2023-05:54:28] [I]                                                                    {ForeignNode[Transpose_154...Transpose_176]}      121.10           3.2729      3.6
[03/03/2023-05:54:28] [I]                                                                                                       Range_206        1.63           0.0442      0.0
[03/03/2023-05:54:28] [I]                                                            Reformatting CopyNode for Input Tensor 0 to Conv_177       80.58           2.1780      2.4
[03/03/2023-05:54:28] [I]                                                                                                        Conv_177       85.97           2.3235      2.5
[03/03/2023-05:54:28] [I]                                                  Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_178)       61.84           1.6715      1.8
[03/03/2023-05:54:28] [I]                                                                                              PWN(LeakyRelu_178)       54.37           1.4695      1.6
[03/03/2023-05:54:28] [I]                                                                                                   ReduceMax_179       31.18           0.8427      0.9
[03/03/2023-05:54:28] [I]                                                              PWN(299 + (Unnamed Layer* 278) [Shuffle], Pow_198)        2.87           0.0777      0.1
[03/03/2023-05:54:28] [I]                                                                                                   ReduceSum_199        2.53           0.0684      0.1
[03/03/2023-05:54:28] [I]                                                                                                   Transpose_202        0.02           0.0006      0.0
[03/03/2023-05:54:28] [I]                                                                                                      MatMul_194       31.82           0.8599      0.9
[03/03/2023-05:54:28] [I]                    PWN(PWN(Neg_200, PWN(PWN(297 + (Unnamed Layer* 275) [Shuffle], Mul_196), Sub_201)), Sub_203)       27.80           0.7514      0.8
[03/03/2023-05:54:28] [I]                                                                                                        TopK_204      288.13           7.7872      8.5
[03/03/2023-05:54:28] [I]                                                                    {ForeignNode[Transpose_213...Transpose_235]}      242.46           6.5531      7.1
[03/03/2023-05:54:28] [I]                                                            Reformatting CopyNode for Input Tensor 0 to Conv_236      162.61           4.3949      4.8
[03/03/2023-05:54:28] [I]                                                                                                        Conv_236      249.61           6.7463      7.3
[03/03/2023-05:54:28] [I]                                                  Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_237)      121.65           3.2879      3.6
[03/03/2023-05:54:28] [I]                                                                                              PWN(LeakyRelu_237)      108.79           2.9401      3.2
[03/03/2023-05:54:28] [I]                                                                                                   ReduceMax_238       64.71           1.7488      1.9
[03/03/2023-05:54:28] [I]                                                                                                        125 copy        3.13           0.0847      0.1
[03/03/2023-05:54:28] [I]                                                                                                        202 copy        2.92           0.0788      0.1
[03/03/2023-05:54:28] [I]                                                                                                        279 copy        5.55           0.1500      0.2
[03/03/2023-05:54:28] [I]                                                                                                        356 copy       10.76           0.2909      0.3
[03/03/2023-05:54:28] [I]                                                                                  (Unnamed Layer* 349) [Shuffle]        0.03           0.0007      0.0
[03/03/2023-05:54:28] [I]  Reformatting CopyNode for Input Tensor 0 to shuffle_between_(Unnamed Layer* 349) [Shuffle]_output_and_Conv_240       15.84           0.4281      0.5
[03/03/2023-05:54:28] [I]                                              shuffle_between_(Unnamed Layer* 349) [Shuffle]_output_and_Conv_240        0.03           0.0007      0.0
[03/03/2023-05:54:28] [I]                                                                                                        Conv_240       82.80           2.2380      2.4
[03/03/2023-05:54:28] [I]             Reformatting CopyNode for Input Tensor 0 to shuffle_after_(Unnamed Layer* 350) [Convolution]_output       25.24           0.6823      0.7
[03/03/2023-05:54:28] [I]                                                         shuffle_after_(Unnamed Layer* 350) [Convolution]_output        0.02           0.0007      0.0
[03/03/2023-05:54:28] [I]                                                                                              PWN(LeakyRelu_241)       20.63           0.5575      0.6
[03/03/2023-05:54:28] [I]                                                                                     squeeze_after_LeakyRelu_241        0.02           0.0005      0.0
[03/03/2023-05:54:28] [I]                                                                                               GlobalMaxPool_242       13.35           0.3608      0.4
[03/03/2023-05:54:28] [I]                                                                                                     Reshape_245        0.22           0.0061      0.0
[03/03/2023-05:54:28] [I]                                                                                           GlobalAveragePool_246       13.68           0.3697      0.4
[03/03/2023-05:54:28] [I]                                                                                                     Reshape_249        0.17           0.0047      0.0
[03/03/2023-05:54:28] [I]                                                                                  (Unnamed Layer* 373) [Shuffle]        0.01           0.0004      0.0
[03/03/2023-05:54:28] [I]                                                          Reformatting CopyNode for Input Tensor 0 to MatMul_251        0.02           0.0006      0.0
[03/03/2023-05:54:28] [I]                                                                                                      MatMul_251        1.42           0.0385      0.0
[03/03/2023-05:54:28] [I]                                                                                              PWN(LeakyRelu_253)        0.22           0.0060      0.0
[03/03/2023-05:54:28] [I]                                                                                                        Gemm_254        0.63           0.0172      0.0
[03/03/2023-05:54:28] [I]                                                  Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_256)        0.02           0.0005      0.0
[03/03/2023-05:54:28] [I]                                                                                              PWN(LeakyRelu_256)        0.21           0.0057      0.0
[03/03/2023-05:54:28] [I]                                                            Reformatting CopyNode for Input Tensor 0 to Gemm_257        0.01           0.0004      0.0
[03/03/2023-05:54:28] [I]                                                                                                        Gemm_257        0.29           0.0080      0.0
[03/03/2023-05:54:28] [I]                                      Reformatting CopyNode for Input Tensor 0 to (Unnamed Layer* 420) [Shuffle]        0.25           0.0069      0.0
[03/03/2023-05:54:28] [I]                                                                                  (Unnamed Layer* 420) [Shuffle]        0.01           0.0004      0.0
[03/03/2023-05:54:28] [I]                                                                                                           Total     3400.34          91.9011    100.0
[03/03/2023-05:54:28] [I] 

There is the TopK operation that takes 34% of inference time aprox. 30 ms. But still if I’ll resolve the issue with this operations it will take 60 ms. that is 3 times greater than pytorch, How to understand where is the issue? I converted tons of models that work with images classification, detection and segmentation but all the models performed much better than pytorch models.

P.S. I can not change the version of tensorrt to newer one because of dependencies.

Hi,

Request you to share the model, script, profiler, and performance output if not shared already so that we can help you better.

Alternatively, you can try running your model with trtexec command.

While measuring the model performance, make sure you consider the latency and throughput of the network inference, excluding the data pre and post-processing overhead.
Please refer to the below links for more details:

Thanks!

trt tried (look at my post) and script of model have posted