Adding multiple inference on TensorRT (Invalid Resource Handle Error)

Platform Info:
Linux distro and version: Ubuntu 18.04.3 LTS
GPU type: Maxwell
nvidia driver version: Installed with Jetpack 4.2
CUDA version: 10.0.326
Python version: 3.6
Tensorflow version: 1.14.0
TensorRT version: 5.1.6.1
If Jetson, OS, hw versions: Jetson Nano

I am trying to run two inferences in a pipeline using Jetson Nano. The first inference is object detection using MobileNet and TensorRT. My code for the first inference is pretty much replicated from the [AastaNV/TRT_Obj_Detection](https://github.com/AastaNV/TRT_object_detection) repository. The only difference being that change that code so that it resides inside a class Inference1.

The second inference job uses the outputs of the first inference to run further analysis. For this inference, I use tensorflow (not TensorRT, but I assume it is called in the backend?) using a custom model. This model is loaded from a .pb file (frozen graph). Once loaded, the inference is performed by calling the session.run() command of tensorflow.

If I run ONLY Inference1 or ONLY Inference2, the code runs properly without any errors. However, when I pipe them, I get the following error:

WARNING:tensorflow:From /usr/lib/python3.6/dist-packages/graphsurgeon/DynamicGraph.py:4: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.

[TensorRT] INFO: Glob Size is 14049908 bytes.
[TensorRT] INFO: Added linear block of size 5760000
[TensorRT] INFO: Added linear block of size 2880000
[TensorRT] INFO: Added linear block of size 409600
[TensorRT] INFO: Added linear block of size 218624
[TensorRT] INFO: Added linear block of size 61440
[TensorRT] INFO: Added linear block of size 57344
[TensorRT] INFO: Added linear block of size 30720
[TensorRT] INFO: Added linear block of size 20992
[TensorRT] INFO: Added linear block of size 9728
[TensorRT] INFO: Added linear block of size 9216
[TensorRT] INFO: Added linear block of size 2560
[TensorRT] INFO: Added linear block of size 2560
[TensorRT] INFO: Added linear block of size 1024
[TensorRT] INFO: Added linear block of size 512
[TensorRT] INFO: Found Creator FlattenConcat_TRT
[TensorRT] INFO: Found Creator GridAnchor_TRT
[TensorRT] INFO: Found Creator FlattenConcat_TRT
[TensorRT] INFO: Found Creator NMS_TRT
[TensorRT] INFO: Deserialize required 5159079 microseconds.
Infering on input.mp4
WARNING:tensorflow:From /home/user/Desktop/SVM_TensorRT/deep_sort/tools/generate_detections.py:75: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

2018-01-29 02:01:38.254282: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcuda.so.1
2018-01-29 02:01:38.286962: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.287300: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1640] Found device 0 with properties: 
name: NVIDIA Tegra X1 major: 5 minor: 3 memoryClockRate(GHz): 0.9216
pciBusID: 0000:00:00.0
2018-01-29 02:01:38.287552: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.0
2018-01-29 02:01:38.287744: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
2018-01-29 02:01:38.287983: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcufft.so.10.0
2018-01-29 02:01:38.288201: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcurand.so.10.0
2018-01-29 02:01:38.415478: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcusolver.so.10.0
2018-01-29 02:01:38.484010: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcusparse.so.10.0
2018-01-29 02:01:38.484668: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudnn.so.7
2018-01-29 02:01:38.485343: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.486009: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.486286: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1763] Adding visible gpu devices: 0
2018-01-29 02:01:38.665379: W tensorflow/core/platform/profile_utils/cpu_utils.cc:98] Failed to find bogomips in /proc/cpuinfo; cannot determine CPU frequency
2018-01-29 02:01:38.682935: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x24f9ea50 executing computations on platform Host. Devices:
2018-01-29 02:01:38.683009: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
2018-01-29 02:01:38.764975: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.765291: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x572614c0 executing computations on platform CUDA. Devices:
2018-01-29 02:01:38.765349: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): NVIDIA Tegra X1, Compute Capability 5.3
2018-01-29 02:01:38.766014: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.766158: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1640] Found device 0 with properties: 
name: NVIDIA Tegra X1 major: 5 minor: 3 memoryClockRate(GHz): 0.9216
pciBusID: 0000:00:00.0
2018-01-29 02:01:38.766716: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.0
2018-01-29 02:01:38.766814: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
2018-01-29 02:01:38.766879: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcufft.so.10.0
2018-01-29 02:01:38.767002: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcurand.so.10.0
2018-01-29 02:01:38.767174: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcusolver.so.10.0
2018-01-29 02:01:38.767311: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcusparse.so.10.0
2018-01-29 02:01:38.767423: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudnn.so.7
2018-01-29 02:01:38.767731: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.768049: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:38.768136: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1763] Adding visible gpu devices: 0
2018-01-29 02:01:38.783718: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.0
2018-01-29 02:01:41.046094: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1181] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-01-29 02:01:41.046260: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1187]      0 
2018-01-29 02:01:41.046311: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1200] 0:   N 
2018-01-29 02:01:41.054160: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:41.054730: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:972] ARM64 does not support NUMA - returning NUMA node zero
2018-01-29 02:01:41.112041: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1326] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 85 MB memory) -> physical GPU (device: 0, name: NVIDIA Tegra X1, pci bus id: 0000:00:00.0, compute capability: 5.3)
WARNING:tensorflow:From /home/user/Desktop/SVM_TensorRT/deep_sort/tools/generate_detections.py:76: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.

WARNING:tensorflow:From /home/user/Desktop/SVM_TensorRT/deep_sort/tools/generate_detections.py:80: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

[TensorRT] ERROR: CUDA cask failure at execution for trt_maxwell_scudnn_128x32_relu_small_nn_v1.
[TensorRT] ERROR: cuda/caskConvolutionLayer.cpp (355) - Cuda Error in execute: 33 (invalid resource handle)
[TensorRT] ERROR: cuda/caskConvolutionLayer.cpp (355) - Cuda Error in execute: 33 (invalid resource handle)

From what I see in the log, the TensorRT serialized graph is loaded without any problems. Tensorflow is also imported and it recognizes my GPU. From my searching on the internet I have found that this problem maybe related to CUDA Contexts? I therefore show below how i have setup the CUDA context in my code below. The create_cuda_context is only called once during the initialization of the Inference1 class. The run_inference_for_single_image is called every iteration.

def create_cuda_context(self):
        self.host_inputs, self.host_outputs = [], []
        self.cuda_inputs, self.cuda_outputs = [], []
        self.bindings = []
        self.stream = cuda.Stream()
    
        for binding in self.engine:
            size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
            host_mem = cuda.pagelocked_empty(size, np.float32)
            cuda_mem = cuda.mem_alloc(host_mem.nbytes)

            self.bindings.append(int(cuda_mem))
            if self.engine.binding_is_input(binding):
                self.host_inputs.append(host_mem)
                self.cuda_inputs.append(cuda_mem)
            else:
                self.host_outputs.append(host_mem)
                self.cuda_outputs.append(cuda_mem)
        self.context = self.engine.create_execution_context()
        self.stop = False

    def run_inference_for_single_image(self, image):
        ''' Copies the image (already raveled) input into GPU memory, performs the forward pass
        and copies the result back to CPU memory
        '''
        np.copyto(self.host_inputs[0], image)
        cuda.memcpy_htod_async(self.cuda_inputs[0], self.host_inputs[0], self.stream)
        self.context.execute_async(bindings=self.bindings, stream_handle=self.stream.handle)
        cuda.memcpy_dtoh_async(self.host_outputs[1], self.cuda_outputs[1], self.stream)
        cuda.memcpy_dtoh_async(self.host_outputs[0], self.cuda_outputs[0], self.stream)
        self.stream.synchronize()
        return self.host_outputs[0]

Any help would be greatly appreciated!

Ran inference in callback function in ROS, got same problem but no solution!

my enviroment is:
Ubuntu: 16.04 64bit
CUDA: 10.0
CUDNN: 7.5.0
TensorRT: 5.1.5
python: 3.6
ROS: kinetic
GPU: Geforce GTX 1080 with Max-Q
NVIDIA Driver Version: 430.40

Hi,

The TensorRT runtime can be used by multiple threads simultaneously, so long as each object uses a different execution context.

Please refer below link for more detail:
https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/tensorrt-601/tensorrt-best-practices/index.html#thread-safety

Thanks