Why tensorrt got stuck when using ThreadPool in Python

Description

After reference this draft and this draft I wrote codes as below. The code got stuck when using thread pool.

Can any one help out how to make it work properly?

And I won’t my model to serve by flask frame with multithreading.

import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
import threading

def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    elif isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))


def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res


class HostDeviceMem:
    def __init__(self, shape, dtype):
        self.shape = shape
        self.dtype = dtype
        size = trt.volume(shape) * dtype.itemsize
        self.device = cuda_call(cudart.cudaMalloc(size))


class TrtModel:
    def __init__(self, model_path, device_id=0):
        err, = cuda.cuInit(0)
        check_cuda_err(err)
        device = cuda_call(cuda.cuDeviceGet(device_id))
        self.ctx = cuda_call(cuda.cuCtxCreate(cuda.CUctx_flags.CU_CTX_SCHED_YIELD, device))
        self.logger = trt.Logger(trt.Logger.ERROR)
        trt.init_libnvinfer_plugins(self.logger, namespace="")

        with open(model_path, 'rb') as f, trt.Runtime(self.logger) as runtime:
            assert runtime, 'Can not create TensorRT Runtime'
            self.engine = runtime.deserialize_cuda_engine(f.read())

        assert self.engine, 'Can not load engine file'

        self.context = self.engine.create_execution_context()
        assert self.context, 'Can not create execution context'

        self.inputs, self.outputs, self.bindings, self.stream, self.max_batch_size = self.allocate_buffers(self.engine)

    @property
    def input_names(self):
        return list(self.inputs.keys())

    @property
    def output_names(self):
        return list(self.outputs.keys())

    def allocate_buffers(self, engine: trt.ICudaEngine):
        inputs = {}
        outputs = {}
        bindings = []
        stream = cuda_call(cudart.cudaStreamCreate())
        max_batch_size = 1
        for i in range(engine.num_io_tensors):
            name = engine.get_tensor_name(i)
            is_input = engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
            dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(name)))
            shape = self.context.get_tensor_shape(name)
            if is_input and shape[0] < 0:
                assert engine.num_optimization_profiles > 0, 'Engine dynamic axes but no optimization profiles ' \
                                                             'exists '
                profile_shape = engine.get_tensor_profile_shape(name, 0)
                assert len(profile_shape) == 3  # min,opt,max
                # Set the *max* profile as binding shape
                self.context.set_input_shape(name, profile_shape[2])
                shape = self.context.get_tensor_shape(name)

            if is_input:
                max_batch_size = shape[0]

            # Allocate host and device buffers
            binding_memory = HostDeviceMem(shape, dtype)

            # Append the device buffer to device bindings.
            bindings.append(int(binding_memory.device))

            # Append to the appropriate list.
            if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
                inputs[name] = binding_memory
            else:
                outputs[name] = binding_memory
        return inputs, outputs, bindings, stream, max_batch_size

    def _inference(self, output_names, input_feed: dict):
        kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
        batch_size = self.max_batch_size
        for name, value in self.inputs.items():  # make sure we set all inputs required
            feed = input_feed[name]
            value.host = feed
            assert feed.dtype == value.dtype, 'Wrong dtype, {} expected'.format(value.dtype)
            feed = np.ascontiguousarray(feed)
            self.context.set_input_shape(name, feed.shape)
            batch_size = min(batch_size, feed.shape[0])
            num_bytes = trt.volume(feed.shape) * value.dtype.itemsize
            cuda_call(cudart.cudaMemcpyAsync(value.device, feed, num_bytes, kind, self.stream))
        self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream)
        kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost

        results = []
        for key in output_names:
            out = self.outputs[key]
            shape = self.context.get_tensor_shape(key)
            out_batch_size = min(shape[0], batch_size)
            shape = out_batch_size, *shape[1:]
            num_bytes = trt.volume(shape) * out.dtype.itemsize
            host = np.zeros(shape, dtype=out.dtype)
            cuda_call(cudart.cudaMemcpyAsync(host, out.device, num_bytes, kind, self.stream))
            results.append(host)
        cuda_call(cudart.cudaStreamSynchronize(self.stream))
    
        return results

    def _chunk(self, input_feed: dict):
        batch_size = next(iter(input_feed.values())).shape[0]
        results = {}
        for key, value in input_feed.items():  # we don't check input batch size consistency
            chunks = [value[i: i + self.max_batch_size] for i in range(0, batch_size, self.max_batch_size)]
            results[key] = chunks

        # dict of list to list of dict
        return [dict(zip(results, t)) for t in zip(*results.values())]

    def run(self, output_names, input_feed: dict):
        threading.Thread.__init__(self)
        cuda.cuCtxPushCurrent(self.ctx)
        """Chunk input by max batch size, and inference sequentially"""
        if next(iter(input_feed.values())).shape[0] <= self.max_batch_size:
            return self._inference(output_names, input_feed)

        input_feeds = self._chunk(input_feed)
        results = [self._inference(output_names, inp) for inp in input_feeds]
        cuda.cuCtxPopCurrent()
        return [np.concatenate(v, axis=0) for v in zip(*results)]

if __name__ == '__main__':
    import numpy as np
    from multiprocessing.pool import ThreadPool

    model = TrtModel('path/to/model.engine', device_id=2)
    model2 = TrtModel('another model', device_id=3)


    inputs = np.random.rand(32, 3, 112, 112).astype(np.float32)

    while True:
        with ThreadPool(5) as pool:
            res1 = [pool.apply_async(model.run, (model.output_names, {model.input_names[0]: inputs})) for _ in range(5)]
            res2 = [pool.apply_async(model2.run, (model2.output_names, {model2.input_names[0]: inputs})) for _ in range(5)]

            res = res1 + res2

            for v in res:
                v = v.get()
                print(v[0].shape)



Environment

TensorRT Version: 8.6.1
GPU Type: A6000
Nvidia Driver Version:
CUDA Version: V11.2.152
CUDNN Version:
Operating System + Version: Ubuntu 20.04
Python Version (if applicable): 3.8.13
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag): Container

Relevant Files

Please attach or include links to any models, data, files, or scripts necessary to reproduce your issue. (Github repo, Google Drive, Dropbox, etc.)

Steps To Reproduce

Just run the script with your engine files to reproduce

Hi,

The below links might be useful for you.

https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html

For multi-threading/streaming, will suggest you to use Deepstream or TRITON

For more details, we recommend you raise the query in Deepstream forum.

or

raise the query in Triton Inference Server Github instance issues section.

Thanks!