Inference chaining using Deepstream and Triton

Please provide complete information as applicable to your setup.

• Hardware Platform (Jetson / GPU) Orin AGX 64Gb
• DeepStream Version 7.1.0
• JetPack Version (valid for Jetson only) 6.2
• TensorRT Version TensorRT v101000
• Triton Inference Server NVIDIA Release 25.05 (build 170551412) - Triton Server Version 2.58.0

Hello,

I have tested a model using an AGX Orin 64gb with triton inference server and deepstream successfully.
Now, I would like to test the same model in 2x AGX Orin 32gb configuration.

For that purpose, I have split the ONNX model in half and convert each halve in TensorRT using trtexec.

I would like to know how to configure the deepstream pipeline when using a onnx model split in half. As I a using triton inference server, I am using the nvinferserverplugin.

Here are the extracts of the config.pbtxtfiles :

name: "model_part1"
platform: "tensorrt_plan"
max_batch_size: 1
default_model_filename: "model_part1.engine"
input [
  {
    name: "input"
    data_type: TYPE_FP32
    format: FORMAT_NCHW
    dims: [ 3, 720, 1280 ]
  }
]
output [
  {
    name: "/conv3_1/Conv_output_0"
    data_type: TYPE_FP32
    dims: [ 64,90,160 ]
  }
]
name: "model_part2"
platform: "tensorrt_plan"
max_batch_size: 1
default_model_filename: "model_part2.engine"
input [
  {
    name: "/conv3_1/Conv_output_0"
    data_type: TYPE_FP32
    dims: [ 64,90,160 ]
  }
]
output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [ 5, 720, 1280 ]
  }
]

I have tried to configure the second model as a sgie with option process_mode: PROCESS_MODE_CLIP_OBJECTS, but it doesn’t seems to work.

Could you help me ?

Kind regards

  1. please refer to \opt\nvidia\deepstream\deepstream\sources\apps\sample_apps\deepstream-test2 for a pgie+sgie nvinferserver sample.
  2. if the app stll can’t work, what are two models used to do respectively? How did you know the the outputs of the first model are correct? How did you do preprocessing for the second model?

Hello,

  1. please refer to \opt\nvidia\deepstream\deepstream\sources\apps\sample_apps\deepstream-test2 for a pgie+sgie nvinferserver sample.

I tried to get some inspiration from the suggested example, but it is not the same scenario. I don’t have a pgie+sgie, I have only one pgie, but splitted in two smaller parts. In other words, I don’t have 2 different models, it is 1 model that I have split in 2 halves.

Therefore the output of the first model is the same as the input of the second model (64x90x160), but it doesn’t represent anything.

Conceptually, I would like to run the inference on the first model, get the output tensor (no postprocess) and then feed the second model with this tensor (no preprocess). I don’t need anything in between.

  1. if the app stll can’t work, what are two models used to do respectively? How did you know the the outputs of the first model are correct? How did you do preprocessing for the second model?

As I said, I don’t care about the output of the first model, I just need to get it and feed it directly to the second model withtout preprocessing.

The reason for all of this is to create a cluster of triton server and be able to run big models by splitting them into smaller ones.

Thank you

Thanks for the sharing! please refer to the sample /opt/nvidia/deepstream/deepstream/sources/TritonBackendEnsemble. The SGIE is a Triton ensemble model that has Secondary_VehicleMake, Secondary_VehicleTypes. You can use your two models as an ensemble model.

The example you are referencing could inded work if both half-models were on the same triton instance.

The idea here it to have each GPU (AGX Orin) its own triton server instance. Therefore, the ensemble model is not a possiblity in my case.

What I want to do is conceptually really simple, I want to take one model, get its output tensor and then feed another model with this tensor (simple model chaining without altering the data in between).

Deepstream as no plugin or configuration for such simple behaviour ?

How did you split the ONNX model? what is the output format? are there two ONNX models after splitting?

Currently there is no ready-made sample to pass the inference results to other device for inference. Here are some ideas.

  1. run “model_part1” on the first model. especially “output-tensor-meta=1” need to be set for nvinfer. then you can access the inference results from the user meta. please refer to the sample deepstream-infer-tensor-meta-test. then send the results to the other device with your custom IPC implementation.
  2. after receiving the inference results on the other device, you can continue to do inference with the “model_part2”.

Yes there are two models after splitting and then I convert them to engine files using trtexec.

In this scenario, I am loosing the benefit of using Deepstream pipelines. As I am using the plugin nvinferserverI can reach both triton server from the same pipeline, the issue is only regarding the data sent to the second model.

If I understand well, there is nothing available as standard in deepstream to handle such scenario…

I was hoping not to get inside the deepstream plugins source code, but it seems that I will need to create my custom nvinferserverfor the second part of the model or at least a custom custom_process_funcion.

  1. could you elaborate on this? how did you run one pipeline on two devices?
  2. do you mean run “model_part1” on one device and run “model_part2” on the other device? if so, how do you pass the inference results between the two devices? as wrote in my last comment, Currently there is no ready-made sample to pass the inference results to other device for inference. custom_process_funcion can’t fix this issue. please refer to /opt/nvidia/deepstream/deepstream/sources/TritonOnnxYolo/nvdsinferserver_custom_impl_yolo/nvdsinferserver_custom_process_yolo.cpp for how to use custom_process_funcion .

The deepsteam pipeline is running on Orin 1 only, but each Orin as its own triton server instance.
In the nvinferserver configuration file, I can specify the url of the triton server using grpc for each model.
If I take as an example the tracking of vehicles, I can configure pgie and sgie like this :

## PGIE
infer_config {
  unique_id: 1
  gpu_ids: [0]
  max_batch_size: 1
  backend {
    triton {
      model_name: "trafficcamnet"
      version: -1
      grpc {
        url: "0.0.0.0:8001" ## -> triton server on the ORIN 1
        enable_cuda_buffer_sharing: true
      }
    }
  }
## SGIE
infer_config {
  unique_id: 1
  gpu_ids: [0]
  max_batch_size: 0
  backend {
    triton {
      model_name: "vehiclemakenet"
      version: -1
      grpc {
        url: "192.168.121.2:8001" ## -> triton server on the ORIN 2
        enable_cuda_buffer_sharing: true
      }
    }
  }

In the deepstream pipeline, I have : … → pgie → sgie → …

And that’s work fine. Of course, there is some preprocess and postprocess definitions in the config files to make it works.
But it seems that in this scenario, the sgie is not taking as input, the output of the pgie. It seems to use some metadata as its input tensor (bbox in the postprocess of the pgie).

In my original scenario (split of onnx model), I would like the “sgie” to take as its input, the raw output tensor from the pgie (because metadata are meaningless, as we are in the middle of the model).

Using nvinferserverwith grpc is the thing for the communication side. I need to solve the “data handling” thing.

you can use this pipeline “…->nvinfersever(1)->nvdspreprocess->nvinfersever(2)->…”. There are some things to note.

  1. In the cfg of nvinfersever 1, “output_tensor_meta: true” needs to be set for saving inference results to user meta. please refer to opt/nvidia/deepstream/deepstream/sources/apps/sample_apps/deepstream-3d-action-recognition/config_triton_infer_primary_2d_action.txt.
  2. nvdspreprocess plugin provides a custom library interface for preprocessing on input streams. In the custom lib, you can get the inference results from user meta, then copy them as new tensors. pleas e refer to this sample for how to prepare tensors.
  3. In the cfg of nvinfersever 2, input_tensor_from_meta needs to be set. with this configuration, nvinferserver will use the tensors directly instead of doing preprocessing. please refer to opt/nvidia/deepstream/deepstream/sources/apps/sample_apps/deepstream-3d-action-recognition/config_triton_infer_primary_2d_action.txt.
  1. In the cfg of nvinfersever 1, “output_tensor_meta: true” needs to be set for saving inference results to user meta. please refer to opt/nvidia/deepstream/deepstream/sources/apps/sample_apps/deepstream-3d-action-recognition/config_triton_infer_primary_2d_action.txt.

OK, I have done it and it works because I can see the tensor with a probe using user_meta.base_meta.meta_type == pyds.NvDsMetaType.NVDSINFER_TENSOR_OUTPUT_META

  1. nvdspreprocess plugin provides a custom library interface for preprocessing on input streams. In the custom lib, you can get the inference results from user meta, then copy them as new tensors. pleas e refer to this sample for how to prepare tensors.

I am blocked here I think. I modified the CustomTensorPreparation method like this :

NvDsPreProcessStatus
CustomTensorPreparation(CustomCtx *ctx, NvDsPreProcessBatch *batch, NvDsPreProcessCustomBuf *&buf,
                        CustomTensorParams &tensorParam, NvDsPreProcessAcquirer *acquirer)
{
   printf("CustomTensorPreparation called\n");
   
    NvDsPreProcessStatus status = NVDSPREPROCESS_TENSOR_NOT_READY;

    // Acquire a buffer from tensor pool
    buf = acquirer->acquire();
    void *pDst = buf->memory_ptr; // Destination GPU pointer

    GstBuffer *inbuf = (GstBuffer *)batch->inbuf;
    NvDsBatchMeta *batch_meta = gst_buffer_get_nvds_batch_meta(inbuf);

    if (!batch_meta) {
        g_printerr("Failed to get batch_meta from GstBuffer\n");
        return NVDSPREPROCESS_TENSOR_NOT_READY;
    }

    bool tensor_found = false;

    // Iterate over frames in batch
    for (NvDsMetaList *l_frame = batch_meta->frame_meta_list; l_frame != nullptr; l_frame = l_frame->next) {
        NvDsFrameMeta *frame_meta = (NvDsFrameMeta *)l_frame->data;

        // Iterate through frame user metadata to find tensor output
        for (NvDsMetaList *l_user = frame_meta->frame_user_meta_list; l_user != nullptr; l_user = l_user->next) {
            NvDsUserMeta *user_meta = (NvDsUserMeta *)l_user->data;

            if (user_meta->base_meta.meta_type == NVDSINFER_TENSOR_OUTPUT_META) {
                // Found tensor output from previous nvinferserver
                NvDsInferTensorMeta *tensor_meta = (NvDsInferTensorMeta *)user_meta->user_meta_data;

                g_print("Found tensor meta: %u output layers\n", tensor_meta->num_output_layers);

                for (uint i = 0; i < tensor_meta->num_output_layers; ++i) {
                    void *src_gpu_ptr = tensor_meta->out_buf_ptrs_dev[i];
                    NvDsInferDims dims = tensor_meta->output_layers_info[i].inferDims;
                    size_t num_elements = 1;
                    for (uint d = 0; d < dims.numDims; ++d) {
                        num_elements *= dims.d[d];
                    }
                    size_t layer_size_bytes = num_elements * sizeof(float); // assuming float32

                    g_print("Copying layer %u of size %zu bytes\n", i, layer_size_bytes);

                    // Copy data from previous model output (GPU) to current buffer (GPU)
                    cudaMemcpy(pDst,
                               src_gpu_ptr,      // Source: GPU pointer
                               layer_size_bytes, // Size
                               cudaMemcpyDeviceToDevice);

                    // Advance destination pointer for next layer (if needed)
                    pDst = (char *)pDst + layer_size_bytes;
                }

                tensor_found = true;
                status = NVDSPREPROCESS_SUCCESS;
                break;
            }
        }

        if (tensor_found)
            break;
    }

    if (!tensor_found) {
        g_printerr("No NvDsInferTensorMeta found in frame metadata!\n");
    }

    return status;
}

And I can see in the logs :

CustomTensorPreparation called
Found tensor meta: 1 output layers
Copying layer 0 of size 3686400 bytes

Which seems to be correct because the output tensor of model_part1 is 64x90x160 (x4 for FP32) = 3686400 bytes

Here is the config file of the nvdspreprocess plugin :

[property]
enable=1
unique-id=5
process-on-frame=1
target-unique-ids=1
network-input-order=0



    #uniquely identify the metadata generated by this element

    # gpu-id to be used
gpu-id=0
    # if enabled maintain the aspect ratio while scaling
#maintain-aspect-ratio=1
    # if enabled pad symmetrically with maintain-aspect-ratio enabled
#symmetric-padding=1
    # processig width/height at which image scaled
processing-width=160
processing-height=90
    # max buffer in scaling buffer pool
scaling-buf-pool-size=1
    # max buffer in tensor buffer pool
tensor-buf-pool-size=1
    # tensor shape based on network-input-order
network-input-shape= 1;64;90;160
    # 0=RGB, 1=BGR, 2=GRAY
network-color-format=0
    # 0=FP32, 1=UINT8, 2=INT8, 3=UINT32, 4=INT32, 5=FP16
tensor-data-type=0
    # tensor name same as input layer name
tensor-name=/conv3_1/Conv_output_0
    # 0=NVBUF_MEM_DEFAULT 1=NVBUF_MEM_CUDA_PINNED 2=NVBUF_MEM_CUDA_DEVICE 3=NVBUF_MEM_CUDA_UNIFIED
scaling-pool-memory-type=0
    # 0=NvBufSurfTransformCompute_Default 1=NvBufSurfTransformCompute_GPU 2=NvBufSurfTransformCompute_VIC
scaling-pool-compute-hw=0
    # Scaling Interpolation method
    # 0=NvBufSurfTransformInter_Nearest 1=NvBufSurfTransformInter_Bilinear 2=NvBufSurfTransformInter_Algo1
    # 3=NvBufSurfTransformInter_Algo2 4=NvBufSurfTransformInter_Algo3 5=NvBufSurfTransformInter_Algo4
    # 6=NvBufSurfTransformInter_Default
scaling-filter=0
    # custom library .so path having custom functionality
custom-lib-path=/home/orkais/orkais/examples/orkais_ulg_split/nvdspreprocess_lib/libcustom2d_preprocess.so
    # custom tensor preparation function name having predefined input/outputs
    # check the default custom library nvdspreprocess_lib for more info
custom-tensor-preparation-function=CustomTensorPreparation

[user-configs]
   # Below parameters get used when using default custom library nvdspreprocess_lib
   # network scaling factor
pixel-normalization-factor=1
   # mean file path in ppm format
#mean-file=
   # array of offsets for each channel
#offsets=

[group-0]
src-ids=0
custom-input-transformation-function=CustomTransformation
process-on-roi=0
#process-on-all-objects=0
#roi-params-src-0=0;0;100;100
#draw-roi=0
#input-object-min-width=100
#input-object-min-height=100
  1. In the cfg of nvinfersever 2, input_tensor_from_meta needs to be set. with this configuration, nvinferserver will use the tensors directly instead of doing preprocessing. please refer to opt/nvidia/deepstream/deepstream/sources/apps/sample_apps/deepstream-3d-action-recognition/config_triton_infer_primary_2d_action.txt.

Yes, I added it :

input_tensor_from_meta {
    is_first_dim_batch: true
  }

CONCLUSION
Unfortunately, if I probe the nvinferserver 2, I get the output tensor of the first nvinferserver (even if output_tensor_meta : true is set in the config of the nvinferserver 2.
Also, on my display, I can’t see the segmentation mask that should be the output of the nvinferserver2 (If I use the model, non split, I can see the segmentation mask ; the pipeline is working fine).

Could you help me further diagnose the process ?

Thank you

you can dump the inference results model1 generated to the file. then test the tensors by third-part lib or nvinfer or triton grpc smaple.

you can dump the inference results model1 generated to the file. then test the tensors by third-part lib or nvinfer or triton grpc smaple.

  1. I dump the output tensor of pgie1 [64, 90, 160] → OK
  2. I use the gprc python client for pgie2 → I need to add a batch dimension to make it works (dynamic batching in triton model config) → [1, 64, 90, 160] OK
  3. I get the answer from the triton call pgie2 [1, 5, 720, 1280] → OK

Then I compare the values given by the output of the complete model (non split) and the output of pgie2 → OK same tensor content on same frame.

Conclusion :

  • The first part of the pipeline is working (from video decoding to pgie1 output tensor)
  • pgie2 is fine as well
  • I miss the glue in between (using nvdspreprocess).

How can I debug further ?
Thank you

you can dump the dump the inference results of model 2, then compare it with the output of gprc python client. if the two data are the same, the issue should be related to the postprocessing of nvinferserver.

If I do that using pgie_src_pad.add_probe(Gst.PadProbeType.BUFFER, probe_func, 0) on pgie2 and look for pyds.NvDsMetaType.NVDSINFER_TENSOR_OUTPUT_META, then I retrieve the tensor of pgie1 (not pgie2).

if I remove

output_control { 
  output_tensor_meta : true
}

from pgie2, I still get the output tensor of pgie1 in the probe attached to pgie2.

It seems that the modification of the tensor inside the nvdspreprocess plugin is not doing its job…

nvinerserver plugin and low-level implement are opensource. you can add logs in InferGrpcClient::InferComplete to dump the inference results. then compare it with the output of gprc python client. please refer to the \opt\nvidia\deepstream\deepstream-7.1\sources\libs\nvdsinferserver\README for how to build the lib.

nvinerserver plugin and low-level implement are opensource. you can add logs in InferGrpcClient::InferComplete to dump the inference results. then compare it with the output of gprc python client. please refer to the \opt\nvidia\deepstream\deepstream-7.1\sources\libs\nvdsinferserver\README for how to build the lib.

If I do this I get this output :

...
Starting pipeline 

Decodebin child added: nvurisrc_bin_src_elem
INFO: TritonGrpcBackend id:2 initialized for model: model_part2
---ONE CALL---Output Tensor: output
  Shape: [1, 5, 720, 1280]
Tensor sample values:
  [0] = 1.560547
  [1] = 0.659668
  [2] = 3.677734
  [3] = 4.101562
  [4] = 4.101562
  [5] = 3.699219
  [6] = 5.207031
  [7] = 5.710938
  [8] = 5.871094
  [9] = 6.074219
WARNING: unsupported tensor order for dims to image-info, retry as kLinear
frameSeqLen:0
frameSeqLen iilegal, use default vaule 300
CustomTensorPreparation called
INFO: TritonGrpcBackend id:1 initialized for model: model_part1
---ONE CALL---Output Tensor: /conv3_1/Conv_output_0
  Shape: [1, 64, 90, 160]
Tensor sample values:
  [0] = 1.511719
  [1] = 0.609375
  [2] = 0.726562
  [3] = 1.117188
  [4] = 1.066406
  [5] = 1.063477
  [6] = 1.224609
  [7] = 1.257812
  [8] = 1.254883
  [9] = 1.254883
Opening in BLOCKING MODE 
NvMMLiteOpen : Block : BlockType = 261 
NvMMLiteBlockCreate : Block : BlockType = 261 
Decodebin child added: nvurisrc_bin_queue
Decodebin child added: nvurisrc_bin_nvvidconv_elem
Decodebin child added: nvurisrc_bin_src_cap_filter_nvvidconv
In cb_newpad
gstname= video/x-raw
features= <Gst.CapsFeatures object at 0xffff6f7c8580 (GstCapsFeatures at 0xfffeb46c3a80)>
in videoconvert caps = video/x-raw(memory:NVMM), format=(string)RGBA, framerate=(fraction)25/1, width=(int)1280, height=(int)720
nvstreammux: Successfully handled EOS for source_id=0
**PERF: {'stream0': 0.0}
---ONE CALL---Output Tensor: /conv3_1/Conv_output_0
  Shape: [1, 64, 90, 160]
Tensor sample values:
  [0] = 0.615234
  [1] = 0.064453
  [2] = 0.201172
  [3] = 0.070312
  [4] = -0.888672
  [5] = -1.169922
  [6] = 0.101562
  [7] = -0.347656
  [8] = -1.478516
  [9] = -1.234375
CustomTensorPreparation called
Found tensor meta: 1 output layers
Copying layer 0 of size 3686400 bytes
NvMMLiteBlockCreate : Block : BlockType = 1 
End-of-stream
Exiting app

Remarks :

  • The video is only 1 frame length (therefore only one inference is expected)
  • The 10 first values of output tensor of model_part2 are not the expected ones
  • I don’t understand why I get the output tensor of model_part2 before model_part1
  • I don’t understand why I get 2 calls of InferGrpcClient::InferComplete for model_part1

nvinferserver will do warmup by default. Please add " disable_warmup: true" in nvinfersrver cfg. Please refer to /opt/nvidia/deepstream/deepstream-7.1/sources/apps/sample_apps/deepstream-3d-lidar-sensor-fusion/v2xfusion/config/triton_mode_CAPI.txt