Some PyTorch model with slicing operation fails on inference

Description

I had to convert rexnet pytorch model into TensorRT to use it with Deepstream 6.0 on Jetson Nano, but I ended with errors below:

# deepstream-app -t -c config.txt

... (after object appears and secondary-gie starts to infer)

ERROR: [TRT]: [shapeMachine.cpp::execute::565] Error Code 7: Internal Error (Slice_28: ISliceLayer has out of bounds access on axis 0
out of bounds access for slice
Instruction: CHECK_SLICE 2 0 16 1
)
ERROR: [TRT]: [executionContext.cpp::enqueueInternal::360] Error Code 2: Internal Error (Could not resolve slots: )
ERROR: Failed to enqueue trt inference batch
ERROR: Infer context enqueue buffer failed, nvinfer error:NVDSINFER_TENSORRT_ERROR
0:07:31.852496379   144   0x55a00d40f0 WARN                 nvinfer gstnvinfer.cpp:1324:gst_nvinfer_input_queue_loop:<secondary_gie_0> error: Failed to queue input batch for inferencing
ERROR from secondary_gie_0: Failed to queue input batch for inferencing
Debug info: /dvs/git/dirty/git-master_linux/deepstream/sdk/src/gst-plugins/gst-nvinfer/gstnvinfer.cpp(1324): gst_nvinfer_input_queue_loop (): /GstPipeline:pipeline/GstBin:secondary_gie_bin/GstNvInfer:secondary_gie_0
Quitting
[NvMultiObjectTracker] De-initialized

(deepstream-app:144): GLib-GObject-CRITICAL **: 03:22:14.558: g_object_unref: assertion 'G_IS_OBJECT (object)' failed
App run failed

Full log below:

# deepstream-app -t -c config.txt

Opening in BLOCKING MODE
ERROR: Deserialize engine failed because file path: /workspace/rexnet.engine open error
0:00:01.416849455   144   0x55a054d550 WARN                 nvinfer gstnvinfer.cpp:635:gst_nvinfer_logger:<secondary_gie_0> NvDsInferContext[UID 2]: Warning from NvDsInferContextImpl::deserializeEngineAndBackend() <nvdsinfer_context_impl.cpp:1889> [UID = 2]: deserialize engine from file :/workspace/rexnet.engine failed
0:00:01.416999302   144   0x55a054d550 WARN                 nvinfer gstnvinfer.cpp:635:gst_nvinfer_logger:<secondary_gie_0> NvDsInferContext[UID 2]: Warning from NvDsInferContextImpl::generateBackendContext() <nvdsinfer_context_impl.cpp:1996> [UID = 2]: deserialize backend context from engine from file :/workspace/rexnet.engine failed, try rebuild
0:00:01.417051178   144   0x55a054d550 INFO                 nvinfer gstnvinfer.cpp:638:gst_nvinfer_logger:<secondary_gie_0> NvDsInferContext[UID 2]: Info from NvDsInferContextImpl::buildModel() <nvdsinfer_context_impl.cpp:1914> [UID = 2]: Trying to create engine from model files
WARNING: [TRT]: onnx2trt_utils.cpp:364: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
WARNING: [TRT]: DLA requests all profiles have same min, max, and opt value. All dla layers are falling back to GPU
WARNING: [TRT]: Detected invalid timing cache, setup a local cache instead
WARNING: [TRT]: Min value of this profile is not valid
0:07:13.644605812   144   0x55a054d550 INFO                 nvinfer gstnvinfer.cpp:638:gst_nvinfer_logger:<secondary_gie_0> NvDsInferContext[UID 2]: Info from NvDsInferContextImpl::buildModel() <nvdsinfer_context_impl.cpp:1947> [UID = 2]: serialize cuda engine to file: /workspace/rexnet.onnx_b16_gpu0_fp16.engine successfully
INFO: [FullDims Engine Info]: layers num: 2
0   INPUT  kFLOAT images          3x224x224       min: 1x3x224x224     opt: 16x3x224x224    Max: 16x3x224x224
1   OUTPUT kFLOAT output          3               min: 0               opt: 0               Max: 0

0:07:13.804415333   144   0x55a054d550 INFO                 nvinfer gstnvinfer_impl.cpp:313:notifyLoadModelStatus:<secondary_gie_0> [UID 2]: Load new model:/workspace/secondary/config_infer_secondary.txt sucessfully
gstnvtracker: Loading low-level lib at /opt/nvidia/deepstream/deepstream-6.0/lib/libnvds_nvmultiobjecttracker.so
gstnvtracker: Batch processing is ON
gstnvtracker: Past frame output is ON
[NvMultiObjectTracker] Initialized
WARNING: [TRT]: Using an engine plan file across different models of devices is not recommended and is likely to affect performance or even cause errors.
0:07:14.378722694   144   0x55a054d550 INFO                 nvinfer gstnvinfer.cpp:638:gst_nvinfer_logger:<primary_gie> NvDsInferContext[UID 1]: Info from NvDsInferContextImpl::deserializeEngineAndBackend() <nvdsinfer_context_impl.cpp:1900> [UID = 1]: deserialized trt engine from :/workspace/primary.engine
INFO: [Implicit Engine Info]: layers num: 2
0   INPUT  kFLOAT images          3x768x768
1   OUTPUT kFLOAT output          12096x7

0:07:14.378848166   144   0x55a054d550 INFO                 nvinfer gstnvinfer.cpp:638:gst_nvinfer_logger:<primary_gie> NvDsInferContext[UID 1]: Info from NvDsInferContextImpl::generateBackendContext() <nvdsinfer_context_impl.cpp:2004> [UID = 1]: Use deserialized engine model: /workspace/primary.engine
0:07:14.390004525   144   0x55a054d550 INFO                 nvinfer gstnvinfer_impl.cpp:313:notifyLoadModelStatus:<primary_gie> [UID 1]: Load new model:/workspace/primary/config_infer_primary.txt sucessfully

Runtime commands:
        h: Print this help
        q: Quit

        p: Pause
        r: Resume

NOTE: To expand a source in the 2D tiled display and view object details, left-click on the source.
      To go back to the tiled display, right-click anywhere on the window.


**PERF:  FPS 0 (Avg)
**PERF:  0.00 (0.00)
** INFO: <bus_callback:194>: Pipeline ready

Opening in BLOCKING MODE
NvMMLiteOpen : Block : BlockType = 261
NVMEDIA: Reading vendor.tegra.display-size : status: 6
NvMMLiteBlockCreate : Block : BlockType = 261
** INFO: <bus_callback:180>: Pipeline running

NvMMLiteOpen : Block : BlockType = 4
===== NVMEDIA: NVENC =====
NvMMLiteBlockCreate : Block : BlockType = 4
H264: Profile = 66, Level = 0
NVMEDIA_ENC: bBlitMode is set to TRUE
**PERF:  30.59 (28.88)
**PERF:  29.16 (29.50)
**PERF:  29.11 (29.09)
**PERF:  29.09 (29.32)
**PERF:  29.12 (29.12)
ERROR: [TRT]: [shapeMachine.cpp::execute::565] Error Code 7: Internal Error (Slice_28: ISliceLayer has out of bounds access on axis 0
out of bounds access for slice
Instruction: CHECK_SLICE 2 0 16 1
)
ERROR: [TRT]: [executionContext.cpp::enqueueInternal::360] Error Code 2: Internal Error (Could not resolve slots: )
ERROR: Failed to enqueue trt inference batch
ERROR: Infer context enqueue buffer failed, nvinfer error:NVDSINFER_TENSORRT_ERROR
0:07:31.852496379   144   0x55a00d40f0 WARN                 nvinfer gstnvinfer.cpp:1324:gst_nvinfer_input_queue_loop:<secondary_gie_0> error: Failed to queue input batch for inferencing
ERROR from secondary_gie_0: Failed to queue input batch for inferencing
Debug info: /dvs/git/dirty/git-master_linux/deepstream/sdk/src/gst-plugins/gst-nvinfer/gstnvinfer.cpp(1324): gst_nvinfer_input_queue_loop (): /GstPipeline:pipeline/GstBin:secondary_gie_bin/GstNvInfer:secondary_gie_0
Quitting
[NvMultiObjectTracker] De-initialized

(deepstream-app:144): GLib-GObject-CRITICAL **: 03:22:14.558: g_object_unref: assertion 'G_IS_OBJECT (object)' failed
App run failed

but I could mitigate this problem by change forward() on model from:

def forward(self, x):
    out = self.out(x)
    if self.use_shortcut:
        out[:, 0:self.in_channels] += x  # self.in_channels won't change during inference

    return out

to:

def forward(self, x):
    feature = self.out(x)
    if self.use_shortcut:
        fB, fC, fH, fW = list(feature.shape)
        x_ext = torch.concat([x, torch.zeros(fB, fC - self.in_channels, fH, fW)], axis=1)
        feature = feature + x_ext

    return feature

I’ve create separate git repository for rexnet model and created commit #1 and #2, You also can look around there.

Though problem has gone, I wondered if this was a TensorRT-related bug or model code’s bug. Is there anyone facing same issue?

Thank you.

Environment (PC - for converting PyTorch model to ONNX)

TensorRT Version: v8.0.3
GPU Type: NVIDIA RTX 3060 12GB
Nvidia Driver Version: 495.44
CUDA Version: 11.5.50
CUDNN Version: 8.3.0.96
Operating System + Version: Ubuntu Linux 20.04.3 LTS
Python Version (if applicable): 3.8.12
PyTorch Version (if applicable): 1.11.0a0+b6df043
Baremetal or Container (if container which image + tag): Containerized (nvcr.io/nvidia/pytorch:21.11-py3)

Environment (Jetson Nano - for converting ONNX to TensorRT)

TensorRT Version: v8.0.1.6
GPU Type: Jetson Nano (128-core Maxwell)
Nvidia Driver Version: NVIDIA Jetson Jetpack 4.6 (L4T 32.6.1)
CUDA Version: 10.2.300
CUDNN Version: 8.2.1.32
Operating System + Version: NVIDIA L4T 32.6.1 (Ubuntu Linux 18.04.6 LTS)
Baremetal or Container (if container which image + tag): Containerized (nvcr.io/nvidia/deepstream-l4t:6.0-triton)

Relevant Files

  1. I’ve uploaded ONNX file at Google Drive for convenience (model has uninitialized weight): Google Drive Link

    • Original exported ONNX file (after Step 2 - onnxsim): rexnetv1_1.0_before_forward_hack_b16_noweight.onnx
    • Mitigation-applied exported ONNX file: rexnetv1_1.0_after_forward_hack_b16_noweight.onnx
  2. Deepstream 6.0 configuration file below:

    • config.txt
    [application]
    enable-perf-measurement=1
    perf-measurement-interval-sec=3
    
    [tiled-display]
    enable=1
    rows=1
    columns=1
    width=1920
    height=1080
    gpu-id=0
    #(0): nvbuf-mem-default - Default memory allocated, specific to particular platform
    #(1): nvbuf-mem-cuda-pinned - Allocate Pinned/Host cuda memory, applicable for Tesla
    #(2): nvbuf-mem-cuda-device - Allocate Device cuda memory, applicable for Tesla
    #(3): nvbuf-mem-cuda-unified - Allocate Unified cuda memory, applicable for Tesla
    #(4): nvbuf-mem-surface-array - Allocate Surface Array memory, applicable for Jetson
    nvbuf-memory-type=0
    
    [source0]
    enable=1
    #Type - 1=CameraV4L2 2=URI 3=MultiURI 4=RTSP
    type=3
    uri=file:///opt/nvidia/deepstream/deepstream-6.0/samples/streams/sample_1080p_h264.mp4
    num-sources=1
    drop-frame-interval=0
    gpu-id=0
    # (0): memtype_device   - Memory type Device
    # (1): memtype_pinned   - Memory type Host Pinned
    # (2): memtype_unified  - Memory type Unified
    cudadec-memtype=0
    
    [sink0]
    enable=1
    #Type - 1=FakeSink 2=EglSink 3=File
    type=1
    sync=0
    source-id=0
    gpu-id=0
    qos=0
    nvbuf-memory-type=0
    overlay-id=1
    
    [sink1]
    enable=1
    #Type - 1=FakeSink 2=EglSink 3=File 4=RTSPStreaming
    type=4
    #1=h264 2=h265
    codec=1
    #encoder type 0=Hardware 1=Software
    enc-type=0
    sync=0
    bitrate=4000000
    #H264 Profile - 0=Baseline 2=Main 4=High
    #H265 Profile - 0=Main 1=Main10
    profile=0
    # set below properties in case of RTSPStreaming
    rtsp-port=8554
    udp-port=5400
    
    [osd]
    enable=1
    gpu-id=0
    border-width=1
    text-size=15
    text-color=1;1;1;1;
    text-bg-color=0.3;0.3;0.3;1
    font=Serif
    show-clock=0
    clock-x-offset=800
    clock-y-offset=820
    clock-text-size=12
    clock-color=1;0;0;0
    nvbuf-memory-type=0
    
    [streammux]
    gpu-id=0
    ##Boolean property to inform muxer that sources are live
    live-source=0
    batch-size=1
    ##time out in usec, to wait after the first buffer is available
    ##to push the batch even if the complete batch is not formed
    batched-push-timeout=40000
    ## Set muxer output width and height
    width=1920
    height=1080
    ##Enable to maintain aspect ratio wrt source, and allow black borders, works
    ##along with width, height properties
    enable-padding=1
    nvbuf-memory-type=0
    ## If set to TRUE, system timestamp will be attached as ntp timestamp
    ## If set to FALSE, ntp timestamp from rtspsrc, if available, will be attached
    # attach-sys-ts-as-ntp=1
    
    [primary-gie]
    enable=1
    gpu-id=0
    #Required by the app for OSD, not a plugin property
    bbox-border-color0=1;0;0;1
    bbox-border-color1=0;1;1;1
    bbox-border-color2=0;0;1;1
    bbox-border-color3=0;1;0;1
    interval=4
    gie-unique-id=1
    nvbuf-memory-type=0
    config-file=config_infer_primary.txt
    
    [tracker]
    enable=1
    # For NvDCF and DeepSORT tracker, tracker-width and tracker-height must be a multiple of 32, respectively
    tracker-width=640
    tracker-height=384
    ll-lib-file=/opt/nvidia/deepstream/deepstream-6.0/lib/libnvds_nvmultiobjecttracker.so
    # ll-config-file required to set different tracker types
    ll-config-file=tracker_configs/config_tracker_IOU.yml
    # ll-config-file=tracker_configs/config_tracker_NvDCF_max_perf.yml
    # ll-config-file=tracker_configs/config_tracker_NvDCF_perf.yml
    # ll-config-file=tracker_configs/config_tracker_NvDCF_accuracy.yml
    # ll-config-file=tracker_configs/config_tracker_DeepSORT.yml
    gpu-id=0
    enable-batch-process=1
    enable-past-frame=1
    display-tracking-id=1
    
    [secondary-gie0]
    enable=1
    gpu-id=0
    gie-unique-id=2
    operate-on-gie-id=1
    operate-on-class-ids=0;
    config-file=config_infer_secondary.txt
    
    [tests]
    file-loop=0
    
    • config_infer_secondary.txt
    [property]
    gpu-id=0
    net-scale-factor=0.00390625
    onnx-file=rexnet.onnx
    model-engine-file=rexnet.engine
    batch-size=16
    output-blob-names=output
    # 0=FP32, 1=INT8, 2=FP16 mode
    network-mode=2
    labelfile-path=labels.txt
    force-implicit-batch-dim=0
    model-color-format=0
    # 1 - primary, 2 - secondary
    process-mode=2
    is-classifier=1
    classifier-async-mode=0
    classifier-threshold=0.5
    input-object-min-width=32
    input-object-min-height=32
    operate-on-gie-id=1
    operate-on-class-ids=0;
    # NvBufSurfTransformInter_Bilinear
    scaling-filter=1
    scaling-compute-hw=0
    parse-classifier-func-name=NvDsInferClassiferParseCustomClassifier
    custom-lib-path=nvdsinfer_custom_impl_classifier/libnvdsinfer_custom_impl_classifier.so
    

Steps To Reproduce

I’ve created TensorRT model with those process:

  1. Convert ReXNet model into ONNX model as below:

    import torch
    from rexnetv1 import ReXNetV1
    
    model = ReXNetV1().eval()
    batch_size = 16
    torch.onnx._export(
        rexnet_model,
        torch.randn(batch_size, 3, 224, 224),
        ONNX_OUTPUT_PATH,
        input_names=['images'],
        output_names=['output'],
        dynamic_axes=None,
        opset_version=11
    )
    
  2. Simplify ONNX using daquexian/onnx-simplifier:

    from onnxsim import simplify
    
    onnx_model = onnx.load(ONNX_OUTPUT_PATH)
    model_simp, check = simplify(onnx_model,
                                 dynamic_input_shape=False,
                                 input_shapes=None)
    
  3. Set ONNX_OUTPUT_PATH to Deepstream 6.0 deepstream-app sample configuration file:

Hi,
Can you try running your model with trtexec command, and share the “”–verbose"" log in case if the issue persist
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec

You can refer below link for all the supported operators list, in case any operator is not supported you need to create a custom plugin to support that operation

Also, request you to share your model and script if not shared already so that we can help you better.

Meanwhile, for some common errors and queries please refer to below link:

Thanks!

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.