Hello,
I used two streams to serve ML inference, stream 16 and stream 20, the latter one is assigned with higher priority. However, when the execution of the kernel in stream 20 finishes earlier, the memory copy operations are delayed for a while, I am not aware of the reason. Attached is the trace report and code.
Thanks in advance.
report.zip (97.4 KB)
import numpy as np
from cuda import cuda
import tensorrt as trt
import time
import threading
import os
import re
import ctypes
# Use os.getcwd() to get the current working directory
current_dir = os.getcwd()
# Regular expression to match any content before "strait"
pattern = r'(.*?)/strait'
# Using search() to find the match
match = re.search(pattern, current_dir)
# Directory of Project
project_dir = match.group(1)
# Initialize TensorRT logger
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
# Create a runtime object
runtime = trt.Runtime(TRT_LOGGER)
class tensorrt_context:
def __init__(self, trt_model, priority):
try:
with open(trt_model, "rb") as f:
deserialized_engine = f.read()
except FileNotFoundError:
print(f"Error: File {trt_model} not found")
exit(1)
self.engine = runtime.deserialize_cuda_engine(deserialized_engine)
if self.engine is None:
print("Error: Failed to deserialize the engine")
exit(1)
# Create an execution context
self.context = self.engine.create_execution_context()
if self.context is None:
print("Error: Failed to create execution context")
exit(1)
# Input tensor name
self.input_tensor_name = self.engine.get_tensor_name(0)
# Output tensor name
self.output_tensor_name = self.engine.get_tensor_name(1)
# Minimum batch size for this engine model
self.minimum_batch_size = self.engine.get_tensor_profile_shape(self.input_tensor_name, 0)[0][0]
# Maximum batch size for this engine model
self.max_batch_size = self.engine.get_tensor_profile_shape(self.input_tensor_name, 0)[-1][0]
# Input data shape
self.input_shape = tuple(self.engine.get_tensor_shape(self.input_tensor_name))[1:]
# Output data shape
self.output_shape = tuple(self.engine.get_tensor_shape(self.output_tensor_name))[1:]
# Maximum batch input
self.max_input_shape = (self.max_batch_size,) + self.input_shape
# Maximum batch output
self.max_output_shape = (self.max_batch_size,) + self.output_shape
# Initialize a cuda stream
self.init_cuda_stream(priority)
self.device_memory_allocation()
self.host_memory_allocation()
self.warm_up()
self.test_data()
def init_cuda_stream(self, priority):
# Create a CUDA stream
err_stream, self.stream = cuda.cuStreamCreateWithPriority(0, priority)
#cuda.cuCtxGetStreamPriorityRange()
print(f"stream ID: {cuda.cuStreamGetId(self.stream)[1]}, priority: {cuda.cuStreamGetPriority(self.stream)[1]}")
def engine_data_type(self):
if self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.FLOAT:
self.data_byte = np.float32().itemsize
return "FP32"
elif self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.HALF:
self.data_byte = np.float16().itemsize
return "FP16"
elif self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.INT32:
self.data_byte = np.int32().itemsize
return "INT32"
elif self.engine.get_tensor_dtype(self.input_tensor_name) == trt.DataType.INT8:
self.data_byte = np.int8().itemsize
return "INT8"
def data_generation(self, shape, empty=False):
engine_data_type = self.engine_data_type()
if engine_data_type == "FP32":
if not empty:
return np.random.random(shape).astype(np.float32)
return np.empty(shape, dtype=np.float32)
def host_memory_allocation(self):
Input_Byte_Size = int(np.prod(self.max_input_shape) * np.float32().itemsize)
Output_Byte_Size = int(np.prod(self.max_output_shape) * np.float32().itemsize)
err_hi, self.h_input = cuda.cuMemHostAlloc(Input_Byte_Size, cuda.CU_MEMHOSTALLOC_WRITECOMBINED)
err_ho, self.h_output = cuda.cuMemHostAlloc(Output_Byte_Size, cuda.CU_MEMHOSTALLOC_DEVICEMAP)
if err_hi != cuda.CUresult(0) or err_ho !=cuda.CUresult(0):
print("Page-locked host memory allocation fail")
exit(0)
def device_memory_allocation(self):
Input_Byte_Size = int(np.prod(self.max_input_shape) * np.float32().itemsize)
Output_Byte_Size = int(np.prod(self.max_output_shape) * np.float32().itemsize)
err_di, self.d_input = cuda.cuMemAllocAsync(Input_Byte_Size, self.stream)
err_do, self.d_output = cuda.cuMemAllocAsync(Output_Byte_Size, self.stream)
self.context.set_tensor_address(self.input_tensor_name, int(self.d_input))
self.context.set_tensor_address(self.output_tensor_name, int(self.d_output))
if err_di != cuda.CUresult(0) or err_do !=cuda.CUresult(0):
print("Device memory allocation fail")
exit(0)
def cuda_stream_execution(self, h_input, d_input, h_output, d_output, Input_Byte_Size, Output_Byte_Size, Sream):
cuda.cuMemcpyHtoDAsync(d_input, h_input, Input_Byte_Size, Sream)
# Execute the inference asynchronously
self.context.execute_async_v3(stream_handle=Sream)
#t3 = time.perf_counter_ns()
#print(f"t3 - t2: {round(float((t3 - t2) / 10**6), 1)}")
# Transfer predictions back from the GPU asynchronously
cuda.cuMemcpyDtoHAsync(h_output, d_output, Output_Byte_Size, Sream)
cuda.cuStreamSynchronize(Sream)
def warm_up(self):
minimum_input_shape = (self.minimum_batch_size,) + self.input_shape
minimum_output_shape = (self.minimum_batch_size,) + self.output_shape
h_input = self.data_generation(minimum_input_shape)
# = self.data_generation(minimum_output_shape, True)
ctypes.memmove(self.h_input, h_input.ctypes.data, h_input.nbytes)
# Setting the context binding
self.context.set_input_shape(self.input_tensor_name, minimum_input_shape)
# Transfer input data to device asynchronously
Input_Byte_Size = int(np.prod(minimum_input_shape) * self.data_byte)
Output_Byte_Size = int(np.prod(minimum_output_shape) * self.data_byte)
self.cuda_stream_execution(self.h_input,
self.d_input,
self.h_output,
self.d_output,
Input_Byte_Size,
Output_Byte_Size,
self.stream)
read_output = np.empty(minimum_output_shape, dtype=np.float32)
ctypes.memmove(read_output.ctypes.data, self.h_output, Output_Byte_Size)
def test(self):
Input_Byte_Size = int(np.prod(self.test_input_data_shape) * self.data_byte)
Output_Byte_Size = int(np.prod(self.test_output_data_shape) * self.data_byte)
self.context.set_input_shape(self.input_tensor_name, self.test_input_data_shape)
self.cuda_stream_execution( self.h_input,
self.d_input,
self.h_output,
self.d_output,
Input_Byte_Size,
Output_Byte_Size,
self.stream)
read_output = np.empty(self.test_output_data_shape, dtype=np.float32)
ctypes.memmove(read_output.ctypes.data, self.h_output, Output_Byte_Size)
def test_data(self):
batch_size = 8
self.test_input_data_shape = (batch_size,) + self.input_shape
self.test_output_data_shape = (batch_size,) + self.output_shape
h_input = self.data_generation(self.test_input_data_shape)
ctypes.memmove(self.h_input, h_input.ctypes.data, h_input.nbytes)
model_path = project_dir + '/strait/model/trt/'
start_timing = time.perf_counter_ns()
context1 = tensorrt_context(model_path + "ResNet-50.trt",-1)
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
start_timing = time.perf_counter_ns()
context2 = tensorrt_context(model_path + "ResNet-50.trt", -2)
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
start_timing = time.perf_counter_ns()
context1.test()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
start_timing = time.perf_counter_ns()
context2.test()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"response latency: {response_latency} ms")
def run_test(context, context_id):
start_timing = time.perf_counter_ns()
context.test()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"Execution time for context{context_id}: {response_latency} ms")
cuda.cuProfilerStart()
# Create and start two threads
thread1 = threading.Thread(target=run_test, args=(context1, 1))
thread2 = threading.Thread(target=run_test, args=(context2, 2))
start_timing = time.perf_counter_ns()
thread1.start()
thread2.start()
# Wait for both threads to finish
thread1.join()
thread2.join()
end_timing = time.perf_counter_ns()
response_latency = round(float((end_timing - start_timing) / 10**6), 1) # Express latency in milliseconds
print(f"Total execution time for both tests: {response_latency} ms")