Better inference performance with PyTorch than with TensorRT

Hello,

I am running an inference with ResNet50 using TensorRT on Python with Jetpack version 5.0.2-b231 on Jetson AGX Xavier. I am processing a variable number of detections to extract features so that the engine has been generated with dynamic batch from an ONNX model with variable input and output. The problem is that using dynamic batch makes the process much slower using TensorRT than using the original PyTorch model. You can find the original model here: GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

I would like to know if there is any way to get higher performance using dynamic shape with TensorRT.

Thanks, Paula.

Hi,

Could you share performance data you got with TensorRT and PyTorch?
And the detailed steps to reproduce the score so we can check it in our environment as well.

Thanks.

Hello,

The inference time with PyTorch is about 63 ms and with TensorRT is about 686 ms per frame
Every frame has about 10 detections, so that I have created the engine with minShapes, maxShapes and optShapes parameters:

/usr/src/tensorrt/bin/./trtexec --onnx=/path/supcon_batch_variable.onnx --saveEngine=/path/supcon_batch_variable_fp32_opt8.trt --workspace=6000 --minShapes=input:1x3x224x224 --maxShapes=input:16x3x224x224 --optShapes=input:8x3x224x224

Also, I have created the onnx model using dynamic axes:

batch_size = 1
dummy_input = torch.randn(batch_size, 3, 224, 224)
torch.onnx.export(model, dummy_input, “/home/path/supcon_batch_variable.onnx”,
input_names=[‘input’], # the model’s input names
output_names=[‘output’], # the model’s output names
dynamic_axes={‘input’ : {0 : ‘batch_size’}, # variable length axes
‘output’ : {0 : ‘batch_size’}})

Finally, the inference with PyTorch is carried out as in the repository attached in the link: GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally). While with TensorRT I have created a context using Python.

Hi,

We meet some issues when converting the TensorRT engine:

...
[10/25/2022-09:25:38] [I] [TRT] ----------------------------------------------------------------
[10/25/2022-09:25:38] [E] [TRT] ModelImporter.cpp:773: While parsing node number 49 [Conv -> "onnx::Relu_510"]:
[10/25/2022-09:25:38] [E] [TRT] ModelImporter.cpp:774: --- Begin node ---
[10/25/2022-09:25:38] [E] [TRT] ModelImporter.cpp:775: input: "input.4"
input: "onnx::Conv_511"
input: "onnx::Conv_512"
output: "onnx::Relu_510"
name: "Conv_49"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "group"
  i: 1
  type: INT
}
attribute {
  name: "kernel_shape"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}
attribute {
  name: "strides"
  ints: 1
  ints: 1
  type: INTS
}

[10/25/2022-09:25:38] [E] [TRT] ModelImporter.cpp:776: --- End node ---
[10/25/2022-09:25:38] [E] [TRT] ModelImporter.cpp:778: ERROR: ModelImporter.cpp:163 In function parseGraph:
[6] Invalid Node - Conv_49
The bias tensor is required to be an initializer for the Conv operator. Try applying constant folding on the model using Polygraphy: https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/examples/cli/surgeon/02_folding_constants
[10/25/2022-09:25:38] [E] Failed to parse onnx file
[10/25/2022-09:25:38] [I] Finish parsing network model
[10/25/2022-09:25:38] [E] Parsing model failed
[10/25/2022-09:25:38] [E] Failed to create engine from model or file.
[10/25/2022-09:25:38] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8401] # /usr/src/tensorrt/bin/trtexec --onnx=./supcon_batch_variable.onnx --saveEngine=./supcon_batch_variable_fp32_opt8.trt --workspace=6000 --minShapes=input:1x3x224x224 --maxShapes=input:16x3x224x224 --optShapes=input:8x3x224x224

Below is our convert script, could you please check if any difference from yours?

from resnet_big import SupConResNet
import torch

model = SupConResNet(name='resnet50')

batch_size = 1
dummy_input = torch.randn(batch_size, 3, 224, 224)
torch.onnx.export(model, dummy_input, './supcon_batch_variable.onnx',
  input_names=['input'],
  output_names=['output'],
  dynamic_axes={'input' : {0 : 'batch_size'}, 'output': {0 : 'batch_size'}})

Thanks.

Hello,

I find some differences between your code and mine.

from utils import resnet_big

import torch
batch_size = 1
dummy_input = torch.randn(batch_size, 3, 224, 224)
state_dict = torch.load(‘./supcon.pth’)[‘model’]

model = resnet_big.SupConResNet()
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove ‘module.’ of dataparallel
new_state_dict[name] = v

model.load_state_dict(new_state_dict)

I don’t know if this code solves your problem. Is the ONNX model correctly transformed?

Thanks.

Hi,

Did you use the pre-trained model shared on the repository?
We test it but the weight is not compatible and met the same error as below:

Could you share the network architecture or the model if you trained it on your own with us?
Thanks.

Hello,

Yes, I am using the model shared on the repository. You must follow exactly the same steps that I showed in my previous post, loading the model using the repository network definition.

I shared with you the ONNX model in a zipped folder and I think you can convert it to TensorRT directly.

Thanks.
supcon_batch_variable.zip (99.3 MB)

Hello,

Is there any news about this topic?

Thanks.

Hi,

Thanks for your patience.
Could you verify if the PyTorch inference is also using batch size=8?

We test the ONNX model with TensorRT and ONNXRuntime on Xavier.

In TensorRT, we got 84.6757ms for batchsize=1 and 652.653ms for batchsize=8.
In ONNXRuntime, batchsize=1 takes 94.296ms while batchsize=8 needs 684.891ms.

So it looks like the performance difference comes from the different batch sizes used.

Thanks.

Hello,

Yes, PyTorch inference is also using batch size=8.

I agree with you in TensorRT performance, I get the same time for batch size =1 and batch size=8, but the question is why is the process using TensorRT so much slower than using PyTorch?

And finally, as I said in my first post: Is there is any way to get higher performance using dynamic shape with TensorRT?

Thanks.

Hi,

Could you share the inference source so we can reproduce the PyTorch result?

We test the ONNX model with ONNXRuntime.
The elapsed time of batchsize=8 is 684.891ms, which is larger than TensorRT.

It will be good if we can reproduce the PyTorch result first.
Thanks.

Hello,

You can find the original model and the original code here: GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

In general, the inference process that I have performanced looks like this:

import resnet_big
state_dict = torch.load(‘/path/supcon.pth’)[‘model’]
model = resnet_big.SupConResNet()
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove ‘module.’ of dataparallel
new_state_dict[name] = v

model.load_state_dict(new_state_dict)

model = model.float().eval().cuda()

crops → images
crops = torch.Tensor(crops).permute(0,3,1,2)
descriptors = self.model(crops.cuda())

Thanks.

Hi,

Have you tried that latest model and latest source shared in the repository?
It doesn’t work since the error mentioned in the Oct 26.

size mismatch for encoder.conv1.weight: copying a param with shape torch.Size([64, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).

Based on your implementation, does the images contains 8 image that size is 3x224x224?

Thanks.

Hello,

Yes, I have tried the latest model shared in the repository and the code too.

It seems like a code problem, you must follow exactly the same steps that I showed you the 4th of November. The next import is from the following code link: https://github.com/HobbitLong/SupContrast/tree/master/networks

import resnet_big

Regarding the second question, my input size is 8x3x224x224.

Thanks.

Hello,

Sorry, I was wrong about the code. I did not remember that I had changed the code to make it work. You must change self.shortcut by self.downsample in resnet_big.py file and kernel_size in line 80 by 7. I hope it works.

Sorry again and thanks you.

Hi,

Thanks for the hint.
We can run the model with PyTorch after the change you mentioned.

Below is the performance data that we test for batch=1 and batch=8.
It seems that TensorRT give a better performance compared to ONNXRuntime or PyTorch.

TensorRT

  • Batch=1: 84.6757ms
  • Batch=8: 652.653ms

PyTorch (tested by inference.py (979 Bytes))

  • Batch=1: 123.793ms
  • Batch=8: 920.4219ms

ONNXRuntime

  • Batch=1: 94.296ms
  • Batch=8: 684.891ms

Could you help to confirm it?

Thanks.

Hello,

Sorry for the delay.

Checking your pytorch code I have discovered that I was doing fuision of layers and that’s why the pytorch model was faster.

Thanks for your time.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.