Deepstream with nvinferserver (python backend triton mode)

Please provide complete information as applicable to your setup.

• Hardware Platform (Jetson / GPU) 4090
• DeepStream Version 7.0
• JetPack Version (valid for Jetson only)
• TensorRT Version 8.6.1
• NVIDIA GPU Driver Version (valid for GPU only)
• Issue Type( questions, new requirements, bugs)
• How to reproduce the issue ? (This is for bugs. Including which sample app is using, the configuration files content, the command line used and other details for reproducing)
**• Requirement details( This is for new requirement. Including the module name-for which plugin or for which sample application, the function description)

related question:** https://forums.developer.nvidia.com/t/discrepancy-between-pytorch-and-deepstream-inference-when-deploying-a-custom-reid-model/346211/5

I use the Python backend with Triton and nvinferserver instead of nvinfer so that I can customize the preprocessing and output to be equivalent to PyTorch. When testing with Triton through the HTTP API, Triton works fine, but when integrating with DeepStream, I couldn’t find a sample example and encountered the following error.

The steps I followed are as follows:

model.py

import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import triton_python_backend_utils as pb_utils
import torch
from torchvision.ops import roi_align

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)


class TritonPythonModel:
    def initialize(self, args):
        engine_path = "/app/models/mgn/1/model.plan"
        with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()

        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

        # Get I/O names
        self.input_name = None
        self.output_name = None
        for i in range(self.engine.num_io_tensors):
            name = self.engine.get_tensor_name(i)
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
                self.input_name = name
            else:
                self.output_name = name
        if not self.input_name or not self.output_name:
            raise ValueError("Missing input/output tensor")

        # We'll set input shape dynamically per batch size
        self.max_batch_size = 4  # must match config.pbtxt
        self.fixed_hw = (384, 128)

        # Get output shape for batch=1
        self.context.set_input_shape(self.input_name, (1, 3, *self.fixed_hw))
        self.single_output_shape = self.context.get_tensor_shape(self.output_name)  # e.g., (2048,)

        self.input_dtype = trt.nptype(self.engine.get_tensor_dtype(self.input_name))
        self.output_dtype = trt.nptype(self.engine.get_tensor_dtype(self.output_name))

        self.stream = cuda.Stream()

        # Pre-allocate max memory (for batch=4)
        max_input_elements = self.max_batch_size * 3 * self.fixed_hw[0] * self.fixed_hw[1]
        max_output_elements = self.max_batch_size * trt.volume(self.single_output_shape)

        self.input_host = np.empty(max_input_elements, dtype=self.input_dtype)
        self.output_host = np.empty(max_output_elements, dtype=self.output_dtype)

        self.input_device = cuda.mem_alloc(self.input_host.nbytes)
        self.output_device = cuda.mem_alloc(self.output_host.nbytes)

        self.context.set_tensor_address(self.input_name, int(self.input_device))
        self.context.set_tensor_address(self.output_name, int(self.output_device))

    def preprocess_batch(self, batch_img_arrays):
        """
        Preprocess a batch of images.
        Args:
            batch_img_arrays: list of numpy arrays, each (H, W, 3), uint8
        Returns:
            numpy array of shape (B, 3, 384, 128), dtype float32
        """
        preprocessed = []
        for img_array in batch_img_arrays:
            img = img_array.astype(np.float32) / 255.0
            frame = torch.from_numpy(img).permute(2, 0, 1)  # HWC -> CHW
            frame = frame.sub_(self.mean).div_(self.std)    # normalize

            w, h = frame.shape[2], frame.shape[1]
            cbboxes = np.array([[0, 0, w, h]], dtype=np.float32)
            boxes = torch.cat([torch.zeros(1, 1), torch.from_numpy(cbboxes)], dim=1)
            crop = roi_align(frame.unsqueeze(0), boxes, output_size=self.fixed_hw)
            preprocessed.append(crop.squeeze(0).numpy())

        return np.stack(preprocessed, axis=0).astype(self.input_dtype)

    def execute(self, requests):
        responses = []
        for request in requests:
            input_tensor = pb_utils.get_input_tensor_by_name(request, "input")
            batch_input = input_tensor.as_numpy()  # shape: (B, H, W, 3)

            if batch_input.ndim == 3:
                # Single image (no batch dim) — should not happen if max_batch_size > 0
                batch_input = np.expand_dims(batch_input, axis=0)

            batch_size = batch_input.shape[0]
            if batch_size > self.max_batch_size:
                err = pb_utils.TritonError(f"Batch size {batch_size} exceeds max {self.max_batch_size}")
                responses.append(pb_utils.InferenceResponse(error=err))
                continue

            try:
                # Preprocess entire batch
                input_data = self.preprocess_batch([batch_input[i] for i in range(batch_size)])
                # input_data shape: (B, 3, 384, 128)

                # Set dynamic input shape
                actual_input_shape = (batch_size, 3, *self.fixed_hw)
                if not self.context.set_input_shape(self.input_name, actual_input_shape):
                    raise ValueError(f"Failed to set input shape {actual_input_shape}")

                actual_output_shape = self.context.get_tensor_shape(self.output_name)
                expected_output_elements = trt.volume(actual_output_shape)

                if input_data.size > self.input_host.size:
                    raise ValueError("Input data too large for pre-allocated buffer")

                # Copy to host buffer
                np.copyto(self.input_host[:input_data.size], input_data.ravel())

                # GPU copy and inference
                cuda.memcpy_htod_async(self.input_device, self.input_host, self.stream)
                self.context.execute_async_v3(stream_handle=self.stream.handle)
                cuda.memcpy_dtoh_async(self.output_host, self.output_device, self.stream)
                self.stream.synchronize()

                # Reshape output
                embedding = self.output_host[:expected_output_elements].copy().reshape(actual_output_shape)
                embedding = embedding.astype(np.float32)

                out_tensor = pb_utils.Tensor("output", embedding)
                responses.append(pb_utils.InferenceResponse(output_tensors=[out_tensor]))

            except Exception as e:
                err = pb_utils.TritonError(f"Batch inference failed: {str(e)}")
                responses.append(pb_utils.InferenceResponse(error=err))

        return responses

    def finalize(self):
        self.stream.synchronize()
        self.input_device.free()
        self.output_device.free()

config.pbtxt

name: "mgn"
backend: "python"
max_batch_size: 4  # hoặc giá trị bạn muốn

input [
  {
    name: "input"
    data_type: TYPE_FP32
    dims: [ -1, -1, 3 ]  # H, W, C — Triton tự thêm batch dim ở đầu
  }
]

output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [2048]
  }
]

instance_group {
  kind: KIND_GPU
  count: 1
  gpus: 0
}
version_policy: { specific: { versions: [1]}}

deepstream reid config

infer_config {
  unique_id: 2
  gpu_ids: [0]
  max_batch_size: 4
  
  backend {
    triton {
      model_name: "mgn"
      version: -1
      grpc {
        url: "192.168.1.90:32001"
        enable_cuda_buffer_sharing: true
      }
    }
  }

  preprocess {
    network_format: IMAGE_FORMAT_RGB 
    tensor_order: TENSOR_ORDER_NHWC
    maintain_aspect_ratio: 0
    frame_scaling_hw: FRAME_SCALING_HW_DEFAULT
    frame_scaling_filter: 0
    normalize {
        scale_factor: 1.0
        channel_offsets: [0.0,0.0,0.0]
    }
  }

  postprocess {
    labelfile_path: "labels.txt"
    detection {
      num_detected_classes: 1
      custom_parse_bbox_func: "NvDsInferParseYolo"
    }
  }

  extra {
    copy_input_to_host_buffers: false
  }
  custom_lib {
    path : "/nvdsinfer_custom_impl_mgn_triton/libnvdsinfer_custom_impl_Yolo.so"
  }
} 

input_control {
  process_mode : PROCESS_MODE_CLIP_OBJECTS
  operate_on_gie_id: 1
  operate_on_class_ids: [0]
  interval : 0
}

output_control {
  output_tensor_meta: true
}
  

My deepstream base on nvcr.io/nvidia/deepstream:7.0-triton-multiarch so it has tensorrt 8.6.1
So when using triton with python backend. My dockerfile like this

# Sử dụng Triton server có sẵn Python backend (phiên bản 24.08)
# FROM nvcr.io/nvidia/tritonserver:24.08-py3
FROM nvcr.io/nvidia/tritonserver:23.10-py3

RUN apt-get update && apt-get install ffmpeg libsm6 libxext6  -y
 
# ==================================================================
# python lib
# ------------------------------------------------------------------
RUN PIP_INSTALL="pip --no-cache-dir install --upgrade" && \
    $PIP_INSTALL \
    opencv-python \
    yacs \
    pillow \
    unidecode
 
 
# RUN pip install onnxruntime
    
# RUN pip install --upgrade tritonclient[all]
 
# RUN pip --no-cache-dir install \
#     protobuf==3.20.0 \
#     numpy==1.23.5
 
# RUN PIP_INSTALL="pip --no-cache-dir install --upgrade" && \
#     $PIP_INSTALL \
#     torch torchvision torchaudio
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126

 
ENV LC_ALL C.UTF-8

requirements.txt

opencv-python
#tensorrt==10.3.0
tensorrt==8.6.1
pycuda
onnx 
onnxslim 
onnxruntime
tritonclient[all]
protobuf==3.20.0
numpy==1.23.5
yacs
pillow
unidecode

I using nvcr.io/nvidia/tritonserver:23.10-py3 : refence deepstream 7.0 release notes

Are there any mistakes in my steps, or does DeepStream support Triton with the Python backend? If so, could you provide me with an example link for reference? Thanks

I suspect the config is wrong. When debugging, I found that the provided samples (if any) are usually used on full images (e.g., 1920x1080 – static shape). But I’m using it in SGIE (back-to-back). The shape of the input bounding box changes, although HTTP with Triton can work with dynamic shapes, for example: 300x399x3 or 100x292x3.

However, when using it with DeepStream, it seems that a fixed shape is required. Suppose I set

input [
  {
    name: "input"
    data_type: TYPE_FP32
    dims: [400, 450, 3]  # H, W, C — Triton automatically adds the batch dim at the beginning
  }
]

to some arbitrary value, but then change it to

input [
  {
    name: "input"
    data_type: TYPE_FP32
    dims: [-1, -1, 3]  # H, W, C — Triton automatically adds the batch dim at the beginning
  }
]

According to the logs, it was able to enter the Python backend as follows.

Is it possible to use a fixed arbitrary shape in config.pbtxt, but adjust the logic inside python_backend so that it still works with DeepStream?

I can fix this, using
parameters: {
key: “FORCE_CPU_ONLY_INPUT_TENSORS”
value: {
string_value:“yes”
}
}

Pipeline running, but it look like my input received in triton resized into (400, 450). How to avoid deepstream pipeline (object after nvinfer - pgie) reisze and forward only crop bbox (dynamic shape) to triton (support dynamic batch and dynamic shape)
I want to use the Python backend to receive dynamic input, but if it gets resized before entering Triton, it will lead to errors, making it completely meaningless.
Thanks

Sorry for the late reply, Is this still an DeepStream issue to support? Thanks!
Please refer to this FAQ for python backend in DeepStream.

As the doc “To manage memory efficiency and keep clean interface, The Gst-nvinferserver Plugin’s default preprocessing cannot be disabled. Color conversion, datatype conversion, input scaling and object cropping are continue working in nvds_infer_server natively.” shown, if native normalization is not needed, you can update scale_factor to 1.0.