vLLM on GB10: gpt-oss-120b MXFP4 slower than SGLang/llama.cpp... what’s missing?

I’m using MXFP4 model weights and I’m specifically looking to ensure vLLM uses GB10’s native FP4 tensor-core kernels rather than slower fallbacks. I would like to get vLLM to the same performance class as the SGLang feature branch or recent llama.cpp improvements.

At the moment, vLLM runs, but it looks like we’re not hitting the native FP4 (NVFP4) / Blackwell-optimized MoE GEMM path. My working hypothesis is that vLLM is either:

  • falling back to a slower Marlin/weight-only FP4 path, or

  • not enabling the intended FlashInfer/CUTLASS group GEMM backend for MXFP4 MoE on sm_121a, or

  • losing perf due to GB10-specific toolchain/backend gating (Triton/PTXAS, attention backend selection, async scheduling / CUDA graphs, etc.).

What I’m looking for:

  1. What is the intended “fast path” on GB10 for gpt-oss-120b MXFP4 in vLLM (MoE GEMM + attention)?

  2. Which versions of vLLM / FlashInfer / Triton / PyTorch / CUDA are currently recommended on GB10 to get that fast path?

  3. Are there known backend gating or shape/padding/packing constraints on GB10 that prevent MXFP4 MoE from selecting the fastest kernels?

  4. If I want to contribute, what’s the highest-impact area:

    • enabling/validating FlashInfer/CUTLASS MXFP4 MoE on sm_121a,

    • fixing Triton toolchain/ptxas issues for sm_121a,

    • or vLLM runtime/scheduler issues (async scheduling, batch queue, CUDA graphs)?

I can run A/B benchmarks, collect logs, and provide Nsight traces. If there’s a specific checklist to confirm we’re on the “native FP4” path (expected log messages, env vars, kernels to look for), I’ll follow it and report back.

Thanks! I am happy to help test or upstream fixes.

1 Like

NVFP4/FP4 isn’t being “properly utilized” on DGX Spark (GB10 / sm121) in current vLLM builds, so NVFP4 quants can be slower than AWQ 4-bit on the same workload. FP4 kernels / NVFP4 paths are better optimized for sm120 (RTX 50xx / RTX Pro 6000) than for Spark’s sm121. So, the summary is: installs got way smoother (cu130 wheels + better Docker tooling + cluster scripts), but NVFP4 performance on Spark still isn’t quite there yet.

Thanks for the reply! That matches with what I’m seeing (NVFP4/MXFP4 underperforming vs AWQ 4-bit on GB10).

Do you know which specific kernel path is missing/slow on sm121 (MoE group GEMM vs attention vs packing/padding)? Also, are there recommended vLLM/FlashInfer/Triton/PyTorch/CUDA versions for Spark right now, and any upstream issues/PRs to track?

I can test patches and provide Nsight traces.

I’m also willing to attempt to contribute, but I’m lacking a little direction. I attempted to make a patch that tried to address sm121 gating, but that’s clearly not enough. Run VLLM in Spark - #118 by christopher_owen

There is also this contribution, new today: feat: Add SM121/GB10 (DGX Spark) Blackwell-class GPU support by seli-equinix · Pull Request #31740 · vllm-project/vllm · GitHub

As well as this contribution, which didn’t appear well received: [Bugfix] Add SM 12.1 support by ohsono · Pull Request #31607 · vllm-project/vllm · GitHub

On the flashInfer side, we see:

(for example, I don’t understand why there is different build strings for cuda ‘< 13.0’ and the rest - and so the difference between 12.0a and 12.0f.

If you want to go deeper into the rabbit hole, you can look at optimizations that SGLang guys did in Triton and SGLang itself. I built Triton with their changes locally, but it alone didn’t improve performance in vLLM even after I managed to get Triton backend working instead of FLASH_ATTN.

1 Like

Well, this contribution doesn’t seem to add anything as I’m getting the same speeds right now. I commented in the post, hopefully the author responds.

This commit seems to only resolve the ‘gating’ in vllm (like my attempt earlier tried to do)…. The PR did say that it was leading to TRITON_ATTN, though.

I haven’t had a chance to run that patch yet. When you did was it still using Marlin kernel or did it start using the TRITON_ATTN for you? That would at least be a step in forward.

That PR mentions “FlashInfer does not support SM121 yet, so TRITON_ATTN backend is used”.

I wonder how to add support for SM121 in FlashInfer? I thought the Nvidia patches from my previous comment were to do so. Maybe he needs to make his own compile of FlasInfer as well as using his gating patches?

But this is the heart of my initial questions above, there are many factors and I’m not confident on where to look.

Yes, it used Triton backend with pretty much the same performance, so I abandoned this route for now.

1 Like

Following up on this discussion with an update on my vLLM work for DGX Spark / SM121:

What I implemented

  • Removed some Spark-related feature gating in vLLM.

  • Wrote a CUTLASS-based attention kernel.

  • Wrote a block-scaled FP8×FP4 MoE GEMM in CUTLASS (FP8 activations × MXFP4 weights).

  • Integrated the attention kernel and MoE GEMM into vLLM.

Along the way I ran into a number of upstream rough edges; I’m planning to send PRs where appropriate (including what I suspect is a compiler bug in nvcc when working with cutlass on the sm121.)

Results so far

  • It’s still coherent!

  • The clutter attention kernel found a 1% TPS improvement.

  • The MoE GEMM kernel found pp2048 throughput (llama-benchy) improved strongly, with some measurements showing +30% increase versus my prior baseline

  • Hopefully prefix cache or some other variable isn’t at play? My benchmarking game could improve.

Current decode/TPS regression
My path currently includes BF16 → MXFP8 activation quantization, and on decode (small M) that per-token overhead is significant. In my setup this results in about an 8% TPS drop vs the Marlin baseline, with benchmarking held constant (--max-num-seqs 2, --enforce-eager, 128k context window, 8k batched tokens).

I think there’s a straightforward next step here: move quantization out of decode by fusing it into the MoE expand/permutation path, so we don’t pay a standalone quantization pass per token.

MoE GEMM still has headroom
Right now the MoE GEMM is intentionally conservative:

  • only a 128×128×128 tile variant,

  • forced a 1×1×1 cluster/shape,

  • disabled several CUTLASS/FlashInfer/TRT-LLM (I think it’s shared) autotuning options (e.g., finalize fusion, swap, tile sizes).

I expect there’s meaningful performance left on the table by re-enabling those and expanding the config space.

Key takeaway / bottleneck
From my measurements, the usual “triton_attn vs flash_attn” (and similar kernel swaps) wasn’t the primary limiter for GPT-OSS-120B TPS in vLLM. The bigger issue appears to be the dense GEMM for lm_head: the current vLLM path ends up converting lm_head to FP32 (whereas llama.cpp keeps it natively in MXFP4), and that dense GEMM was always - but is still - the dominant TPS bottleneck for me, based on profiling vllm.

Next steps

  • Tackle the lm_head dense GEMM path (keep it MXFP4 end-to-end if possible).

  • Experiment with more MoE GEMM kernel tuning and experiment with shifting the quantization from the decode path.

  • Experiment with tree-span speculative decoding (Eagle3) (I also learned the practical tradeoffs vs chain-based speculative decoding while testing).

  • Consider a specialized small-M path (DP4A-based GEMV).

As in my original post, I am still interested in being mentored on the highest-impact areas I could focus on.

4 Likes

Have you had a chance to look into the changes SGLang guys did for Spark marketing campaign? GitHub - yvbbrjdr/sglang at spark - looks like these are the key (+ enabling Triton kernels).

Thank you for the pointer. Below is my analysis of what’s driving SGLang’s success.

There’s quite a bit to learn from this. One key takeaway is that it’s not only which engine is used (in this case, they’re using Triton), but also how it’s used. Their implementation has useful ideas on both fronts.

For example, they quantize all dense layers to FP8 - not just lm_head (which I mentioned in my previous post). In addition to lm_head, they quantize the other dense layers (q_proj, k_proj, v_proj, o_proj) as well as embed_tokens.

Before learning this about SGLang, I was able to quantize lm_head to MXFP4 in vLLM. That change produced a ~10% TPS improvement over baseline (up from -8% earlier in the effort). Now that I understand SGLang’s broader approach, I’m going to try the same for these other layers and add a command-line option so users can select MXFP4 / MXFP8 / native weights for the dense layers (but I specifically plan to remove the BF16 → FP32 scaling that the current vLLM code is performing as this seems like a ‘get it to work’ shortcut and not a desired end-state). I haven’t performed any testing on how the quantization impacts the quality of the LLM. I don’t really have a good strategy for this yet.

Also, I “cheated a bit” on the MXFP4 dequantization path for lm_head by reusing Marlin kernel code. I plan to implement a purpose-built GEMV in CUTLASS for this use case - there may be additional gains there as well.

For the MoE GEMM, I’m keeping the weights in their native mxfp4, but I’m quantizing the activations to mxfp8. I should also put some thought into how to keep the activations native BF16 as I think it would be nice to have a ‘pick your quality/speed trade-off’ user setting for all of these things.

I’m sure there’s more to learn from the SGLang implementation, so I’ll keep digging.

The raw analysis of SGLang can be found here: spark-vllm-mxfp4-docker/docs/analysis/SGLANG_ANALYSIS.md at main · christopherowen/spark-vllm-mxfp4-docker · GitHub

If anyone is interested in real-time collaboration, I can make myself available on the vLLM slack to show what I have. At some point I plan to make a nice Dockerfile (inspired by eugr) to tie this all together, but what I have now isn’t there.

2 Likes

Interesting… There are also changes in Triton that they’ve made: GitHub - yvbbrjdr/triton at spark

Once you are close to working solution, we can bake it into our community docker build as a separate build option.

2 Likes

little status update… 48.9 tg32 TPS (+55% over baseline Marlin) with keeping an improvement to pp2048 over baseline. This was achieved with the learning from SGLang and applying mxfp4 quantization to the projection layers (I’ve omitted the embed_tokens for now.)

The model is still coherent with Geography, Math, and coding questions.

I’m pleased with the result, but there’s more to do to in order to beat the SOTA llama.cpp performance on the spark.

2 Likes

Great work! Is the Dockerfile in your repo up to date? Does it depend on any changes in vllm pr 31740?

1 Like

Thanks!

The Dockerfile isn’t quite ready yet—right now I’m using Dockerfile.dev, with my mxfp4_v2 branches from my vLLM and FlashInfer forks bind-mounted from ~/projects/vllm and ~/projects/flashinfer. The container also includes a few convenience scripts for building FlashInfer and vLLM from /workspace.

You can jump into the container with:

docker compose -f docker-compose.dev.yml exec dev bash

For my testing, I start vLLM with:

docker exec -it vllm-dev bash -c '
export PYTHONPATH=/workspace/flashinfer:/workspace/vllm
vllm serve openai/gpt-oss-120b \
  --host 0.0.0.0 \
  --port 8000 \
  --served-model-name gpt-oss-120b \
  --quantization mxfp4 \
  --mxfp4-backend CUTLASS \
  --mxfp4-layers moe,qkv,o,lm_head \
  --attention-backend FLASHINFER \
  --tensor-parallel-size 1 \
  --gpu-memory-utilization 0.70 \
  --max-model-len 131072 \
  --max-num-seqs 2 \
  --max-num-batched-tokens 8192 \
  --enforce-eager \
  --enable-prefix-caching \
  --load-format fastsafetensors
'

Note: I’ve tried to minimize blast radius by allowing the attention backend and MoE GEMM path to be selected via CLI flags. I plan to use that for final benchmarking, and it’s a nice improvement over the previous environment-variable approach (e.g., VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS).

I initially started with Dockerfile, but pivoted to Dockerfile.dev as my main workflow. Once I’m further along, I plan to polish Dockerfile, pin repos to specific SHAs, and add a changelog for reproducibility (instead of my current patch approach).

On the vLLM side, my mxfp4_v2 branch is based on PR #31740 (a big improvement over my earlier SM121 gate-removal attempt). My latest work is available in these branches—feedback welcome.

FYI, I’ve also pivoted to improving vLLM’s tree-based speculative decoding, since that seems like the best near-term shot at a TPS boost. Right now, draft token selection is raw argmax/top-k on logits and ignores the sampling transforms (temperature, top-p, penalties, etc.) that the verifier applies, which causes unnecessary rejections because the drafts don’t match what the verifier would actually sample.

If this pans out, I think it could make the overall solution very competitive—before we even apply other refinements discussed in this thread (e.g., a CUTLASS GEMV/small-M path for decode, similar to what llama.cpp does, to complement the block-scaled MoE GEMM that shines in prefill).

1 Like

Small update…

Current tg32 is 50.5 TPS (+3% over the previous measurement and +60% over baseline).

My Eagle3 experiments didn’t pan out. My current hypothesis is that quantization is hurting draft quality enough that the verifier rejects most draft tokens. There’s more to investigate there, but I’m setting it aside for now.

In the meantime, I focused on plain decode performance. I made a couple small changes to allow FP8 KV cache; on its own, this didn’t move the needle (~0.1% TPS).

More interestingly, I tried running without --enforce-eager. It worked without any code changes and gave a ~1.5% TPS gain “for free.” I don’t remember why I stopped testing with CUDA graphs earlier (maybe just vllm load time and testing burden), but it might not be necessary to disable them anymore.

After pulling the latest changes for FP8 KV cache, the new command is:

docker exec -it vllm-dev bash -c '
export PYTHONPATH=/workspace/flashinfer:/workspace/vllm
vllm serve openai/gpt-oss-120b \
  --host 0.0.0.0 \
  --port 8000 \
  --served-model-name gpt-oss-120b \
  --quantization mxfp4 \
  --mxfp4-backend CUTLASS \
  --mxfp4-layers moe,qkv,o,lm_head \
  --attention-backend FLASHINFER \
  --kv-cache-dtype fp8 \
  --tensor-parallel-size 1 \
  --gpu-memory-utilization 0.70 \
  --max-model-len 131072 \
  --max-num-seqs 2 \
  --max-num-batched-tokens 8192 \
  --enable-prefix-caching \
  --load-format fastsafetensors
'

Profiling shows the following breakdown (~20.4 ms/token):

Component Time %
MoE GEMM (FC1 + FC2) ~12.5 ms 61%
QKV/O GEMM ~4.0 ms 20%
LM head ~1.2 ms 6%
Attention ~0.8 ms 4%
MoE routing / other ~1.9 ms 9%

That leaves a ~13% gap (~2.6 ms) of overhead (kernel launch, routing, activation quant, RMSNorm, etc.).

Next I should swing back to the MoE blockwise GEMM kernel—based on the profile, that’s the biggest lever. Earlier I had planned to expand the supported tiles to more than only 128x128x128, so I may try this.

Still more work to do to beat the SGLang and llama.cpp results.

5 Likes

12.0
12.0f
12.a

you have the explanation here: NVIDIA Blackwell and NVIDIA CUDA 12.9 Introduce Family-Specific Architecture Features | NVIDIA Technical Blog

From cuda <12.9 it not admits blackwell family f

You’re reading my mind! I just discovered that CUDA 12.1a didn’t include the FP4 features I needed, so I updated my container from 12.1a to 12.1f and the FP4 path started working as expected.

Do you know if there Is there a way to use tcgen05 (or a similar path) to do FP8 × FP4 without padding the FP4 weights? In my case the padding is costing roughly ~8 KB out of ~100 KB shared memory per CTA (~8%), which feels like a big waste. For many tile shapes, tighter FP4 packing could free enough space to fit an additional scale value.

I’m feeling slightly shocked. I think this may be the fasted way to run gpt-oss-120b on the spark now. There is still quite a bit more possible to squeeze out performance, like fuzing the quantization with the MoE GEMM or using cutlass for the dense layers.

Since it’s JIT-only now (or at least I haven’t tried it any other way), it’s a bit slow to start with all the extra tile shapes.

$ docker exec vllm-dev llama-benchy --base-url http://localhost:8000/v1 --model gpt-oss-120b --pp 2048 --tg 32 128 --runs 5 2>&1

Token indices sequence length is longer than the specified maximum sequence length for this model (159385 > 1024). Running this sequence through the model will result in indexing errors
llama-benchy (0.1.1)
Date: 2026-01-17 02:59:42
Benchmarking model: gpt-oss-120b at http://localhost:8000/v1
Error loading tokenizer: gpt-oss-120b is not a local folder and is not a valid model identifier listed on ‘https://huggingface.co/models’
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`
Falling back to ‘gpt2’ tokenizer as approximation.
Loading text from cache: /root/.cache/llama-benchy/cc6a0b5782734ee3b9069aa3b64cc62c.txt
Total tokens available in text corpus: 159385
Warming up…
Warmup (User only) complete. Delta: 64 tokens (Server: 86, Local: 22)
Warmup (System+Empty) complete. Delta: 68 tokens (Server: 90, Local: 22)
Measuring latency using mode: api…
Average latency (api): 1.31 ms
Running test: pp=512, tg=32, depth=0
Running test: pp=512, tg=128, depth=0
Running test: pp=2048, tg=32, depth=0
Running test: pp=2048, tg=128, depth=0
Running test: pp=8192, tg=32, depth=0
Running test: pp=8192, tg=128, depth=0

model test t/s ttfr (ms) est_ppt (ms) e2e_ttft (ms)
gpt-oss-120b pp512 1854.24 ± 29.69 251.54 ± 14.67 250.23 ± 14.67 302.97 ± 15.12
gpt-oss-120b tg32 60.02 ± 0.10
gpt-oss-120b pp512 1782.37 ± 289.72 269.85 ± 37.75 268.54 ± 37.75 321.21 ± 38.31
gpt-oss-120b tg128 60.07 ± 0.04
gpt-oss-120b pp2048 4572.79 ± 109.53 392.40 ± 11.86 391.09 ± 11.86 444.66 ± 12.01
gpt-oss-120b tg32 59.36 ± 0.08
gpt-oss-120b pp2048 4622.49 ± 70.19 396.81 ± 7.09 395.50 ± 7.09 448.18 ± 6.86
gpt-oss-120b tg128 59.47 ± 0.02
gpt-oss-120b pp8192 6628.10 ± 24.46 1111.29 ± 6.26 1109.98 ± 6.26 1164.33 ± 6.24
gpt-oss-120b tg32 57.52 ± 0.12
gpt-oss-120b pp8192 6612.96 ± 12.42 1097.16 ± 15.20 1095.84 ± 15.20 1150.34 ± 14.91
gpt-oss-120b tg128 57.81 ± 0.03

llama-benchy (0.1.1)
date: 2026-01-17 02:59:42 | latency mode: api

Full Benchmark Results analysis:

Context Prefill (t/s) Decode tg32 (t/s) Decode tg128 (t/s)
Short (512) 1,854 60.02 60.07
Medium (2048) 4,573 59.36 59.47
Long (8192) 6,628 57.52 57.81

Key Observations:

  • ✅ Decode consistently 57-60 tok/s across all context lengths

  • ✅ Prefill scales well: 1.8K → 4.6K → 6.6K t/s as batch size increases

  • ✅ Long context (8K) only ~3% decode slowdown vs short context

vs Targets:

Engine Decode (t/s) Status
SGLang 52 ✅ Beat by 10-15%
llama.cpp 58 ✅ Beat at short/medium context
vLLM (this) 57-60 Winner
3 Likes

Interesting, because the way I always read it was if you specify “a” suffix, it would include all the architecture features from the family plus GPU-specific features. IOW, 12.1a > 12.1f > 12.0f > 12.0. Is it not the case?

Wow, great job!

Is it on Flashinfer side? Maybe can include it in a custom build of flashinfer-jit-cache?