It has been pretty tricky to get two DGX Sparks working as a server for inference models.
I have posted 3 repos on git that provide the same scripts for setup, checkout and benchmarking across two DGX Spark servers for vLLM, SGLang and TensorRT.
It has been pretty tricky to get two DGX Sparks working as a server for inference models.
I have posted 3 repos on git that provide the same scripts for setup, checkout and benchmarking across two DGX Spark servers for vLLM, SGLang and TensorRT.
Have you tried to run gpt-oss-120b in the cluster using SGLang? What performance are you getting?
The last time I tried, it crashed during inference, so I switched my efforts towards vLLM.
A cursory look at your scripts doesnβt show anything substantially different from how I launched it, so Iβm curious if you managed to get it working after all (or maybe I need to re-download gpt-oss-120b, as it was giving me some trouble with vllm too recently).
I now have standardized on a common approach for benchmarking across all three so that I can do comparisons.
Here is the GPT-OSS-120B across the three platforms:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Benchmark Results
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Test Configuration:
Platform: SGLang
Model: openai/gpt-oss-120b
Num Prompts: 100
Concurrency: 32
Dataset: ShareGPT_V3
Throughput Metrics:
Duration: 52.29s
Requests/sec: 1.91
Output tok/s: 242.76
Total tok/s: 476.98
Latency Metrics:
Mean Latency: 15360.49 ms
P50 Latency: 12673.84 ms
P99 Latency: 22591.62 ms
Mean TTFT: 123.45 ms
Request Statistics:
Completed: 100/100
Total Input Tokens: 12247
Total Output Tokens:12694
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Benchmark Results
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Test Configuration:
Platform: TensorRT-LLM
Model: openai/gpt-oss-120b
Num Prompts: 100
Concurrency: 32
Dataset: ShareGPT_V3
Throughput Metrics:
Duration: 198.49s
Requests/sec: 0.50
Output tok/s: 62.13
Total tok/s: 128.60
Latency Metrics:
Mean Latency: 54130.01 ms
P50 Latency: 61797.33 ms
P99 Latency: 67569.62 ms
Mean TTFT: 468.73 ms
Request Statistics:
Completed: 100/100
Total Input Tokens: 13192
Total Output Tokens:12333
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Benchmark Results
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Test Configuration:
Platform: vLLM
Model: openai/gpt-oss-120b
Num Prompts: 100
Concurrency: 32
Dataset: ShareGPT_V3
Throughput Metrics:
Duration: 43.47s
Requests/sec: 2.30
Output tok/s: 292.46
Total tok/s: 598.38
Latency Metrics:
Mean Latency: 12761.38 ms
P50 Latency: 12519.12 ms
P99 Latency: 15023.92 ms
Mean TTFT: 100.56 ms
Request Statistics:
Completed: 100/100
Total Input Tokens: 13299
Total Output Tokens:12714
vLLM is the fastest, but it takes about 12mins to load the model and be ready to accept requests, SGLang is ready the fastest in about 3-4min and TRT is ready in about 5-6 mins.
VLLM will be the fastest to load a model if you use fastsafetensors.
But Iβm surprised seeing your results - something is not right here. What are you getting on a single request?
On a single node, SGLang was significantly faster than VLLM for me - 52 t/s vs 36 t/s. So either the cluster implementation is botched, or there is something else going on.
My VLLM benchmarks for 100 requests are slightly higher then yours, so I assume that VLLM is working properly, so there is definitely something wrong with SGLang setup.
Ah, I see now. You set --disable-cuda-graph by default. I guess, you are getting the same crash as me otherwise:
RuntimeError: CUDART error: invalid resource handle
During handling of the above exception, another exception occurred:
That explains the performance drop.
correctβ¦.getting it to work was goal #1, then sort out more performance.
Performance is the tricky part usually :)
now we have cuda graph running with small batches and sglang matches the vLLM performance.
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Benchmark Results
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Test Configuration:
Platform: SGLang
Model: openai/gpt-oss-120b
Num Prompts: 100
Concurrency: 32
Dataset: ShareGPT_V3
Throughput Metrics:
Duration: 43.38s
Requests/sec: 2.31
Output tok/s: 292.70
Total tok/s: 567.73
Latency Metrics:
Mean Latency: 12875.18 ms
P50 Latency: 12879.93 ms
P99 Latency: 14413.43 ms
Mean TTFT: 101.79 ms
Request Statistics:
Completed: 100/100
Total Input Tokens: 11930
Total Output Tokens:12696
What cuda batch size was that?
It is set to 32 currently
I was able to bump it up to 256 now and have set memory to 0.85 and it is driving a bit more performance:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Benchmark Results
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Test Configuration:
Platform: SGLang
Model: openai/gpt-oss-120b
Num Prompts: 100
Concurrency: 32
Dataset: ShareGPT_V3
Throughput Metrics:
Duration: 38.90s
Requests/sec: 2.57
Output tok/s: 323.21
Total tok/s: 660.93
Latency Metrics:
Mean Latency: 11703.57 ms
P50 Latency: 12012.12 ms
P99 Latency: 12543.92 ms
Mean TTFT: 92.44 ms
Request Statistics:
Completed: 100/100
Total Input Tokens: 13138
Total Output Tokens:12574
Unless you are using some other Docker image, or have uncommitted changes to your repository, I have no idea why it works on your end.
Apparently, it crashed because of this bug: [fix] Only enable flashinfer all reduce fusion by default for single-node servers by leejnau Β· Pull Request #12724 Β· sgl-project/sglang Β· GitHub
The fix was merged a day AFTER the sglang:spark container was created, so itβs not there.
I applied the patch manually, and it loaded just fine with CUDA graphs and didnβt crash.
Unfortunately, the main sglang:latest Docker doesnβt contain Spark/Blackwell-specific optimizations for MXFP4 MOE, and yes, there is an open issue for that too: [Bug] Fail to run gpt-oss with FlashInfer MXFP4 moe kernel on 5090 Β· Issue #13061 Β· sgl-project/sglang Β· GitHub
Iβm not getting the same performance as you on 100 requests / 32 concurrent, but single request performance is pretty good, Iβm getting 75 t/s (vs 52 t/s single node SGLang or 55 t/s dual node vLLM).
BTW, I suggest you add single request performance to your benchmarks (just run it first to avoid caching prompts), because itβs the best indication of end user experience - this is what you get when you chat with an LLM.
Here are my numbers:
vllm bench serve --backend vllm --model openai/gpt-oss-120b --endpoint /v1/completions --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1 --port 30000
============ Serving Benchmark Result ============
Successful requests: 1
Failed requests: 0
Benchmark duration (s): 1.58
Total input tokens: 12
Total generated tokens: 119
Request throughput (req/s): 0.63
Output token throughput (tok/s): 75.12
Peak output token throughput (tok/s): 74.00
Peak concurrent requests: 1.00
Total Token throughput (tok/s): 82.69
---------------Time to First Token----------------
Mean TTFT (ms): 48.16
Median TTFT (ms): 48.16
P99 TTFT (ms): 48.16
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 13.01
Median TPOT (ms): 13.01
P99 TPOT (ms): 13.01
---------------Inter-token Latency----------------
Mean ITL (ms): 13.01
Median ITL (ms): 12.94
P99 ITL (ms): 23.73
==================================================
Output tokens per second
400 +---------------------------------------------------------------------+
| |
350 | * |
| ** * |
| * * ** * * * * ** ** * *** |
300 | * ******* *** * * * **** **** * *** * |
| * * ** **** **** * **** **** * ***** |
250 | * * * * * **** * *** * * * * * *** |
| * * * * * * * * * * * * * ** |
200 | * * * * **** * |
| * * * * |
| * * * * |
150 | * * * * |
| * * * ** |
100 | * ** * |
| * * * |
| * * * |
50 | * * * |
| * * * |
0 +---------------------------------------------------------------------+
0 10 20 30 40 50 60 70 80 90 100
Concurrent requests per second
40 +----------------------------------------------------------------------+
| |
35 | * * * * * * |
| * * ******************* ** *** ** * |
|********* * ** ** ** |
30 | ** |
| ** |
25 | ** |
| * |
20 | *** |
| ** |
| ****** |
15 | *** |
| * |
10 | * |
| * |
| * |
5 | ** |
| ** |
0 +----------------------------------------------------------------------+
0 10 20 30 40 50 60 70 80 90 100
============ Serving Benchmark Result ============
Successful requests: 100
Failed requests: 0
Maximum request concurrency: 32
Benchmark duration (s): 108.98
Total input tokens: 22946
Total generated tokens: 21376
Request throughput (req/s): 0.92
Output token throughput (tok/s): 196.14
Peak output token throughput (tok/s): 352.00
Peak concurrent requests: 36.00
Total Token throughput (tok/s): 406.70
---------------Time to First Token----------------
Mean TTFT (ms): 4468.81
Median TTFT (ms): 435.18
P99 TTFT (ms): 15387.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 158.53
Median TPOT (ms): 128.79
P99 TPOT (ms): 438.55
---------------Inter-token Latency----------------
Mean ITL (ms): 119.13
Median ITL (ms): 94.17
P99 ITL (ms): 828.34
==================================================
I am currently using the lmsysorg/sglang:spark container and not the latest.
I will add a 1 prompt option to the benchmarks. Thanks for the suggestion.
Then I have no idea why it works on your side without failing with the error.
I was able to launch only after I applied the patch referenced in my post.
Can you post the startup logs, especially the config part?
[2025-12-03 20:30:35] server_args=ServerArgs(model_path=βopenai/gpt-oss-120bβ, tokenizer_path=βopenai/gpt-oss-120bβ, tokenizer_mode=βautoβ, tokenizer_worker_num=1, skip_tokenizer_init=False, load_format=βautoβ, model_loader_extra_config=β{}β, trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl=βautoβ, host=β0.0.0.0β, port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype=βbfloat16β, quantization=None, quantization_param_path=None, kv_cache_dtype=βautoβ, enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.8, max_running_requests=None, max_queued_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy=βfcfsβ, enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=True, radix_eviction_policy=βlruβ, device=βcudaβ, tp_size=2, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=639353969, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level=βinfoβ, log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header=βx-custom-labelsβ, tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint=βlocalhost:4317β, api_key=None, served_model_name=βopenai/gpt-oss-120bβ, weight_version=βdefaultβ, chat_template=None, completion_template=None, file_storage_path=βsglang_storageβ, enable_cache_report=False, reasoning_parser=βgpt-ossβ, tool_call_parser=βgpt-ossβ, tool_server=None, sampling_defaults=βmodelβ, dp_size=1, load_balance_method=βround_robinβ, load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=β192.168.177.12:20000β, nnodes=2, node_rank=0, json_model_override_args=β{}β, preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy=βlruβ, lora_backend=βcsgmvβ, max_lora_chunk_size=16, attention_backend=βtritonβ, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend=βflashinferβ, grammar_backend=βxgrammarβ, mm_attention_backend=None, nsa_prefill_backend=βflashmla_sparseβ, nsa_decode_backend=βfa3β, speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode=βprefillβ, speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type=βBFSβ, speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend=βnoneβ, moe_runner_backend=βtriton_kernelβ, flashinfer_mxfp4_moe_precision=βdefaultβ, enable_flashinfer_allreduce_fusion=False, deepep_mode=βautoβ, ep_num_redundant_experts=0, ep_dispatch_algorithm=βstaticβ, init_expert_location=βtrivialβ, enable_eplb=False, eplb_algorithm=βautoβ, eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype=βfloat32β, mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy=βwrite_throughβ, hicache_io_backend=βkernelβ, hicache_mem_layout=βlayer_firstβ, hicache_storage_backend=None, hicache_storage_prefetch_policy=βbest_effortβ, hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_amx_weight_path=None, kt_amx_method=βAMXINT4β, kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type=βqkβ, ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode=βcpuβ, multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=256, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler=βeagerβ, torchao_config=ββ, enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=-1, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode=βnullβ, disaggregation_transfer_backend=βmooncakeβ, disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=, weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, decrypted_config_file=None, decrypted_draft_config_file=None)
Oh, it looks like you are using your own benchmark script. Can you run vllm bench serve against SGLang using this command?
vllm bench serve \
--backend vllm \
--model openai/gpt-oss-120b \
--endpoint /v1/completions \
--dataset-name sharegpt \
--dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
--num-prompts 100 \
--port 30000
Here are my results:
Output tokens per second
800 +---------------------------------------------------------------------+
|* |
700 |* |
|** |
| ** |
600 | * * |
| *** |
500 | * *** |
| * **** |
400 | * *** ** |
| ** ** ** |
| * ** ** |
300 | ** ***** *** |
| * ** ** |
200 | * **** |
| ** |
| ****** |
100 | * |
| * |
0 +---------------------------------------------------------------------+
0 10 20 30 40 50 60 70
Concurrent requests per second
100 +---------------------------------------------------------------------+
|* |
| * |
| * |
80 | * |
| * |
| **** |
| ****** |
60 | ** |
| **** |
| ***** |
40 | ** |
| *** |
| ******* |
| ***** |
20 | *** |
| ****** |
| *** |
| ****** |
0 +---------------------------------------------------------------------+
0 10 20 30 40 50 60 70
============ Serving Benchmark Result ============
Successful requests: 100
Failed requests: 0
Benchmark duration (s): 61.19
Total input tokens: 22946
Total generated tokens: 21564
Request throughput (req/s): 1.63
Output token throughput (tok/s): 352.43
Peak output token throughput (tok/s): 737.00
Peak concurrent requests: 100.00
Total Token throughput (tok/s): 727.44
---------------Time to First Token----------------
Mean TTFT (ms): 301.17
Median TTFT (ms): 347.17
P99 TTFT (ms): 376.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 119.75
Median TPOT (ms): 124.01
P99 TPOT (ms): 151.29
---------------Inter-token Latency----------------
Mean ITL (ms): 108.54
Median ITL (ms): 117.29
P99 ITL (ms): 143.00
==================================================
I will run the vllm benchmark.
I switched and use the same benchmark so that it is the same across vLLM, SGLang and TensorRT.