Accelerating Inference Up to 6x Faster in PyTorch with Torch-TensorRT

Originally published at: Accelerating Inference Up to 6x Faster in PyTorch with Torch-TensorRT | NVIDIA Developer Blog

Torch-TensorRT is a PyTorch integration for TensorRT inference optimizations on NVIDIA GPUs. With just one line of code, it speeds up performance up to 6x on NVIDIA GPUs.

Does this one line of code need to be run on the target machine for optimization like previous conversions?

Hi, thank you for your post.
Can I expect the speedup by using Toch-TensorRT in case of my low performance laptop GPU such as Geforce GTX 1060 as well?

“Torch-TensorRT” shouldn’t be used for higher precision inference, right?

Yes, Torch-TensorRT should be used on the target machine since TensorRT optimizations are dependent on the system’s configuration.

You can expect some speedup (not always) but that’d depend on the DNN too.

DNN’s are typically trained in FP32 or in mixed precision (FP32 + FP16) and inferenced in FP32/FP16/INT8/INT4.
A higher precision (like FP64) is typically not used for inference.

Except using ngc pytorch docker images, can we install Torch-TensorRT in our custom env, like using pip install?

I was able to manually download wheels to install everything so you should be able to. I don’t have sudo access on the cloud machine so i had to do everything without sudo, which was a pain. Otherwise there should be easier instructions to follow

1 Like

Where did you download wheel file? I just found torch-tensorrt 0.0.0 in pypi.

Any docs about Torch-TensorRT? How I know these APIs?

Hi Ashish
Thank you for the article.

I found one typo in your code in:

traced_model = torch.jit.trace(model, torch.randn((1,3,224,224)).to("cuda")])

] at the end is not needed

Also, the last benchmark fails with error

>>> benchmark(trt_model, input_shape=(1, 3, 224, 224), nruns=100, dtype="fp16")
Warm up ...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 9, in benchmark
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: [Error thrown at core/runtime/register_trt_op.cpp:89] Expected inputs[pyt_idx].dtype() == expected_type to be true but got false
Expected input tensors to have type Float, found type c10::Half

I think you need to use dtype="fp32".
The model was compiled with half precision, but the input_data should still be float32

Tested in docker pull nvcr.io/nvidia/pytorch:22.04-py3

@asardana @anishm @ashishsardana

Torch-TensorRT has recently moved inside the PyTorch project. You’d find the documentation here - Torch-TensorRT — Torch-TensorRT master documentation

Hi Alexander, thanks for trying out the code for Torch-TensorRT and pointing out the typo!

When publishing this blog (+ code), we tested it in the pytorch:19.11 environment (docker image - nvcr.io/nvidia/pytorch:19.11-py3). This docker image contains Torch-TensorRT v0.4.1
The environment you are in (pytorch:20.04) contains Torch-TensorRT v1.1, hence the difference in the output you’re observing.

@apivovarov – Fixed the code typo, thanks!

Ashish,

  1. pytorch/TensorRT 0.4.1 module name was trtorch. Your blog uses torch_tensorrt.
  2. trtorch was renamed to torch_tensorrt since v1.0.0.
  3. I tried nvcr.io/nvidia/pytorch:19.11-py3. It does not have trtorch and torch_tensorrt.

If you blog was updated to use Torch-TensorRT v1.0.0+ API then the code also should be updated.
It would be nice if you fix the code and replace

benchmark(trt_model, input_shape=(1, 3, 224, 224), nruns=100, dtype="fp16")

with

benchmark(trt_model, input_shape=(1, 3, 224, 224), nruns=100, dtype="fp32")

@jwitsoe @ashishsardana @anishm

If we want to build torch_tensorrt outside the docker, does it work with the new agx orin?

1 Like

Hi, thanks for your support.

I use the python api guide from Using Torch-TensorRT in Python — Torch-TensorRT master documentation

my code:
def load_torchmodel():

model = load_torchmodel()

inputs = [torch_tensorrt.Input(
min_shape=[1, 1, 320, 320],
opt_shape=[1, 1, 640, 640],
max_shape=[1, 1, 1280, 1280],
dtype=torch.half,
)]
enabled_precisions = {torch.float, torch.half} # Run with fp16

trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions)

input_data = input_data.to(‘cuda’).half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, “trt_ts_module.ts”)

but get an error below. Could you help me?

Traceback (most recent call last):
File “trt_infer3.py”, line 236, in
trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions)
File “/usr/local/lib/python3.6/dist-packages/torch_tensorrt/_compile.py”, line 96, in compile
ts_mod = torch.jit.script(module)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_script.py”, line 1258, in script
obj, torch.jit._recursive.infer_methods_to_compile
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 451, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 513, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_script.py”, line 587, in _construct
init_fn(script_module)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 491, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 513, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_script.py”, line 587, in _construct
init_fn(script_module)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 491, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 463, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 732, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 66, in make_stub_from_method
return make_stub(func, method_name)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py”, line 51, in make_stub
ast = get_jit_def(func, name, self_name=“RecursiveScriptModule”)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 264, in get_jit_def
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 315, in build_def
build_stmts(ctx, body))
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 137, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 137, in
stmts = [build_stmt(ctx, s) for s in stmts]
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 287, in call
return method(ctx, node)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 550, in build_Return
return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 287, in call
return method(ctx, node)
File “/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py”, line 988, in build_DictComp
raise NotSupportedError(r, “Comprehension ifs are not supported yet”)
torch.jit.frontend.NotSupportedError: Comprehension ifs are not supported yet:
File “/home/Project/YOLOX/yolox/models/darknet.py”, line 179
x = self.dark5(x)
outputs[“dark5”] = x
return {k: v for k, v in outputs.items() if k in self.out_features}