Description
A clear and concise description of the bug or issue.
Environment
TensorRT Version: 8.2.1
GPU Type: Jetson Xavier
Nvidia Driver Version:
CUDA Version: 10.2
CUDNN Version:
Operating System + Version:
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):
I have a point cloud classification model, written in pytorch:
# Ref: https://github.com/WangYueFt/dgcnn/blob/master/pytorch/model.py
# Ref: https://github.com/hansen7/OcCo/blob/master/OcCo_Torch/models/dgcnn_cls.py
import torch, torch.nn as nn, torch.nn.functional as F
def knn(x, k):
inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
idx = pairwise_distance.topk(k=k, dim=-1)[1]
return idx
def get_graph_feature(x, k=20, idx=None, extra_dim=False):
batch_size, num_dims, num_points = x.size()
x = x.view(batch_size, -1, num_points)
if idx is None:
if extra_dim is False:
idx = knn(x, k=k)
else:
idx = knn(x[:, 6:], k=k) # idx = knn(x[:, :3], k=k)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
idx += idx_base
idx = idx.view(-1)
x = x.transpose(2, 1).contiguous()
feature = x.view(batch_size*num_points, -1)[idx, :]
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
return feature # (batch_size, 2 * num_dims, num_points, k)
class get_model(nn.Module):
def __init__(self, args, num_channel=3, num_class=40, **kwargs):
super(get_model, self).__init__()
self.args = args
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
self.bn5 = nn.BatchNorm1d(args.emb_dims)
self.conv1 = nn.Sequential(nn.Conv2d(num_channel*2, 64, kernel_size=1, bias=False),
self.bn1,
nn.LeakyReLU(negative_slope=0.2))
self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
self.bn2,
nn.LeakyReLU(negative_slope=0.2))
self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
self.bn3,
nn.LeakyReLU(negative_slope=0.2))
self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
self.bn4,
nn.LeakyReLU(negative_slope=0.2))
self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False),
self.bn5,
nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512)
self.dp1 = nn.Dropout(p=args.dropout)
self.linear2 = nn.Linear(512, 256)
self.bn7 = nn.BatchNorm1d(256)
self.dp2 = nn.Dropout(p=args.dropout)
self.linear3 = nn.Linear(256, num_class)
def forward(self, x):
batch_size = x.size()[0]
x = get_graph_feature(x, k=self.args.k)
x = self.conv1(x)
x1 = x.max(dim=-1, keepdim=False)[0]
x = get_graph_feature(x1, k=self.args.k)
x = self.conv2(x)
x2 = x.max(dim=-1, keepdim=False)[0]
x = get_graph_feature(x2, k=self.args.k)
x = self.conv3(x)
x3 = x.max(dim=-1, keepdim=False)[0]
x = get_graph_feature(x3, k=self.args.k)
x = self.conv4(x)
x4 = x.max(dim=-1, keepdim=False)[0]
x = torch.cat((x1, x2, x3, x4), dim=1)
x = self.conv5(x)
x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
x = torch.cat((x1, x2), 1)
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
x = self.dp1(x)
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
x = self.dp2(x)
x = self.linear3(x)
return x
class get_loss(torch.nn.Module):
def __init__(self):
super(get_loss, self).__init__()
@staticmethod
def cal_loss(pred, gold, smoothing=True, label_weights=False):
"""Calculate cross entropy loss, apply label smoothing if needed."""
gold = gold.contiguous().view(-1)
if smoothing and type(label_weights) == bool:
eps = 0.2
loss = F.cross_entropy(pred, gold, reduction='mean', label_smoothing=eps)
# n_class = pred.size()[1]
# one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) # (num_points, num_class)
# one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
# log_prb = F.log_softmax(pred, dim=1)
# loss = -(one_hot * log_prb).sum(dim=1).mean() # ~ F.nll_loss(log_prb, gold)
elif type(label_weights) != bool and not smoothing:
label_weights = torch.as_tensor(label_weights, device="cuda", dtype=torch.float)
loss = F.cross_entropy(pred, gold, reduction='mean', weight=label_weights)
elif type(label_weights) != bool and smoothing:
eps = 0.2
label_weights = torch.as_tensor(label_weights, device="cuda", dtype=torch.float)
loss = F.cross_entropy(pred, gold, reduction='mean',
weight=label_weights, label_smoothing=eps)
else:
loss = F.cross_entropy(pred, gold, reduction='mean')
return loss
def forward(self, pred, target, label_weights=False):
return self.cal_loss(pred, target, smoothing=True, label_weights=label_weights)
Model inference time is 18 ms on pytorch.
Than I convert it to onnx and after it to tensorrt through trtexec to have possibility to profile the model.
The inference time when convert in fp32 is 114 ms in fp16 is 91 ms.
Here is the profiling output of the model:
[03/03/2023-05:54:28] [I] === Profile (37 iterations ) ===
[03/03/2023-05:54:28] [I] Layer Time (ms) Avg. Time (ms) Time %
[03/03/2023-05:54:28] [I] [HostToDeviceCopy] 1.17 0.0316 0.0
[03/03/2023-05:54:28] [I] PWN(68 + (Unnamed Layer* 23) [Shuffle], Pow_21) 0.89 0.0240 0.0
[03/03/2023-05:54:28] [I] ReduceSum_22 1.04 0.0281 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Output Tensor 0 to ReduceSum_22 1.51 0.0409 0.0
[03/03/2023-05:54:28] [I] Transpose_25 0.26 0.0069 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to MatMul_17 3.68 0.0994 0.1
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 1 to MatMul_17 2.35 0.0636 0.1
[03/03/2023-05:54:28] [I] MatMul_17 13.00 0.3515 0.4
[03/03/2023-05:54:28] [I] PWN(PWN(Neg_23, PWN(PWN(66 + (Unnamed Layer* 20) [Shuffle], Mul_19), Sub_24)), Sub_26) 26.20 0.7081 0.8
[03/03/2023-05:54:28] [I] TopK_27 288.02 7.7843 8.5
[03/03/2023-05:54:28] [I] Range_29 1.63 0.0440 0.0
[03/03/2023-05:54:28] [I] {ForeignNode[Transpose_36...Transpose_58]} 5.91 0.1598 0.2
[03/03/2023-05:54:28] [I] Range_88 0.83 0.0225 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to Conv_59 4.30 0.1162 0.1
[03/03/2023-05:54:28] [I] Conv_59 25.55 0.6906 0.8
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_60) 29.31 0.7923 0.9
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_60) 25.72 0.6950 0.8
[03/03/2023-05:54:28] [I] ReduceMax_61 15.38 0.4157 0.5
[03/03/2023-05:54:28] [I] PWN(145 + (Unnamed Layer* 108) [Shuffle], Pow_80) 1.59 0.0429 0.0
[03/03/2023-05:54:28] [I] ReduceSum_81 1.46 0.0395 0.0
[03/03/2023-05:54:28] [I] Transpose_84 0.02 0.0007 0.0
[03/03/2023-05:54:28] [I] MatMul_76 24.27 0.6561 0.7
[03/03/2023-05:54:28] [I] PWN(PWN(Neg_82, PWN(PWN(143 + (Unnamed Layer* 105) [Shuffle], Mul_78), Sub_83)), Sub_85) 26.11 0.7058 0.8
[03/03/2023-05:54:28] [I] TopK_86 288.28 7.7913 8.5
[03/03/2023-05:54:28] [I] {ForeignNode[Transpose_95...Transpose_117]} 121.05 3.2715 3.6
[03/03/2023-05:54:28] [I] Range_147 1.52 0.0410 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to Conv_118 80.58 2.1778 2.4
[03/03/2023-05:54:28] [I] Conv_118 57.20 1.5459 1.7
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_119) 29.88 0.8075 0.9
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_119) 26.59 0.7188 0.8
[03/03/2023-05:54:28] [I] ReduceMax_120 15.39 0.4161 0.5
[03/03/2023-05:54:28] [I] PWN(222 + (Unnamed Layer* 193) [Shuffle], Pow_139) 1.55 0.0418 0.0
[03/03/2023-05:54:28] [I] ReduceSum_140 1.45 0.0393 0.0
[03/03/2023-05:54:28] [I] Transpose_143 0.02 0.0005 0.0
[03/03/2023-05:54:28] [I] MatMul_135 24.27 0.6559 0.7
[03/03/2023-05:54:28] [I] PWN(PWN(Neg_141, PWN(PWN(220 + (Unnamed Layer* 190) [Shuffle], Mul_137), Sub_142)), Sub_144) 26.79 0.7242 0.8
[03/03/2023-05:54:28] [I] TopK_145 288.35 7.7932 8.5
[03/03/2023-05:54:28] [I] {ForeignNode[Transpose_154...Transpose_176]} 121.10 3.2729 3.6
[03/03/2023-05:54:28] [I] Range_206 1.63 0.0442 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to Conv_177 80.58 2.1780 2.4
[03/03/2023-05:54:28] [I] Conv_177 85.97 2.3235 2.5
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_178) 61.84 1.6715 1.8
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_178) 54.37 1.4695 1.6
[03/03/2023-05:54:28] [I] ReduceMax_179 31.18 0.8427 0.9
[03/03/2023-05:54:28] [I] PWN(299 + (Unnamed Layer* 278) [Shuffle], Pow_198) 2.87 0.0777 0.1
[03/03/2023-05:54:28] [I] ReduceSum_199 2.53 0.0684 0.1
[03/03/2023-05:54:28] [I] Transpose_202 0.02 0.0006 0.0
[03/03/2023-05:54:28] [I] MatMul_194 31.82 0.8599 0.9
[03/03/2023-05:54:28] [I] PWN(PWN(Neg_200, PWN(PWN(297 + (Unnamed Layer* 275) [Shuffle], Mul_196), Sub_201)), Sub_203) 27.80 0.7514 0.8
[03/03/2023-05:54:28] [I] TopK_204 288.13 7.7872 8.5
[03/03/2023-05:54:28] [I] {ForeignNode[Transpose_213...Transpose_235]} 242.46 6.5531 7.1
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to Conv_236 162.61 4.3949 4.8
[03/03/2023-05:54:28] [I] Conv_236 249.61 6.7463 7.3
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_237) 121.65 3.2879 3.6
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_237) 108.79 2.9401 3.2
[03/03/2023-05:54:28] [I] ReduceMax_238 64.71 1.7488 1.9
[03/03/2023-05:54:28] [I] 125 copy 3.13 0.0847 0.1
[03/03/2023-05:54:28] [I] 202 copy 2.92 0.0788 0.1
[03/03/2023-05:54:28] [I] 279 copy 5.55 0.1500 0.2
[03/03/2023-05:54:28] [I] 356 copy 10.76 0.2909 0.3
[03/03/2023-05:54:28] [I] (Unnamed Layer* 349) [Shuffle] 0.03 0.0007 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to shuffle_between_(Unnamed Layer* 349) [Shuffle]_output_and_Conv_240 15.84 0.4281 0.5
[03/03/2023-05:54:28] [I] shuffle_between_(Unnamed Layer* 349) [Shuffle]_output_and_Conv_240 0.03 0.0007 0.0
[03/03/2023-05:54:28] [I] Conv_240 82.80 2.2380 2.4
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to shuffle_after_(Unnamed Layer* 350) [Convolution]_output 25.24 0.6823 0.7
[03/03/2023-05:54:28] [I] shuffle_after_(Unnamed Layer* 350) [Convolution]_output 0.02 0.0007 0.0
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_241) 20.63 0.5575 0.6
[03/03/2023-05:54:28] [I] squeeze_after_LeakyRelu_241 0.02 0.0005 0.0
[03/03/2023-05:54:28] [I] GlobalMaxPool_242 13.35 0.3608 0.4
[03/03/2023-05:54:28] [I] Reshape_245 0.22 0.0061 0.0
[03/03/2023-05:54:28] [I] GlobalAveragePool_246 13.68 0.3697 0.4
[03/03/2023-05:54:28] [I] Reshape_249 0.17 0.0047 0.0
[03/03/2023-05:54:28] [I] (Unnamed Layer* 373) [Shuffle] 0.01 0.0004 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to MatMul_251 0.02 0.0006 0.0
[03/03/2023-05:54:28] [I] MatMul_251 1.42 0.0385 0.0
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_253) 0.22 0.0060 0.0
[03/03/2023-05:54:28] [I] Gemm_254 0.63 0.0172 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to PWN(LeakyRelu_256) 0.02 0.0005 0.0
[03/03/2023-05:54:28] [I] PWN(LeakyRelu_256) 0.21 0.0057 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to Gemm_257 0.01 0.0004 0.0
[03/03/2023-05:54:28] [I] Gemm_257 0.29 0.0080 0.0
[03/03/2023-05:54:28] [I] Reformatting CopyNode for Input Tensor 0 to (Unnamed Layer* 420) [Shuffle] 0.25 0.0069 0.0
[03/03/2023-05:54:28] [I] (Unnamed Layer* 420) [Shuffle] 0.01 0.0004 0.0
[03/03/2023-05:54:28] [I] Total 3400.34 91.9011 100.0
[03/03/2023-05:54:28] [I]
There is the TopK operation that takes 34% of inference time aprox. 30 ms. But still if I’ll resolve the issue with this operations it will take 60 ms. that is 3 times greater than pytorch, How to understand where is the issue? I converted tons of models that work with images classification, detection and segmentation but all the models performed much better than pytorch models.
P.S. I can not change the version of tensorrt to newer one because of dependencies.