Segmentation fault (core dumped) after run IExecutionContext.execute_async_v3()

Description

I used the following commands to convert an ONNX model to a TRT engine, where the input.onnx file is the original model:

polygraphy surgeon sanitize --fold-constants ./input.onnx  -o output.onnx
trtexec --onnx=output.onnx --saveEngine=model.plan --minShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
        --optShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
        --maxShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 --verbose=true

Then I tried to perform inference using TensorRT, but encountered a “Segmentation fault (core dumped)” error. Below are my model information and code. I would greatly appreciate any assistance you can provide.

TensorRT Engine

[I] ==== TensorRT Engine ====
    Name: Unnamed Network 0 | Explicit Batch Engine
    
    ---- 5 Engine Input(s) ----
    {audio [dtype=float32, shape=(1, 256, 128)],
     trunc_start [dtype=int64, shape=(1,)],
     offset [dtype=int64, shape=(1,)],
     att_cache [dtype=float32, shape=(32, 20, 128, 128)],
     frame_cache [dtype=float32, shape=(1, 4, 128)]}
    
    ---- 2 Engine Output(s) ----
    {output [dtype=float32, shape=(64, 6400)],
     att_cache_out [dtype=float32, shape=(32, 20, 256, 128)]}
    
    ---- Memory ----
    Device Memory: 13281280 bytes
    
    ---- 1 Profile(s) (7 Tensor(s) Each) ----
    - Profile: 0
        Tensor: audio                  (Input), Index: 0 | Shapes: min=(1, 256, 128), opt=(1, 256, 128), max=(1, 256, 128)
        Tensor: trunc_start            (Input), Index: 1 | Shapes: min=(1,), opt=(1,), max=(1,)
        Tensor: offset                 (Input), Index: 2 | Shapes: min=(1,), opt=(1,), max=(1,)
        Tensor: att_cache              (Input), Index: 3 | Shapes: min=(32, 20, 128, 128), opt=(32, 20, 128, 128), max=(32, 20, 128, 128)
        Tensor: frame_cache            (Input), Index: 4 | Shapes: min=(1, 4, 128), opt=(1, 4, 128), max=(1, 4, 128)
        Tensor: output                (Output), Index: 5 | Shape: (64, 6400)
        Tensor: att_cache_out         (Output), Index: 6 | Shape: (32, 20, 256, 128)
    
    ---- 505 Layer(s) ----

My Code

from typing import Optional, List, Union
import ctypes
import os
import sys
import argparse
import numpy as np
import tensorrt as trt
import pdb
import torch
import onnxruntime as ort


class TensorRTInfer:
    """
    Implements inference for the Model TensorRT engine.
    """
    def __init__(self, engine_path):
        """
        :param engine_path: The path to the serialized engine to load from disk.
        """
        # Load TRT engine
        self.logger = trt.Logger(trt.Logger.INFO)
        trt.init_libnvinfer_plugins(self.logger, namespace="")
        with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
            assert runtime
            self.engine = runtime.deserialize_cuda_engine(f.read())
        assert self.engine
        self.context = self.engine.create_execution_context()
        assert self.context

    def infer(self):
        self.stream = torch.cuda.Stream().cuda_stream
        # Prepare the input data
        audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
        trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
        offset = torch.tensor([126], dtype=torch.int64, device="cuda")
        att_cache = torch.randn(
            32, 20, 128, 128, dtype=torch.float32, device="cuda")
        frame_cache = torch.randn(
            1, 4, 128, dtype=torch.float32, device="cuda")
        self.context.set_input_shape('audio', audio.shape)
        self.context.set_tensor_address('audio', audio.contiguous().data_ptr())
        self.context.set_input_shape('trunc_start', trunc_start.shape)
        self.context.set_tensor_address(
            'trunc_start', trunc_start.contiguous().data_ptr())
        self.context.set_input_shape('offset', offset.shape)
        self.context.set_tensor_address(
            'offset', offset.contiguous().data_ptr())
        self.context.set_input_shape('att_cache', att_cache.shape)
        self.context.set_tensor_address(
            'att_cache', att_cache.contiguous().data_ptr())
        self.context.set_input_shape('frame_cache', frame_cache.shape)
        self.context.set_tensor_address(
            'frame_cache', frame_cache.contiguous().data_ptr())
        import pdb; pdb.set_trace()
        # Prepare the output data
        att_cache_out = torch.zeros(
            32, 20, 256, 128, dtype=torch.float32, device="cuda")
        output = torch.zeros(64, 6400, dtype=torch.float32, device="cuda")
        self.context.set_tensor_address(
            'att_cache_out', att_cache_out.contiguous().data_ptr())
        self.context.set_tensor_address(
            'output', output.contiguous().data_ptr())
        # self.context.set_optimization_profile_async(0, self.stream)
        torch.cuda.synchronize()
        # a = datetime.now()
        self.context.execute_async_v3(self.stream)
        torch.cuda.synchronize()
        return output


def trt_infer(args):

    trt_infer = TensorRTInfer(args.engine)
    trt_infer.infer()


def onnx_infer():
    ort_session = ort.InferenceSession(
        "./poly/encoder.onnx", providers=['CUDAExecutionProvider'])
    audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
    trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
    offset = torch.tensor([126], dtype=torch.int64, device="cuda")
    att_cache = torch.randn(
        32, 20, 128, 128, dtype=torch.float32, device="cuda")
    frame_cache = torch.randn(
        1, 4, 128, dtype=torch.float32, device="cuda")
    outputs = ort_session.run(
        None, {"audio": audio.cpu().numpy(), "trunc_start": trunc_start.cpu().numpy(), "offset": offset.cpu().numpy(), "att_cache": att_cache.cpu().numpy(), "frame_cache": frame_cache.cpu().numpy()})
    print(f"output shape={outputs[0].shape} att shape={outputs[1].shape}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-e", "--engine", default="./model.plan", help="The serialized TensorRT engine"
    )
    args = parser.parse_args()
    trt_infer(args)
    # onnx_infer()

Environment

TensorRT Version: 10.3.0
GPU Type: H20
Nvidia Driver Version: 535.161.08
CUDA Version: 12.4
CUDNN Version: 8.9
Operating System + Version:
Python Version (if applicable): 3.10.4
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

Relevant Files

Please attach or include links to any models, data, files, or scripts necessary to reproduce your issue. (Github repo, Google Drive, Dropbox, etc.)

Steps To Reproduce

Please include:

  • Exact steps/commands to build your repro
  • Exact steps/commands to run your repro
  • Full traceback of errors encountered

Hi @simonzgx ,
Can you pls share your onnx model and verbose logs.

Thanks

Hi @ AakankshaS ,thanks for your reply!
I’m sorry I can’t share detailed information about my model. However, I eventually traced the coredump issue mentioned above to the following code in the model:

# audio shape:[1, 130, 1280]
audio = torch.narrow(audio, dim=1, start=trunc_start, length=128)

Here, trunc_start is a model input with a value of 2 .
I’ve implemented some workarounds to bypass this issue for now, but if possible, I’d like to know how this problem occurs—it’s been quite puzzling to me.