Unable to run two TensorRT models in a cascade manner

Hi,

Purpose: Need to load two TensorRT models at the same time.

I am new to jetson and TensorRT. I’ve been trying to implement a CV pipeline in python which needs to use two TensorRT models in a cascade manner.

Environment Info

TensorRT Version : 7.1
GPU Type : GTX 1070
Nvidia Driver Version : 440
CUDA Version : 10.2
CUDNN Version : 7.6
Operating System + Version : Ubuntu 18.04
Python Version : 3.6.10
PyTorch Version : 1.5.1

Pipeline workflow

  1. Detects objects in the video frame with a TensorRT detection model
  2. Tracks objects, detected by the TRT detection model, using DeepSort Tracker

I am referencing @jkjung13’s tensorrt_demos repo to build and implement TensorRT models.

Here is the code for TensorRT detection model Inference:

class HostDeviceMem(object):
    """Simple helper data class that's a little nicer to use than a 2-tuple"""
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

def allocate_buffers(engine):
    """Allocates all host/device in/out buffers required for an engine."""

    inputs   = []
    outputs  = []
    bindings = []
    stream   = cuda.Stream()
    for binding in engine:
        size  = trt.volume(engine.get_binding_shape(binding)) * \
                engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))

        # Allocate host and device buffers.
        host_mem   = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes) 
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))

    return inputs, outputs, bindings, stream

class TrtYOLOv4(object):
    """TrtYOLOv4 class encapsulates things needed to run TRT YOLO.
    """

    def _load_engine(self):
        TRTbin = self.engine_path
        with open(TRTbin, 'rb') as f, trt.Runtime(self.trt_logger) as runtime:
            return runtime.deserialize_cuda_engine(f.read())

    def _create_context(self):
        return self.engine.create_execution_context()

    def __init__(self, engine_path, input_shape, 
                nms_thres, conf_thres, num_classes=80):

        """Initialize parameters requried for building TensorRT engine.
           TensorRT plugins, engine and context.

        Parameters
        ----------
        engine_path : str
                      Path of the TensorRT engine model file
        input_shape : tuple
                      a tuple of (H, W)
        nms_thres   : float(between 1 and 0)
                      Threshold value for performing non-maximum suppression
        conf_thres  : float (between 1 and 0)
                      Threshold value for filtering the boxes, outputted from model
        num_classs  : int
                      Total number of classes, that the model can detect
        yolo_masks  : list of tuples
                      A list of 3 three-dimensional tuples for the YOLO masks
        yolo_anchors: list of tuples
                      A list of 9 two-dimensional tuples for the YOLO anchors

        Returns
        -------
        TensorRT engine instance capable of inferencing on images
        """
        self.cuda_ctx    = cuda.Device(0).make_context() # Use GPU:0

        self.engine_path = engine_path
        # check whether the file exists or not
        assert ops.exists(self.engine_path), "Engine file does not exist. Please check!"

        self.input_shape = input_shape
        self.nms_thres   = nms_thres
        self.conf_thres  = conf_thres
        self.num_classes = num_classes

        # Setup YOLOv4 postprocessing parameters
        filters = (self.num_classes + 5) * 3
        h, w = self.input_shape
        self.yolo_masks   = [(0, 1, 2), (3, 4, 5), (6, 7, 8)]
        self.output_shapes = [(1, filters, h //  8, w //  8),
                             (1, filters, h // 16, w // 16),  
                             (1, filters, h // 32, w // 32)]
        self.yolo_anchors = [(12, 16), (19, 36), (40, 28),
                             (36, 75), (76, 55), (72, 146),
                             (142, 110), (192, 243), (459, 401)]
        self.yolo_input_resolution = self.input_shape 

        # setup inference function
        self.inference_fn = do_inference

        # setup logger
        self.trt_logger = trt.Logger(trt.Logger.INFO)
        self.engine     = self._load_engine()

        # setup postprocess
        self.postprocessor = Postprocess(yolo_masks=self.yolo_masks, yolo_anchors=self.yolo_anchors, conf_threshold=self.conf_thres,
                                         nms_threshold=self.nms_thres, yolo_input_resolution=self.input_shape, num_classes=self.num_classes)

        try:
            self.context = self._create_context()
            self.inputs, self.outputs, self.bindings, self.stream = \
                allocate_buffers(self.engine)
        except Exception as e:
            self.cuda_ctx.pop()
            del self.cuda_ctx
            raise RuntimeError("Fail to allocate CUDA resources") from e

    def __del__(self):
        """Free CUDA memories"""
        del self.stream
        del self.outputs
        del self.inputs
        self.cuda_ctx.pop()
        del self.cuda_ctx
    
    def detect(self, img):
        """Detect objects in the input image."""
        shape_orig_WH = (img.shape[1], img.shape[0])
        img_resized = _preprocess_frame(img, self.input_shape)

        # Set host input to the image. The do_inference() function
        # will copy the input to the GPU before executing.
        self.inputs[0].host = np.ascontiguousarray(img_resized)
        trt_outputs = self.inference_fn(
            context=self.context,
            bindings=self.bindings,
            inputs=self.inputs,
            outputs=self.outputs,
            stream=self.stream)

        # Before doing post-processing, we need to reshape the outputs
        # as do_inference() will give us flat arrays.
        trt_outputs = [
            output.reshape(shape)
            for output, shape in zip(trt_outputs, self.output_shapes)]
        # Run the post-processing algorithms on the TensorRT outputs
        # and get the bounding box details of detected objects
        boxes, classes, scores = self.postprocessor.process(
            trt_outputs, shape_orig_WH, self.conf_thres)
        return boxes, scores, classes

Inference code

def do_inference(context, bindings, inputs, outputs, stream):
    """do_inference (for TensorRT 7.0+)

    This function is generalized for multiple inputs/outputs for full
    dimension networks.
    Inputs and outputs are expected to be lists of HostDeviceMem objects.
    """
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

DeepSort Tracking also uses a ConvNet to extract features from detections, detected by the TensorRT detection model so I converted it to TensorRT model for better performance. TRT inference implementation for extractor network is nearly the same as the above detection model code.

After that, I create an object named “CVPipeline” object which creates these two models as functions.

class CVPipeline(object):
    def __init__(self, cfg):
        # build yolov4 TRT model
        self.trt_detector = TrtYOLOv4(engine_path="checkpoint/yolov4.trt", input_shape=(self.input_size, self.input_size), 
                                      nms_thres=self.nms_threshold, conf_thres=self.conf_threshold, num_classes=len(self.obj_list))
        # create feature extractor model, to be used in DeepSort Tracker
        self.extractor = Extractor("checkpoint/extractorv3.trt", (64, 128))

    def run(self, video):
        #  something like this
        cap = cv2.VideoCapture(video)
        while True:
            ret,frame = cap.read()
            boxes, confs, cls = self.detector(frame)
            for box in boxes:
               features.append(self.extractor(box))
        # Tracking stuffs happen in here

When I create an instance of the CVPipeline object and run the process, this sort of error occurs on my machine:

[TensorRT] ERROR: ../rtSafe/cuda/caskConvolutionRunner.cpp (373) - Cask Error in checkCaskExecError<false>: 10 (Cask Convolution execution)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception
[TensorRT] ERROR: ../rtSafe/cuda/caskConvolutionRunner.cpp (490) - Cuda Error in execute: 400 (invalid resource handle)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception

Each of these models works as expected if I only use test them separately. After trying to debug the code, it seems to be that I can’t create/load two TensorRT models at the same time.

Could you share with me some suggestions on how to fix this error so that the pipeline can run smoothly as I mentioned?

Thanks in advance!

Best Regards,
Htut

Hi @Htut,

As per the error, this looks like a thread issue.
Please ensure you are using the guidelines as per the document.

Alternatively, you can use Deepstream to run multiple models.

Thanks!

HI @AakankshaS,

Thanks for the reply.
It seems to be a thread issue, according to the first link that you provided me.

2.4. Thread Safety
The TensorRT builder may only be used by one thread at a time. If you need to run multiple builds simultaneously, you will need to create multiple builders.
The TensorRT runtime can be used by multiple threads simultaneously, so long as each object uses a different execution context.

I see that this is causing the problem but I am totally lost on how to solve it so can you please provide an example on it?

Regards,
htut

Hi @Htut,
This link might help you to get thru the issue.
https://github.com/NVIDIA/TensorRT/issues/301#issuecomment-570558499

Thanks!

HI @AakankshaS, I was able to solve this error. According to the Thread Safety section, it seems to be that a single CUDA context can be used by multiple TensorRT builders and execution contexts.

The problem with my implementation is that I am creating separate CUDA contexts and execution contexts for each TensorRT model inference code as you can see here.

I was able to solve this error by creating a common CUDA context, and passing it to each TensorRT builders, which uses it and creates it’s own execution contexts.

Thanks for the help.
Regards,
Htut

5 Likes

Hi @Htut,

Could you share your modified code of this? I’m also trying to run two models in a cascade manner but I did not get the error that you’d got. What happened to my case is a segmentation error when the program ends.
Appreciate it if you could share your final code so that I can check mine.

Thanks

Hi,
This looks like a Jetson issue. We recommend you to raise it to the respective platform from the below link

Thanks!