TensorRT ERROR: pointWiseV2Helpers.h::launchPwgenKernel::532 Cuda Driver (invalid resource handle)

Description

I want to use TensorRT to run inference of a pose estimation model (OpenPifPaf) on a Jetson Xavier NX.

I have converted a Pytorch model into an ONNX and then into an TensorRT engine.

When I run my inference script, I am getting a Cuda Driver (invalid resource handle) error, I think it happens exactly when I call cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream).

Environment

TensorRT Version: 8.0.1.6
GPU Type: Jetson Xavier NX
Nvidia Driver Version: 440
CUDA Version: 10.2
CUDNN Version: 8.2.1 (Jetpack 4.6)
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.6
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.9.0
Baremetal or Container (if container which image + tag):

Relevant Files

I attach the tensorrt scripts and config file that I use for inference.
inference.py (4.3 KB)
openpifpaf_tensorrt.py (11.3 KB)
config-pose-single-image.ini (540 Bytes)

I also attach the generated TensorRT engine.
openpifpaf_mobilenetv3large_129_97_d16.trt (10.2 MB)

Steps To Reproduce

  1. First I installed all the required dependencies for OpenPifPaf and TensorRT following this repository. Note that this repository uses a Docker file, in my case I am not using one, but rather I installed all the libraries directly in my Jetson.

  2. This is how I import the ONNX model from the PifPaf official repo:

python3 -m openpifpaf.export_onnx --input-width 129 --input-height 97 --checkpoint tshufflenetv2k16 --outfile tshufflenetv2k16.onnx
  1. I then convert the ONNX model into TRT using trtexec (onnx2trt is not working for me):
/usr/src/tensorrt/bin/trtexec --explicitBatch --threads --onnx=mobilenetv3large.onnx --fp16 --batch=1 --saveEngine=pose.engine
  1. I run inference via:
python3 inference.py config-pose-single-image.ini
  1. The inference.py script calls the openpifpaf_tensorrt.py. Note that I am calling torch and pycuda in the same script. I try to run async (cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)). This is how openpifpaf_tensorrt.py looks:
import time
import os
import torch
import numpy as np
import cv2
import openpifpaf
import pycuda.driver as cuda
# import pycuda.autoinit # added to solve cuDeviceGet failed: initialization error
import PIL
import pdb

import tensorrt as trt
# from decoder import CifCafDecoder
# from libs.utils.fps_calculator import convert_infr_time_to_fps

def allocate_buffers(engine):
    stream = cuda.Stream()  # create a CUDA stream to run inference
    host_inputs = []
    cuda_inputs = []
    host_outputs = []
    cuda_outputs = []
    bindings = []
    for i in range(engine.num_bindings):
        binding = engine[i]
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        host_mem = cuda.pagelocked_empty(size, np.float32) # create page-locked memory buffers(i.e. won't be swapped to disk) to hold host inputs/outputs.
        cuda_mem = cuda.mem_alloc(host_mem.nbytes)
        bindings.append(int(cuda_mem))
        if engine.binding_is_input(binding):
            host_inputs.append(host_mem)
            cuda_inputs.append(cuda_mem)
        else:
            host_outputs.append(host_mem)
            cuda_outputs.append(cuda_mem)
    # stream = cuda.Stream()  # create a CUDA stream to run inference        
    return bindings, host_inputs, cuda_inputs, host_outputs, cuda_outputs, stream

class PoseEstimator:
    """
    :param config: Is a ConfigEngine instance which provides necessary parameters.
    """
    def __init__(self, config):
        self.config = config
        self.fps = None
        self.model = self.config['Model']['BaselineModel']
        self.w, self.h = [int(i) for i in self.config['PoseEstimator']['InputSize'].split(',')]
        self.trt_logger = trt.Logger(trt.Logger.INFO) # TensorRT logger singleton
        self.model_input_size = (self.w, self.h)
        self.device = None
        self.cuda_context = None
        self._init_cuda_stuff()

    def _init_cuda_stuff(self):
        cuda.init()
        self.engine = self._load_engine()
        self.device = cuda.Device(0)  # enter your GPU id here, print(torch.cuda.current_device())
        self.cuda_context = self.device.make_context()
        self.engine_context = self.engine.create_execution_context()
        bindings, host_inputs, cuda_inputs, host_outputs, cuda_outputs, stream = allocate_buffers(self.engine)
        self.bindings = bindings
        self.host_inputs = host_inputs
        self.host_outputs = host_outputs
        self.cuda_inputs = cuda_inputs
        self.cuda_outputs = cuda_outputs
        self.stream = stream

        # print("[INFO] CUDA available:", torch.cuda.is_available())
        # print("[INFO] CUDA current device:", torch.cuda.current_device())
        # print("[INFO] CUDA device count:", torch.cuda.device_count())
        # print("[INFO] CUDA device name:", torch.cuda.get_device_name(0))

    def _load_engine(self):
        """
        Loads TRT engine. 
        Model flow: Pytorch --> ONNX --> TRT
        """
        precision=int(self.config['PoseEstimator']['TensorrtPrecision'])
        TRTbinPath='models/tensorrt/openpifpaf_'+self.model+'_{}_{}_d{}.trt'.format(self.w,self.h,precision)
        print("[INFO] TRT Model Path:", TRTbinPath)

        # create TRT engine (if it does not exist)
        if not os.path.exists(TRTbinPath):
            print("Creating TRT engine")
            # os.system('bash generate_tensorrt.bash config-pose.ini 1')
            os.system('bash generate_tensorrt.bash config-pose-single-image.ini 1')

        # load TRT engine    
        with open(TRTbinPath, 'rb') as f, trt.Runtime(self.trt_logger) as runtime:
            print("Loading TRT engine")    
            return runtime.deserialize_cuda_engine(f.read())

    def __del__(self):
        """ Free CUDA memory. """
        self.cuda_context.pop()
        del self.cuda_context
        del self.engine_context
        del self.engine

    def inference(self, resized_rgb_image):
        """
        This method will perform inference and return the detected bounding boxes
        Args:
            resized_rgb_image: uint8 numpy array with shape (img_height, img_width, channels)

        Returns:
            result: a dictionary contains of [{"id": 0, "bbox": [x1, y1, x2, y2], "score":s%}, {...}, {...}, ...]

        """
        image = resized_rgb_image
        image = cv2.resize(image, self.model_input_size)
        pil_im = PIL.Image.fromarray(image)
        preprocess = None

        # define openpifpaf PilImageList
        data = openpifpaf.datasets.PilImageList([pil_im], preprocess=preprocess)
        loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, pin_memory=True, collate_fn=openpifpaf.datasets.collate_images_anns_meta)
        
        for images_batch, _, __ in loader:
            np_img = images_batch.numpy()
        
        # Retrieve info from buffers
        bindings = self.bindings
        host_inputs = self.host_inputs
        host_outputs = self.host_outputs
        cuda_inputs = self.cuda_inputs
        cuda_outputs = self.cuda_outputs
        stream = self.stream

        host_inputs[0] = np.ravel(np.zeros_like(np_img))

        self.cuda_context.push()

        np.copyto(host_inputs[0], np.ravel(np_img)) # example: numpy array [1.3926706 1.3584211 1.4097953 ... 1.7162529 1.733682  1.7162529]

        # Transfer input data to the GPU.
        cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
        # cuda.memcpy_dtod_async(cuda_inputs[0], host_inputs[0], stream)        
        # cuda.memcpy_htod(cuda_inputs[0], host_inputs[0], stream)

        # Run inference.
        self.engine_context.execute_async(batch_size=1, bindings=bindings, stream_handle=stream.handle)

        cif=[None] * 1          
        caf=[None] * 1
        cif_names=['cif']
        caf_names=['caf']

        # Transfer predictions back from the GPU.
        for i in range(1, self.engine.num_bindings):      
            cuda.memcpy_dtoh_async(host_outputs[i - 1], cuda_outputs[i - 1], stream)
            # cuda.memcpy_dtod_async(host_outputs[i - 1], cuda_outputs[i - 1], stream)
            # cuda.memcpy_dtoh(host_outputs[i - 1], cuda_outputs[i - 1], stream)     

        # Synchronize the stream
        stream.synchronize()    

        # Return host outputs
        for i in range(1, self.engine.num_bindings):      
            shape = self.engine.get_binding_shape(i)      
            name = self.engine.get_binding_name(i)        
            total_shape = np.prod(shape)
            output = host_outputs[i - 1][0: total_shape]  
            output = np.reshape(output, tuple(shape))     
            if name in cif_names:      
                index_n = cif_names.index(name)           
                tmp = torch.from_numpy(output[0])         
                cif = tmp.cpu().numpy()
            elif name in caf_names:    
                index_n = caf_names.index(name)           
                tmp = torch.from_numpy(output[0])         
                caf = tmp.cpu().numpy()

        heads = [cif, caf]    
        self.cuda_context.pop() 

        fields = heads
        return fields
  1. Traceback: The error I am getting is the following:
[INFO] Configuring Parser
[INFO] Defining Pose Estimator (Creating / Loading Engine)
[INFO] TRT Model Path: models/tensorrt/openpifpaf_mobilenetv3large_129_97_d16.trt
[TensorRT] INFO: [MemUsageChange] Init CUDA: CPU +353, GPU +0, now: CPU 432, GPU 5316 (MiB)
Loading TRT engine
[TensorRT] INFO: Loaded engine size: 10 MB
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine begin: CPU 443 MiB, GPU 5336 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +226, GPU +314, now: CPU 673, GPU 5654 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +307, GPU +416, now: CPU 980, GPU 6070 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 980, GPU 6070 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] deserializeCudaEngine end: CPU 980 MiB, GPU 6070 MiB
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation begin: CPU 1323 MiB, GPU 6415 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +226, GPU +228, now: CPU 1549, GPU 6643 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +306, GPU +306, now: CPU 1855, GPU 6949 (MiB)
[TensorRT] INFO: [MemUsageSnapshot] ExecutionContext creation end: CPU 1855 MiB, GPU 6950 MiB
[INFO] Running Inference w TRT Engine
[TensorRT] ERROR: 1: [pointWiseV2Helpers.h::launchPwgenKernel::532] Error Code 1: Cuda Driver (invalid resource handle)
DEBUG:openpifpaf.visualizer.base:cif: indices = []
DEBUG:openpifpaf.show.painters:color connections = True, lw = 2, marker = 6
DEBUG:openpifpaf.show.painters:color connections = False, lw = 6, marker = 3
DEBUG:openpifpaf.visualizer.base:caf: indices = []
DEBUG:openpifpaf.show.painters:color connections = True, lw = 2, marker = 6
DEBUG:openpifpaf.show.painters:color connections = False, lw = 6, marker = 3
DEBUG:openpifpaf.decoder.cifcaf:initial annotations = 0
DEBUG:openpifpaf.decoder.utils.cif_hr:target_intensities 0.006s
DEBUG:openpifpaf.decoder.utils.cif_seeds:seeds 0, 0.006s (C++ 0.004s)
DEBUG:openpifpaf.decoder.utils.caf_scored:scored caf (0, 0) in 0.003s
DEBUG:openpifpaf.decoder.cifcaf:annotations 0, 0.017s
INFO:openpifpaf.decoder.cifcaf:0 annotations: []
inference time is 9.053708233999942 and decoder time is :0.04443562999995265 and fps:0

predictions []
Output image is saved!
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 3311, GPU 7569 (MiB)

As you can see, my pipeline is not detecting keypoints (predictions array is empty). I have tested this for multiple models/engines (ResNet50, ShuffleNet, MobileNetv3).

Hi,

This looks like setup related issue. We are moving this post to the Jetson Xavier forum to get better help.

Thank you.

Hi,

Since trtexec can serialize the engine successfully, the issue should come from your implementation.

Do you run the inference in threads?
If not, you don’t need to push/pop the CUDA context.

self.cuda_context.push()

You can find an example for python inference below:
https://elinux.org/Jetson/L4T/TRT_Customized_Example#OpenCV_with_PLAN_model

If you still meet the error, please share the ONNX model with us to reproduce.
Thanks.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.