Triton Inference Server + vLLM Backend on the NVIDIA Jetson AGX Orin 64GB Developer Kit

Greetings to all,

Below is a link to a post on how to use the NVIDIA Triton Inference Server together with vLLM to run Large Language Models on the NVIDIA Jetson AGX Orin 64GB Developer Kit.

Thanks to @johnnynunez and @dusty_nv for their help.

Best regards,
Shakhizat

1 Like

Hi shahizat, Can we use the NVIDIA Triton Inference Server together with TensorRT-LLM backend to run Large Language Models on the NVIDIA Jetson AGX Orin?

What does the docker image look like?

Hi @Zhi, Sorry for the late reply. I believe so, as you can get the pip wheel for TensorRT-LLM from here: jp6/cu126 index

Hello, yes. you can get the pip wheel for vLLM from here:jp6/cu126 index. Then use docs here: GitHub - triton-inference-server/vllm_backend. I guess the option 3 is right one using Triton igpu docker images. Hope it helps.

Thank you. I will give this a shot!

Which Jetpack were you using? I am getting nothing but errors trying run vllm. I am on Jetpack 6.2

Hello @calexander1, My Jetpack version is 6.2. Since my Triton docker image is outdated, here are instructions on how to run Triton using the vLLM backend on the Nvidia Jetson Orin. I hope this helps you build the Docker image. I can confirm that it works, I just tested it on the NVIDIA Jetson AGX Orin.

Step 1: Create the Model Repository Structure

└── vllm_model/           # This directory acts as the model's root.
    β”œβ”€β”€ 1/                # Represents the version of your model. Triton loads models from versioned subdirectories.
    β”‚   └── model.json    # Contains the configuration specific to the vLLM model.
    └── config.pbtxt      # The main configuration file for the Triton model.

Step 2: Create the model.json file inside the vllm_model/1/ directory.

{
  "model": "meta-llama/Llama-3.2-1B-Instruct",
  "disable_log_requests": true,
  "gpu_memory_utilization": 0.5,
  "max_num_seqs": 64
}

Step 3: Create the config.pbtxt file in the vllm_model/ directory.

backend: "vllm"
instance_group [
  {
    count: 1
    kind: KIND_MODEL
  }
]

Step 4: Start the Triton Inference Server Docker Container

docker run --rm -it --net host --shm-size=2g \
    --ulimit memlock=-1 --ulimit stack=67108864 --runtime nvidia \
    -v $(pwd)/vllm_model:/opt/tritonserver/model_repository/vllm_model \
    nvcr.io/nvidia/tritonserver:25.05-py3-igpu

Step 5: Once inside the Docker container, you may need to install libopenblas-dev

apt update
apt install libopenblas-dev

Step 6 Install the necessary python packages within the running Triton container from http://jetson.webredirect.org/jp6/cu128.

pip install vllm torch triton

Step 7: Deploy the vLLM Triton Backend (Inside the Docker Container)

mkdir -p /opt/tritonserver/backends/vllm
git clone https://github.com/triton-inference-server/vllm_backend.git /tmp/vllm_backend
cp -r /tmp/vllm_backend/src/* /opt/tritonserver/backends/vllm

Step 8: Log in to Hugging Face (Inside the Docker Container)

huggingface-cli login --token <your_huggingface_token>

Step 9: Start the Triton Server

tritonserver --model-repository /opt/tritonserver/model_repository

Expected Output Logs:

I0616 19:13:19.717710 1100 pinned_memory_manager.cc:277] "Pinned memory pool is created at '0x1021b4000' with size 268435456"
I0616 19:13:19.717973 1100 cuda_memory_manager.cc:107] "CUDA memory pool is created on device 0 with size 67108864"
I0616 19:13:19.724295 1100 model_lifecycle.cc:473] "loading: vllm_model:1"
INFO 06-16 19:13:24 [__init__.py:244] Automatically detected platform cuda.
I0616 19:13:30.781886 1100 python_be.cc:2289] "TRITONBACKEND_ModelInstanceInitialize: vllm_model_0_0 (MODEL device 0)"
INFO 06-16 19:13:35 [__init__.py:244] Automatically detected platform cuda.
INFO 06-16 19:13:53 [config.py:822] This model supports multiple tasks: {'generate', 'reward', 'embed', 'score', 'classify'}. Defaulting to 'generate'.
INFO 06-16 19:13:53 [arg_utils.py:1643] Engine in background thread is experimental on VLLM_USE_V1=1. Falling back to V0 Engine.
WARNING 06-16 19:13:53 [arg_utils.py:1469] Chunked prefill is enabled by default for models with max_model_len > 32K. Chunked prefill might not work with some features or models. If you encounter any issues, please disable by launching with --enable-chunked-prefill=False.
INFO 06-16 19:13:53 [config.py:2176] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 06-16 19:13:53 [api_server.py:267] Started engine process with PID 1234
tokenizer_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 54.5k/54.5k [00:00<00:00, 820kB/s]
tokenizer.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 9.09M/9.09M [00:00<00:00, 12.2MB/s]
special_tokens_map.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 296/296 [00:00<00:00, 992kB/s]
INFO 06-16 19:13:57 [__init__.py:244] Automatically detected platform cuda.
INFO 06-16 19:14:00 [llm_engine.py:231] Initializing a V0 LLM engine (v0.9.2) with config: model='meta-llama/Llama-3.2-1B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.2-1B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=meta-llama/Llama-3.2-1B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":false,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":64,"local_cache_dir":null}, use_cached_outputs=True, 
generation_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 189/189 [00:00<00:00, 592kB/s]
INFO 06-16 19:14:05 [cuda.py:308] Using Flash Attention backend.
INFO 06-16 19:14:06 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 06-16 19:14:06 [model_runner.py:1171] Starting to load model meta-llama/Llama-3.2-1B-Instruct...
INFO 06-16 19:14:07 [weight_utils.py:292] Using model weights format ['*.safetensors']
model.safetensors: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2.47G/2.47G [00:40<00:00, 61.3MB/s]
INFO 06-16 19:14:48 [weight_utils.py:308] Time spent downloading weights for meta-llama/Llama-3.2-1B-Instruct: 41.349721 seconds
INFO 06-16 19:14:49 [weight_utils.py:345] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.70it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.70it/s]

INFO 06-16 19:14:49 [default_loader.py:272] Loading weights took 0.64 seconds
INFO 06-16 19:14:50 [model_runner.py:1203] Model loading took 2.3185 GiB and 43.583469 seconds
INFO 06-16 19:14:51 [worker.py:294] Memory profiling takes 0.89 seconds
INFO 06-16 19:14:51 [worker.py:294] the current vLLM instance can use total_gpu_memory (61.37GiB) x gpu_memory_utilization (0.50) = 30.68GiB
INFO 06-16 19:14:51 [worker.py:294] model weights take 2.32GiB; non_torch_memory takes 1.56GiB; PyTorch activation peak memory takes 0.40GiB; the rest of the memory reserved for KV Cache is 26.40GiB.
INFO 06-16 19:14:51 [executor_base.py:113] # cuda blocks: 54075, # CPU blocks: 8192
INFO 06-16 19:14:51 [executor_base.py:118] Maximum concurrency for 131072 tokens per request: 6.60x
INFO 06-16 19:14:57 [model_runner.py:1513] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 11/11 [00:08<00:00,  1.36it/s]
INFO 06-16 19:15:05 [model_runner.py:1671] Graph capturing finished in 8 secs, took 0.08 GiB
INFO 06-16 19:15:05 [llm_engine.py:429] init engine (profile, create kv cache, warmup model) took 15.85 seconds
I0616 19:15:07.087176 1100 model_lifecycle.cc:849] "successfully loaded 'vllm_model'"
I0616 19:15:07.087548 1100 server.cc:604] 
+------------------+------+
| Repository Agent | Path |
+------------------+------+
+------------------+------+

I0616 19:15:07.087656 1100 server.cc:631] 
+---------+-------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| Backend | Path                                                  | Config                                                                                        |
+---------+-------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| python  | /opt/tritonserver/backends/python/libtriton_python.so | {"cmdline":{"auto-complete-config":"true","backend-directory":"/opt/tritonserver/backends","m |
|         |                                                       | in-compute-capability":"5.300000","default-max-batch-size":"4"}}                              |
| vllm    | /opt/tritonserver/backends/vllm/model.py              | {"cmdline":{"auto-complete-config":"true","backend-directory":"/opt/tritonserver/backends","m |
|         |                                                       | in-compute-capability":"5.300000","default-max-batch-size":"4"}}                              |
+---------+-------------------------------------------------------+-----------------------------------------------------------------------------------------------+

I0616 19:15:07.087750 1100 server.cc:674] 
+------------+---------+--------+
| Model      | Version | Status |
+------------+---------+--------+
| vllm_model | 1       | READY  |
+------------+---------+--------+

I0616 19:15:07.087967 1100 tritonserver.cc:2598] 
+----------------------------------+------------------------------------------------------------------------------------------------------------------------------+
| Option                           | Value                                                                                                                        |
+----------------------------------+------------------------------------------------------------------------------------------------------------------------------+
| server_id                        | triton                                                                                                                       |
| server_version                   | 2.58.0                                                                                                                       |
| server_extensions                | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shar |
|                                  | ed_memory cuda_shared_memory binary_tensor_data parameters statistics trace logging                                          |
| model_repository_path[0]         | model_repository                                                                                                             |
| model_control_mode               | MODE_NONE                                                                                                                    |
| strict_model_config              | 0                                                                                                                            |
| model_config_name                |                                                                                                                              |
| rate_limit                       | OFF                                                                                                                          |
| pinned_memory_pool_byte_size     | 268435456                                                                                                                    |
| cuda_memory_pool_byte_size{0}    | 67108864                                                                                                                     |
| min_supported_compute_capability | 5.3                                                                                                                          |
| strict_readiness                 | 1                                                                                                                            |
| exit_timeout                     | 30                                                                                                                           |
| cache_enabled                    | 0                                                                                                                            |
+----------------------------------+------------------------------------------------------------------------------------------------------------------------------+

I0616 19:15:07.092091 1100 grpc_server.cc:2562] "Started GRPCInferenceService at 0.0.0.0:8001"
I0616 19:15:07.092406 1100 http_server.cc:4806] "Started HTTPService at 0.0.0.0:8000"
I0616 19:15:07.133522 1100 http_server.cc:358] "Started Metrics Service at 0.0.0.0:8002"
INFO 06-16 19:23:12 [metrics.py:417] Avg prompt throughput: 3.4 tokens/s, Avg generation throughput: 33.5 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 06-16 19:23:24 [metrics.py:417] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 7.4 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 06-16 19:23:34 [metrics.py:417] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.

Step 10: Send a Request to the Triton Server

curl -X POST localhost:8000/v2/models/vllm_model/generate -d \
  '{
      "text_input": "Compose a poem that explains the concept of recursion in programming.",
      "parameters":
            {
              "stream": false,
              "max_tokens": 256
            }
  }'

Thank you for taking your time to reply with such a detailed response. I was actually able to use the latest image here: https://hub.docker.com/r/dustynv/vllm/. From there I installed tritonserver following their steps for Jetson installation. Thank you again!

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.