Description
Hello,
I plan to run 2 contexts for an engine via Cuda stream technologies. Each context initiates a new cuda stream, however, based on the tracing, I found that the results failed. Besides, I assign stream 20 with higher priority, you can find in the trace that kernels in this stream are blocked for a while. I guess if memory copy (D-H) bocks the other threads, how to solve it?
my code:
import numpy as np
from cuda import cuda
import tensorrt as trt
import time
import threading
import os
import re
# Use os.getcwd() to get the current working directory
current_dir = os.getcwd()
# Regular expression to match any content before "strait"
pattern = r'(.*?)/strait'
# Using search() to find the match
match = re.search(pattern, current_dir)
# Directory of Project
project_dir = match.group(1)
# Initialize TensorRT logger
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
# Create a runtime object
runtime = trt.Runtime(TRT_LOGGER)
class tensorrt_context:
def __init__(self, trt_model, priority):
try:
with open(trt_model, "rb") as f:
deserialized_engine = f.read()
except FileNotFoundError:
print(f"Error: File {trt_model} not found")
exit(1)
self.engine = runtime.deserialize_cuda_engine(deserialized_engine)
if self.engine is None:
print("Error: Failed to deserialize the engine")
exit(1)
# Create an execution context
self.context = self.engine.create_execution_context()
if self.context is None:
print("Error: Failed to create execution context")
exit(1)
# Input tensor name
self.input_tensor_name = self.engine.get_tensor_name(0)
# Output tensor name
self.output_tensor_name = self.engine.get_tensor_name(1)
# Minimum batch size for this engine model
self.minimum_batch_size = self.engine.get_tensor_profile_shape(self.input_tensor_name, 0)[0][0]
# Maximum batch size for this engine model
self.max_batch_size = self.engine.get_tensor_profile_shape(self.input_tensor_name, 0)[-1][0]
# Input data shape
self.input_shape = tuple(self.engine.get_tensor_shape(self.input_tensor_name))[1:]
# Output data shape
self.output_shape = tuple(self.engine.get_tensor_shape(self.output_tensor_name))[1:]
# Maximum batch input
self.max_input_shape = (self.max_batch_size,) + self.input_shape
# Maximum batch output
self.max_output_shape = (self.max_batch_size,) + self.output_shape
# Initialize a cuda stream
self.init_cuda_stream(priority)
self.device_memory_allocation()
self.warm_up()
self.test_data()
def init_cuda_stream(self, priority):
# Create a CUDA stream
err_stream, self.stream = cuda.cuStreamCreateWithPriority(0, priority)
#cuda.cuCtxGetStreamPriorityRange()
print(f"stream ID: {cuda.cuStreamGetId(self.stream)[1]}, priority: {cuda.cuStreamGetPriority(self.stream)[1]}")
def engine_data_type(self):
if self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.FLOAT:
self.data_byte = np.float32().itemsize
return "FP32"
elif self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.HALF:
self.data_byte = np.float16().itemsize
return "FP16"
elif self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.INT32:
self.data_byte = np.int32().itemsize
return "INT32"
elif self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.INT8:
self.data_byte = np.int8().itemsize
return "INT8"
def data_generation(self, shape, empty=False):
engine_data_type = self.engine_data_type()
if engine_data_type == "FP32":
if not empty:
return np.random.random(shape).astype(np.float32)
return np.empty(shape, dtype=np.float32)
def device_memory_allocation(self):
Input_Byte_Size = int(np.prod(self.max_input_shape) * np.float32().itemsize)
Output_Byte_Size = int(np.prod(self.max_output_shape) * np.float32().itemsize)
err_di, self.d_input = cuda.cuMemAllocAsync(Input_Byte_Size, self.stream)
err_do, self.d_output = cuda.cuMemAllocAsync(Output_Byte_Size, self.stream)
self.context.set_tensor_address(self.input_tensor_name, int(self.d_input))
self.context.set_tensor_address(self.output_tensor_name, int(self.d_output))
if err_di != cuda.CUresult(0) or err_do !=cuda.CUresult(0):
print("Device memory allocation fail")
exit(0)
def cuda_stream_execution(self, h_input, d_input, h_output, d_output, Input_Byte_Size, Output_Byte_Size, Sream):
t1 = time.perf_counter_ns()
cuda.cuMemcpyHtoDAsync(d_input, h_input.ctypes.data, Input_Byte_Size, Sream)
t2 = time.perf_counter_ns()
print(f"t2-t1: {round(float((t2 - t1) / 10**6), 1)} ms")
# Execute the inference asynchronously
self.context.execute_async_v3(stream_handle=Sream)
t3 = time.perf_counter_ns()
print(f"t3-t2: {round(float((t3 - t2) / 10**6), 1)} ms")
# Transfer predictions back from the GPU asynchronously
cuda.cuMemcpyDtoHAsync(h_output.ctypes.data, d_output, Output_Byte_Size, Sream)
cuda.cuStreamSynchronize(Sream)
t4 = time.perf_counter_ns()
print(f"t4-t3: {round(float((t4 - t3) / 10**6), 1)} ms")
def warm_up(self):
minimum_input_shape = (self.minimum_batch_size,) + self.input_shape
minimum_output_shape = (self.minimum_batch_size,) + self.output_shape
warm_up_input = self.data_generation(minimum_input_shape)
warm_up_output = self.data_generation(minimum_output_shape, True)
# Setting the context binding
self.context.set_input_shape(self.input_tensor_name, minimum_input_shape)
# Transfer input data to device asynchronously
Input_Byte_Size = int(np.prod(minimum_input_shape) * self.data_byte)
Output_Byte_Size = int(np.prod(minimum_output_shape) * self.data_byte)
self.cuda_stream_execution(warm_up_input,
self.d_input,
warm_up_output,
self.d_output,
Input_Byte_Size,
Output_Byte_Size,
self.stream)
def test(self):
Input_Byte_Size = int(np.prod(self.test_input_data_shape) * self.data_byte)
Output_Byte_Size = int(np.prod(self.test_output_data_shape) * self.data_byte)
self.context.set_input_shape(self.input_tensor_name, self.test_input_data_shape)
self.cuda_stream_execution(self.test_h_input,
self.d_input,
self.test_h_output,
self.d_output,
Input_Byte_Size,
Output_Byte_Size,
self.stream)
def test_data(self):
batch_size = 8
self.test_input_data_shape = (batch_size,) + self.input_shape
self.test_output_data_shape = (batch_size,) + self.output_shape
self.test_h_input = self.data_generation(self.test_input_data_shape)
self.test_h_output = self.data_generation(self.test_output_data_shape, True)
model_path = project_dir + '/strait/model/trt/'
start_timing = time.perf_counter_ns()
context1 = tensorrt_context(model_path + "resnet.trt",-1)
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
start_timing = time.perf_counter_ns()
context2 = tensorrt_context(model_path + "resnet.trt", -2)
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
start_timing = time.perf_counter_ns()
context1.test()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
start_timing = time.perf_counter_ns()
context2.test()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
def run_test(context, context_id):
start_timing = time.perf_counter_ns()
context.test()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"Execution time for context{context_id}: {response_latency} ms")
cuda.cuProfilerStart()
# Create and start two threads
thread1 = threading.Thread(target=run_test, args=(context1, 1))
thread2 = threading.Thread(target=run_test, args=(context2, 2))
start_timing = time.perf_counter_ns()
thread1.start()
thread2.start()
# Wait for both threads to finish
thread1.join()
thread2.join()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"Total execution time for both tests: {response_latency} ms")
Environment
TensorRT Version: 10
GPU Type: RTX 3500 ada
Nvidia Driver Version:
CUDA Version:
Operating System + Version: Ubuntu
Relevant Files
Please attach or include links to any models, data, files, or scripts necessary to reproduce your issue. (Github repo, Google Drive, Dropbox, etc.)
Steps To Reproduce
test_report.zip (100.3 KB)