Description
I want to do inference with a TensorRT engine on PyTorch GPU tensors. However, using the code below, if I create the tensors after I have created my execution context, I get the following error:
import tensorrt as trt
import torch
import pycuda.driver as cuda
import pycuda.autoinit
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
with open("model.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(f.read()) as engine, engine.create_execution_context() as context:
output_buffer = cuda.mem_alloc(4*288*768*4)
stream = cuda.Stream()
for i in range(1):
tensor = torch.randn((4, 288, 768, 4), dtype=float, device=torch.device('cuda'))
context.execute_async_v2(bindings=[int(tensor.data_ptr()), int(output_buffer)],
stream_handle=stream.handle)
stream.synchronize()
TensorRT] ERROR: …/rtExt/cuda/cudaGatherRunner.cpp (111) - Cuda Error in execute: 400 (invalid resource handle)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception
If I make the tensor before I create the execution context, there are no errors.
import tensorrt as trt
import torch
import pycuda.driver as cuda
import pycuda.autoinit
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
tensor = torch.randn((4, 288, 768, 4), dtype=float, device=torch.device('cuda'))
with open("model.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(f.read()) as engine, engine.create_execution_context() as context:
output_buffer = cuda.mem_alloc(4*288*768*4)
stream = cuda.Stream()
for i in range(1):
context.execute_async_v2(bindings=[int(tensor.data_ptr()), int(output_buffer)], stream_handle=stream.handle)
stream.synchronize()
Is there any way to create a TRT engine and then perform inference on PyTorch tensors that are created after the execution context? I assume it has to do with CUDA contexts?
Environment
TensorRT Version: 7.2:
GPU Type: Quadro RTX 3000:
Nvidia Driver Version: 460.56:
CUDA Version 11.1:
CUDNN Version:
Operating System + Version:
Python Version: 3.6:
TensorFlow Version (if applicable):
PyTorch Version: 1.8: