Description
Hello, I am a student, I am working on a project, that my professor gave me.
The project is to convert from the pytracking framework ToMP101 to TensorRT. Because the model is tangled up with the training and inference phase (You will see that inside tompnet.py the forward function is used for training, not inference, as well as the head of the tompnet), inference is done by calling individual functions, so I put them inside 2 separate forward functions. One for initialization, and one for tracking.
The Initialization model just uses the RESNET101 backbone, so it converts to ONNX then TensorRT and runs without any problems.
But the model used for inference poses a lot more problems. After lots of trying I was finally able to successfully convert it to ONNX then to TensorRT and run inference, but now the inference is accuracy is VERY LOW. I cannot seem to understand what is the problem.
code used to convert pytorch to onnx:
input_names = ['sample_x', 'train_samples', 'target_labels', 'train_ltrb']
output_names = ['target_scores', 'bbreg_test_feat_enc', 'bbreg_weights']
torch.onnx.export(model,
model_inputs,
"tomp101_head_latest3.onnx",
verbose=False,
export_params=True,
do_constant_folding=True,
opset_version=16,
input_names=input_names,
output_names=output_names,
dynamic_axes={'sample_x':{0:'batch_size'},
'train_samples':{0:'batch_size'},
'target_labels':{0:'batch_size'},
'train_ltrb':{0:'batch_size'},
'target_scores':{0:'batch_size'},
'bbreg_test_feat_enc':{0:'batch_size'},
'bbreg_weights':{0:'batch_size'}})
If I don’t specify dynamic_axes I get this error, even though I make sure that everything is on CUDA, or on 2nd attempt, that everything is on CPU:
params_dict = _C._jit_pass_onnx_constant_fold(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
I converted the model using trtexec:
trtexec --onnx=tomp101_head_latest3.onnx --saveEngine=trt_tomp101_head_latest3.engine --verbose --explicitBatch --useCudaGraph
The verbose gives me 2 types of warnings:
[TRT] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/12/2024-12:17:31] [W] [TRT] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
I hope that this is what is causing the inaccuracies.
I ran some code to track down where in my model is int64 used, ran some code to look for torch.int64 types and this answer is all I could find (I could be missing something):
Layer 'filter_predictor.box_encoding.1' has an INT64 buffer named 'num_batches_tracked'
I tracked down that int64 is used inside this nn.Sequential, even though it is used ONLY ONCE in the code and ONLY torch.float32 gets passed through:
def MLP(channels, do_bn=True):
n = len(channels)
layers = []
for i in range(1, n):
layers.append(
nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n-1):
if do_bn:
layers.append(nn.BatchNorm1d(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
Can clamping this one int64 to int32 be the reason causing the inaccuracies in the converted TensorRT?
[01/12/2024-12:17:31] [W] Dynamic dimensions required for input: sample_x, but no shapes were provided. Automatically overriding shape to: 1x1024x18x18
[01/12/2024-12:17:31] [W] Dynamic dimensions required for input: train_samples, but no shapes were provided. Automatically overriding shape to: 1x1024x18x18
[01/12/2024-12:17:31] [W] Dynamic dimensions required for input: target_labels, but no shapes were provided. Automatically overriding shape to: 1x1x18x18
[01/12/2024-12:17:31] [W] Dynamic dimensions required for input: train_ltrb, but no shapes were provided. Automatically overriding shape to: 1x4x18x18
I am assuming that we can ignore the second warning, because the shapes are provided static.
I am running inference with polygraphy:
from polygraphy.backend.trt import EngineFromNetwork, NetworkFromOnnxPath, TrtRunner
with open("trt_tomp101_head_latest3.engine", "rb") as f:
engine_data = f.read()
runtime1 = trt.Runtime(trt.Logger(trt.Logger.WARNING))
engine1 = runtime1.deserialize_cuda_engine(engine_data)
....
But the inference accuracy is terrible…
ALSO
When I open the model in NETRON, converting to ONNX turns my outputs into this:
target_scores: float32[batch_size,1,Reshapetarget_scores_dim_2,18]
bbreg_test_feat_enc: float32[batch_size,1,Unsqueezebbreg_test_feat_enc_dim_2,Unsqueezebbreg_test_feat_enc_dim_3,18]
bbreg_weights: float32[batch_size,Unsqueezebbreg_weights_dim_1,Unsqueezebbreg_weights_dim_2,Unsqueezebbreg_weights_dim_3]
Anybody that has any experience with this please help me, because I cannot seem to find an answer anywhere.
You can download my onnx model from my google drive here
If any additional information is needed, please let me know
THANK YOU
Environment
TensorRT Version: 8.6
GPU Type: GTX 1660 Ti
Nvidia Driver Version: 546.01
CUDA Version: 12.1
CUDNN Version: 8.9.7
Operating System + Version: Windows 10
Python Version (if applicable): 3.10.13
PyTorch Version (if applicable): 2.1.2+cu121
Baremetal or Container (if container which image + tag): Run without any environment, straight on Windows
Relevant Files
trtexec verbose log (without warnings): lognew.txt (4.9 MB)