Description
Hi, I’am using multi-stream to improve TensorRT inference Latency & Throughput. Here’ the inference code I modified from TensorRT repo’s example. common_runtime.py
# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem:
"""Pair of host and device memory, where the host memory is wrapped in a numpy array"""
def __init__(self, size: int, dtype: Optional[np.dtype] = None, stream_num: int = 1):
dtype = dtype or np.dtype(np.uint8)
self._nbytes = size * dtype.itemsize
self.stream_num = stream_num
self._host_mem = []
self._device_mem = []
for _ in range(stream_num):
#allocate host
host_mem = cuda_call(cudart.cudaMallocHost(self._nbytes))
pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))
host_array = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
self._host_mem.append(host_array)
#allocate device
device_mem_ptr = cuda_call(cudart.cudaMalloc(self._nbytes))
self._device_mem.append(device_mem_ptr)
@property
def host(self) -> List[np.ndarray]:
assert type(self._host_mem) == list
return self._host_mem
@host.setter
def host(self, data: Union[np.ndarray, bytes]):
if isinstance(data, np.ndarray):
if data.size > self.host[0].size:
raise ValueError(
f"Tried to fit an array of size {data.size} into host memory of size {self.host.size}"
)
for i in range(self.stream_num):
np.copyto(self.host[i][:data.size], data.flat, casting='safe')
else:
assert self.host[0].dtype == np.uint8
for i in range(self.stream_num):
self.host[i][:self.nbytes] = np.frombuffer(data, dtype=np.uint8)
@property
def device(self) -> list[int]:
assert type(self._device_mem) == list
return self._device_mem
@property
def nbytes(self) -> int:
return self._nbytes
def __str__(self):
return f"Host:\n{[str(h) for h in self._host_mem]}\nDevice:\n{self._device_mem}\nSize:\n{self._nbytes}\n"
def __repr__(self):
return self.__str__()
def free(self):
"""Free all allocated host and device memory."""
for device in self._device_mem:
cuda_call(cudart.cudaFree(device))
for host in self._host_mem:
cuda_call(cudart.cudaFreeHost(host.ctypes.data))
#inputs and outputs are List of HostDeviceMem
def do_inference_v3_with_timing(context, bindings, inputs, outputs, streams, engine):
#1 2 4 8
#720p 1080p
#code change review explain
num_streams = len(streams)
total = 40
itr = total // num_streams
cudart.cudaDeviceSynchronize()
eventsBefore = [cuda_call(cudart.cudaEventCreate()) for _ in range(num_streams)]
eventsAfter = [cuda_call(cudart.cudaEventCreate()) for _ in range(num_streams)]
tik = time.perf_counter()
num_io = engine.num_io_tensors
for _ in range(1):
for i, stream in enumerate(streams):
kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
[cuda_call(cudart.cudaMemcpyAsync(inp.device[i], inp.host[i], inp.nbytes, kind, stream)) for inp in inputs]
cudart.cudaEventRecord(eventsBefore[i],stream) cudart.cudaStreamWaitEvent(stream,eventsBefore[i],cudart.cudaEventWaitDefault)
context.execute_async_v2(bindings=bindings[i], stream_handle=stream)
cudart.cudaEventRecord(eventsAfter[i],stream)
cudart.cudaStreamWaitEvent(stream,eventsAfter[i],cudart.cudaEventWaitDefault)
kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
[cuda_call(cudart.cudaMemcpyAsync(out.host[i], out.device[i], out.nbytes, kind, stream)) for out in outputs]
cudart.cudaDeviceSynchronize()
tok = time.perf_counter()
print(f"time: {(tok - tik) * 1000}ms")
return [out.host[0] for out in outputs]
the inputs are a List[List] , which inputs[i][j] is the i’th input tensor memory allocated for stream j, and the same for the outputs. And the idea of the code is that every time we execute we set the new bindings(tensor address).
But the inference results are wrong. Each inference has a different results. But after I inserted a cudart.cudaStreamSynchronize(stream) after [cuda_call(cudart.cudaMemcpyAsync(out.host[i], out.device[i], out.nbytes, kind, stream)) for out in outputs] the results seems to be ok. So I Checked the timeline using Nsight sys, I found there is memcpyh2d during the inference time
But using sync, there is no memcpyh2d during inference:
I want to know why. Since I have already inserted cuda event to prevent this. And Even if I don’t insert the cuda event, just with the stream Synchronize, the code seems to be ok(the second photo). I searched the internet, it seems we need to use multi context and mps? Plz help me if you know the best way to use multi-stream.
Environment
TensorRT Version: 8.6.2
GPU Type: A10
Nvidia Driver Version: 11.4
CUDA Version: 11.4
Operating System + Version: centos x86_64
Python Version (if applicable): 3.9