Torch_tensorrt and YOLO

Description

Hi, folks.

I’m trying to convert a YOLO model using the new torch_tensorrt API and I’m getting some issues.

A clear and concise description of the bug or issue.

Environment

All the libraries and dependencies are working well. I did the SSD test etc etc etc.

Steps To Reproduce

I’m using the following code:

import torch

import torch_tensorrt

model = torch.hub.load(‘ultralytics/yolov3’, ‘yolov3’) # or yolov3-spp, yolov3-tiny, custom
img = ‘https://ultralytics.com/images/zidane.jpg’ # or file, Path, PIL, OpenCV, numpy, list
model.eval()

traced_model = torch.jit.trace(model, [torch.randn((1,3,416,416)).to(“cuda”)], strict=False )

torch.jit.save(traced_model, “yolo.jit.pt”)
results = model(img)

results.print()

Output error message

Fusing layers…
Model Summary: 261 layers, 61922845 parameters, 0 gradients
Adding AutoShape…
/home/adriano/.cache/torch/hub/ultralytics_yolov3_master/models/yolo.py:58: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch/jit/_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module’s inputs. Consider using a constant container instead (e.g. for list, use a tuple instead. for dict, use a NamedTuple instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
module._c._create_method_from_trace(
Traceback (most recent call last):
File “/home/adriano/Documents/yolov3/test.py”, line 13, in
traced_model = torch.jit.trace(model, [torch.randn((1,3,416,416)).to(“cuda”)])
File “/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch/jit/_trace.py”, line 741, in trace
return trace_module(
File “/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch/jit/_trace.py”, line 983, in trace_module
_check_trace(
File “/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch/autograd/grad_mode.py”, line 28, in decorate_context
return func(*args, **kwargs)
File “/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch/jit/_trace.py”, line 526, in _check_trace
raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
graph(%self.1 : torch.models.common.AutoShape,
%imgs.1 : Tensor):
%model : torch.models.yolo.Model = prim::GetAttrname=“model”
%model.1 : torch.models.yolo.Model = prim::GetAttrname=“model”
%model.3 : torch.torch.nn.modules.container.Sequential = prim::GetAttrname=“model”
%_0.1 : torch.models.common.Conv = prim::GetAttrname=“0”
%conv.1 : torch.torch.nn.modules.conv.Conv2d = prim::GetAttrname=“conv”
%weight.149 : Tensor = prim::GetAttrname=“weight”
%8 : int = prim::Constantvalue=6 # /home/adriano/.cache/torch/hub/ultralytics_yolov3_master/models/common.py:439:0
%9 : int = prim::Constantvalue=0 # /home/adriano/.cache/torch/hub/ultralytics_yolov3_master/models/common.py:439:0
%10 : Device = prim::Constantvalue=“cuda:0” # /home/adriano/.cache/torch/hub/ultralytics_yolov3_master/models/common.py:439:0
%11 : NoneType = prim::Constant()

… +

P.S: I know none who use SSD in production, but I know a lot of people who use YOLO. I didn’t understand why the NVIDIA use SSD instead YOLO in our examples. You may say it’s an easy model… ok, I know… but it isn’t useful in practice.

Hi,
Please check the below link, as they might answer your concerns

Thanks!

Sorry, but I think you don’t understand what I’m trying to say… Maybe the problem is in my description. Well, I will share my code.

First of all, I’m using the Ultralytics repository (GitHub - ultralytics/yolov3: YOLOv3 in PyTorch > ONNX > CoreML > TFLite). To replicate what I did, you need to clone it. After that, you may create a file (you can choose the name you want for it) and put this code in:

`import cv2
import torch
import torch_tensorrt
from utils.general import non_max_suppression, scale_coords

import matplotlib.pyplot as plt
from models.experimental import attempt_load

# Letterbox function
def letterbox(image, desired_size):
    old_size = image.shape[:2]

    ratio = float(desired_size)/max(old_size)
    new_size = tuple([int(x*ratio) for x in old_size])

    im = cv2.resize(image, (new_size[1], new_size[0]))
    
    delta_w = desired_size - new_size[1]
    delta_h = desired_size - new_size[0]
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)

    color = [0, 0, 0]
    new_img = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,value=color)
    
    
    return new_img


weights = "yolov3.pt"
device = "cuda"
img_size = 416
outfile = "yolov3-traced.ts"
img_path = "/home/adriano/Pictures/zidane.jpg"

model = attempt_load(weights, map_location=device, inplace=True, fuse=True) 

# Eval mode
model.eval()

# Image base
img_base = torch.rand((1, 3) + (img_size, img_size)).to(device)

# Model load
model(img_base)

# Create traced model
traced_model = torch.jit.trace(model, [img_base], strict=False)

# Compile settings to be used with TensorRT
compile_settings = {
    "inputs": [torch_tensorrt.Input((1, 3, 416, 416), dtype=torch.float)],
    "enabled_precisions": {torch.float, torch.half}, # Run with FP16
    "workspace_size": 1 << 20
}

# Problematic code!!! 
#traced_model = torch_tensorrt.compile(traced_model, **compile_settings)

# Save the jit file
torch.jit.save(traced_model, outfile)


# Torch_TensorRT inference
traced_model =torch.jit.load(outfile)

half = False

# Imagem processing
img = cv2.imread(img_path)
shape = img.shape
img_org = img.copy()

img = letterbox(img, 416)
im = torch.from_numpy(img).to("cuda").unsqueeze_(0)
im = im.half() if half else im.float()  # uint8 to fp16/32
im /= 255  # 0 - 255 to 0.0 - 1.0
im = im.permute(0, 3, 1, 2)


# Detection steps
pred, _ = traced_model(im)
det = non_max_suppression(pred, 0.3, 0.3, None, False, max_det=1000)[0]
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], shape).round()
out_detections = det.cpu().detach().numpy()


# Drawing the annotation
if (len(out_detections) > 0):
    for d in out_detections:        
        cv2.rectangle(img_org,(int(d[0]), int(d[1])), (int(d[2]), int(d[3])), (0, 255, 0) ,2)
        print(d)
`

Using this code, I can run the inference. Even with this warn:

/home/adriano/Documents/yolov3/models/yolo.py:58: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
[        740          41        1150         712     0.93587           0]
[        439         430         515         718      0.8167          27]

However, if I remove the comment here:

# Problematic code!!! 
#traced_model = torch_tensorrt.compile(traced_model, compile_settings)

I receive this error message:

WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT TorchScript Conversion Context] - IElementWiseLayer with inputs 360 and (Unnamed Layer* 250) [Shuffle]_output: first input has type Float but second input has type Int32.
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [layers.cpp::validate::2384] Error Code 4: Internal Error (%3080 : Tensor = aten::mul(%360, %10), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0: operation PROD has incompatible input types Float and Int32)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [elementWiseNode.cpp::computeOutputExtents::14] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 3: [network.cpp::addConstant::1191] Error Code 3: Internal Error (Parameter check failed at: optimizer/api/network.cpp::addConstant::1191, condition: !weights.values == !weights.count
)
Traceback (most recent call last):
  File "/home/adriano/Documents/yolov3/test.py", line 59, in <module>
    traced_model = torch_tensorrt.compile(traced_model, **compile_settings)
  File "/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch_tensorrt/_compile.py", line 97, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/home/adriano/anaconda3/envs/torchtensorrt/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/converters/impl/select.cpp:236] Expected const_layer to be true but got false
Unable to create constant layer from node: %365 : Tensor = aten::slice(%y.1, %8, %25, %21, %27), scope: __module.model.28 # /home/adriano/Documents/yolov3/models/yolo.py:63:0

Hi,

We recommend you to please post your concern here Issues · pytorch/TensorRT · GitHub to get better help.

Thank you.

@AdrianoSantosPB I am encountering the same error trying to get YOLOv5 to run with libtorch-tensorrt in C++. Did you end up solving the issue?

Hi. No, I don’t. But I made it in C++ my own library and it’s working very well. Take a look at this repo GitHub - wang-xinyu/tensorrtx: Implementation of popular deep learning networks with TensorRT network definition API