TensorRT fails to exit properly

I’m trying to run multithreading with TensorRT by modifying this example to run with 2 (or more) threads at the same time. The code currently runs fine and shows correct results but the error

Segmentation fault (core dumped)

always happens when finishing. Is there anyway to fix this? Thank you in advance.

Here is my modified code:
my_tensorrt_code.py:

from PIL import Image
import numpy as np
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver as cuda
import threading
import time
import math


class TRTInference:
    def __init__(self, trt_engine_path, trt_engine_datatype, batch_size):
        self.cfx = cuda.Device(0).make_context()
        stream = cuda.Stream()

        TRT_LOGGER = trt.Logger(trt.Logger.INFO)
        trt.init_libnvinfer_plugins(TRT_LOGGER, '')
        runtime = trt.Runtime(TRT_LOGGER)

        # deserialize engine
        with open(trt_engine_path, 'rb') as f:
            buf = f.read()
            engine = runtime.deserialize_cuda_engine(buf)
        context = engine.create_execution_context()

        # prepare buffer
        host_inputs = []
        cuda_inputs = []
        host_outputs = []
        cuda_outputs = []
        bindings = []

        for binding in engine:
            size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
            host_mem = cuda.pagelocked_empty(size, np.float32)
            cuda_mem = cuda.mem_alloc(host_mem.nbytes)

            bindings.append(int(cuda_mem))
            if engine.binding_is_input(binding):
                host_inputs.append(host_mem)
                cuda_inputs.append(cuda_mem)
            else:
                host_outputs.append(host_mem)
                cuda_outputs.append(cuda_mem)

        # store
        self.stream = stream
        self.context = context
        self.engine = engine

        self.host_inputs = host_inputs
        self.cuda_inputs = cuda_inputs
        self.host_outputs = host_outputs
        self.cuda_outputs = cuda_outputs
        self.bindings = bindings

    def infer(self, input_img_path):
        threading.Thread.__init__(self)
        self.cfx.push()

        # restore
        stream = self.stream
        context = self.context
        engine = self.engine

        host_inputs = self.host_inputs
        cuda_inputs = self.cuda_inputs
        host_outputs = self.host_outputs
        cuda_outputs = self.cuda_outputs
        bindings = self.bindings

        # read image
        image = 1 - (np.asarray(Image.open(input_img_path), dtype=np.float) / 255)
        np.copyto(host_inputs[0], image.ravel())

        # inference
        start_time = time.time()
        cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
        context.execute_async(bindings=bindings, stream_handle=stream.handle)
        cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
        stream.synchronize()
        print("execute times " + str(time.time() - start_time))

        # parse output
        output = np.array([math.exp(o) for o in host_outputs[0]])
        output /= sum(output)
        for i in range(len(output)): print("%d: %.2f" % (i, output[i]))

        self.cfx.pop()

    def destory(self):
        self.cfx.pop()

test.py:

import threading
import time
from my_tensorrt_code import TRTInference, trt

exitFlag = 0


class myThread(threading.Thread):
    def __init__(self, func, args):
        threading.Thread.__init__(self)
        self.func = func
        self.args = args

    def run(self):
        print("Starting " + self.args[0])
        self.func(*self.args)
        print("Exiting " + self.args[0])


if __name__ == '__main__':
    # Create new threads
    '''
    format thread:
        - func: function names, function that we wished to use
        - arguments: arguments that will be used for the func's arguments
    '''

    trt_engine_path = 'mnist.trt'

    max_batch_size = 1
    trt_inference_wrapper1 = TRTInference(trt_engine_path,
                                         trt_engine_datatype=trt.DataType.FLOAT,
                                         batch_size=max_batch_size)
    trt_inference_wrapper2 = TRTInference(trt_engine_path,
                                          trt_engine_datatype=trt.DataType.FLOAT,
                                          batch_size=max_batch_size)

    # Get TensorRT SSD model output
    input_img_path1 = 'pgms/3.pgm'
    input_img_path2 = 'pgms/1.pgm'

    thread1 = myThread(trt_inference_wrapper1.infer, [input_img_path1])
    thread2 = myThread(trt_inference_wrapper2.infer, [input_img_path2])

    # Start new Threads
    thread1.start()
    thread2.start()
    thread1.join()
    thread2.join()
    trt_inference_wrapper1.destory()
    trt_inference_wrapper2.destory()
    print("Exiting Main Thread")

Here is the output when running:

vinhtq115@Dell-G7-7588:~/PycharmProjects/TensorRT-multithreading$ python test.py 
[TensorRT] INFO: [MemUsageChange] Init CUDA: CPU +150, GPU +0, now: CPU 175, GPU 197 (MiB)
[TensorRT] INFO: Loaded engine size: 0 MB
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine begin: CPU 175 MiB, GPU 197 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +230, GPU +94, now: CPU 405, GPU 293 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +185, GPU +80, now: CPU 590, GPU 373 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 590, GPU 357 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine end: CPU 590 MiB, GPU 357 MiB
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation begin: CPU 590 MiB, GPU 357 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 590, GPU 365 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 590, GPU 373 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation end: CPU 590 MiB, GPU 373 MiB
[TensorRT] WARNING: The logger passed into createInferRuntime differs from one already provided for an existing builder, runtime, or refitter. TensorRT maintains only a single logger pointer at any given time, so the existing value, which can be retrieved with getLogger(), will be used instead. In order to use a new logger, first destroy all existing builder, runner or refitter objects.

[TensorRT] INFO: [MemUsageChange] Init CUDA: CPU +40, GPU +0, now: CPU 635, GPU 496 (MiB)
[TensorRT] INFO: Loaded engine size: 0 MB
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine begin: CPU 635 MiB, GPU 496 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +33, GPU +48, now: CPU 669, GPU 546 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +72, GPU +82, now: CPU 741, GPU 628 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 773, GPU 656 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine end: CPU 773 MiB, GPU 656 MiB
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation begin: CPU 773 MiB, GPU 656 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 773, GPU 664 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 773, GPU 672 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation end: CPU 773 MiB, GPU 672 MiB
Starting pgms/3.pgm
Starting pgms/1.pgm
execute times 0.0002894401550292969
0: 0.00
1: 0.00
2: 0.00
3: 1.00
execute times 0.0002601146697998047
4: 0.00
0: 0.00
5: 0.00
6: 0.00
7: 0.00
1: 1.00
2: 0.00
8: 0.00
3: 0.00
9: 0.00
4: 0.00
Exiting pgms/3.pgm
5: 0.00
6: 0.00
7: 0.00
8: 0.00
9: 0.00
Exiting pgms/1.pgm
Exiting Main Thread
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 951, GPU 874 (MiB)
Segmentation fault (core dumped)

Hi,
The below link might be useful for you
https://docs.nvidia.com/deeplearning/tensorrt/best-practices/index.html#thread-safety

https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
For multi threading/streaming, will suggest you to use Deepstream or TRITON
For more details, we recommend you to raise the query to the Deepstream or TRITON forum.

Thanks!

Thanks but is it possible to fix this without relying on DeepStream or TRITON? I’d like to use Python instead of C++.

Hi,

Could you please share ONNX model and trtexec command you’re following for generating engine, to try from our end for better debugging.

Thank you.

I’m using TensorRT 8.0.1.6. The ONNX model comes from the MNIST example. The trtexec command that I used was: `./trtexec --onnx=TensorRT-8.0.1.6/data/mnist/mnist.onnx --saveEngine=mnist.trt.
model.onnx (25.8 KB)

Hi,

Could you please share inputs as well for testing from our end, we are facing some issue with other test data.

Thank you.

Hi.

Here is the code and the inputs. I’m running this with TensorRT 8.0.1.6.
TensorRT-multithreading.zip (34.5 KB)

I think I have found the solution. Moving TRT_Logger outside of the class solved the issue for me. Maybe pycuda needs TRT_Logger to stay alive, even after TRTInference is deleted?
my_tensorrt_code.py

from PIL import Image
import numpy as np
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver as cuda
import threading
import time
import math


TRT_LOGGER = trt.Logger(trt.Logger.INFO)
class TRTInference:
    def __init__(self, trt_engine_path, trt_engine_datatype, batch_size):
        self.cfx = cuda.Device(0).make_context()
        stream = cuda.Stream()

        trt.init_libnvinfer_plugins(TRT_LOGGER, '')
        runtime = trt.Runtime(TRT_LOGGER)

        # deserialize engine
        with open(trt_engine_path, 'rb') as f:
            buf = f.read()
            engine = runtime.deserialize_cuda_engine(buf)
        context = engine.create_execution_context()

        # prepare buffer
        host_inputs = []
        cuda_inputs = []
        host_outputs = []
        cuda_outputs = []
        bindings = []

        for binding in engine:
            size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
            host_mem = cuda.pagelocked_empty(size, np.float32)
            cuda_mem = cuda.mem_alloc(host_mem.nbytes)

            bindings.append(int(cuda_mem))
            if engine.binding_is_input(binding):
                host_inputs.append(host_mem)
                cuda_inputs.append(cuda_mem)
            else:
                host_outputs.append(host_mem)
                cuda_outputs.append(cuda_mem)

        # store
        self.stream = stream
        self.context = context
        self.engine = engine

        self.host_inputs = host_inputs
        self.cuda_inputs = cuda_inputs
        self.host_outputs = host_outputs
        self.cuda_outputs = cuda_outputs
        self.bindings = bindings

    def infer(self, input_img_path):
        threading.Thread.__init__(self)
        self.cfx.push()

        # restore
        stream = self.stream
        context = self.context
        engine = self.engine

        host_inputs = self.host_inputs
        cuda_inputs = self.cuda_inputs
        host_outputs = self.host_outputs
        cuda_outputs = self.cuda_outputs
        bindings = self.bindings

        # read image
        image = 1 - (np.asarray(Image.open(input_img_path), dtype=np.float) / 255)
        np.copyto(host_inputs[0], image.ravel())

        # inference
        start_time = time.time()
        cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
        context.execute_async(bindings=bindings, stream_handle=stream.handle)
        cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
        stream.synchronize()
        print("execute times " + str(time.time() - start_time))

        # parse output
        output = np.array([math.exp(o) for o in host_outputs[0]])
        output /= sum(output)
        for i in range(len(output)): print("%d: %.2f" % (i, output[i]))

        self.cfx.pop()

    def destory(self):
        self.cfx.pop()

No more issue after finishing:

vinhtq115@Dell-G7-7588:~/PycharmProjects/TensorRT-multithreading$ python main.py 
[TensorRT] INFO: [MemUsageChange] Init CUDA: CPU +150, GPU +0, now: CPU 175, GPU 197 (MiB)
[TensorRT] INFO: Loaded engine size: 0 MB
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine begin: CPU 175 MiB, GPU 197 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +230, GPU +94, now: CPU 405, GPU 293 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +185, GPU +80, now: CPU 590, GPU 373 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 590, GPU 357 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine end: CPU 590 MiB, GPU 357 MiB
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation begin: CPU 590 MiB, GPU 357 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 590, GPU 365 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 590, GPU 373 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation end: CPU 590 MiB, GPU 373 MiB
[TensorRT] INFO: [MemUsageChange] Init CUDA: CPU +40, GPU +0, now: CPU 635, GPU 496 (MiB)
[TensorRT] INFO: Loaded engine size: 0 MB
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine begin: CPU 635 MiB, GPU 496 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +33, GPU +48, now: CPU 669, GPU 546 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +72, GPU +82, now: CPU 741, GPU 628 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 773, GPU 656 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine end: CPU 773 MiB, GPU 656 MiB
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation begin: CPU 773 MiB, GPU 656 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 773, GPU 664 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 773, GPU 672 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation end: CPU 773 MiB, GPU 672 MiB
Starting pgms/3.pgm
Starting pgms/1.pgm
execute times 0.00012302398681640625
0: 0.00
1: 1.00
2: 0.00
3: 0.00
execute times 0.00030493736267089844
4: 0.00
0: 0.00
5: 0.00
1: 0.00
6: 0.00
2: 0.00
7: 0.00
3: 1.00
8: 0.00
4: 0.00
9: 0.00
5: 0.00
Exiting pgms/1.pgm
6: 0.00
7: 0.00
8: 0.00
9: 0.00
Exiting pgms/3.pgm
Exiting Main Thread
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 951, GPU 874 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 768, GPU 575 (MiB)

Hi @starcraft6723,

Please find more details in TRT api doc here.
https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Logger.html

The logger used to create an instance of IBuilder, IRuntime or IRefitter is used for all objects created through that interface. The logger should be valid until all objects created are released.

Thank you.