I found the root case:
main ← RobTand:fix/sm120-k64-blockscaled-tma-layout
opened 05:02PM - 20 Mar 26 UTC
## Summary
Two fixes that enable K=64 tile shapes for block-scaled MoE GEMM on … SM120 (RTX 5090/PRO 6000) and SM121 (DGX Spark GB10). Without these, K=64 tiles produce invalid TMA descriptors or overflow scale factor layout computations.
### 1. TMA zero-stride basis handling (`copy_traits_sm90_tma.hpp`)
When K=64 with SFVectorSize=32, scale factor folding creates a broadcast dimension with zero stride in `fill_tma_gmem_shape_stride()`. The existing code passes this to `basis_get()`, which produces invalid TMA descriptors. Fix: detect zero-stride basis via `is_constant<0>` and emit `shape=1, stride=0`.
### 2. Scale factor block size clamping (`sm120_blockscaled_mma_builder.inl`)
K=64 with SFVectorSize=32 gives `NumSFAlongK=2`, but `Blk_SF=4`. The existing division `Blk_SF/MMA_NSF` overflows. Fix: clamp effective block size (`EffBlk_SF`) to `min(NumSFAlongK, Blk_SF)` and conditionally fold into `kBasicBlock` to keep the TMA layout flat.
### Impact
Together these enable K=64 CTA shapes (`[128,128,64]`, `[128,256,64]`, `[256,128,64]`) which achieve 7-11 pipeline stages vs 2 with K=128, giving **~2x single-user decode throughput** on SM120/SM121 for NVFP4 MoE models.
### Behavior for existing tile sizes
The EffBlk_SF clamping only activates when `NumSFAlongK < Blk_SF` (i.e., small K values). For K >= 128 the code is identical to the current behavior:
| K | SFVectorSize | NumSFAlongK | Blk_SF | EffBlk_SF | FoldSF? | Behavior |
|---|---|---|---|---|---|---|
| 64 | 32 | 2 | 4 | **2** | yes | **New: clamped and folded** |
| 128 | 32 | 4 | 4 | 4 | no | Unchanged |
| 256 | 32 | 8 | 4 | 4 | no | Unchanged |
### Testing
Tested on DGX Spark (SM121, 128 GB unified LPDDR5X) with:
- **Nemotron-3-Super-120B-A12B-NVFP4** — 24 tok/s (up from ~12 tok/s without K=64)
- **Qwen3.5-122B-A10B-NVFP4** — 26 tok/s
### Caveats for reviewers
1. **`copy_traits_sm90_tma.hpp` is core CuTe infrastructure** used by all TMA operations across all architectures. The `if constexpr` branch only triggers for compile-time zero-stride constants (`Int<0>`), so there is no runtime cost and no change to existing non-zero-stride code paths. However, this file has broad blast radius — please verify that `shape=1, stride=0` is the universally correct interpretation for a zero-stride TMA basis (we believe broadcast is the only valid case).
2. **Only tested on SM121 (DGX Spark).** The EffBlk_SF clamping is architecture-independent and should work on SM120 (RTX 5090) and SM100 (datacenter Blackwell) if K=64 tiles are used there, but we have not validated those configurations.
3. **The multi-mode branch** (`tma_i_rank != 1`) in `fill_tma_gmem_shape_stride` also calls `basis_get` and could theoretically encounter zero strides. That branch has existing `gcd` and `!= 0` guards that may already handle it, but we only patched the rank-1 path.
4. **K=32 is untested.** `NumSFAlongK=1` with `Blk_SF=4` would produce `EffBlk_SF=1` and `FoldSFIntoBasicBlock=false` (since `1 > MMA_NSF` is false). This takes a different path than K=64 and is unverified.
### Related work
- **FlashInfer** [flashinfer-ai/flashinfer#2786](https://github.com/flashinfer-ai/flashinfer/pull/2786) — adds K=64 tile shapes to FlashInfer's kernel generation and dispatch. Depends on these CUTLASS fixes for correctness.
- **CUTLASS** [#3096](https://github.com/NVIDIA/cutlass/issues/3096) — original issue describing SM120 grouped GEMM failures
- **CUTLASS** [#3120](https://github.com/NVIDIA/cutlass/pull/3120) — our companion PR excluding SM12x from E2M1 PTX (separate issue, same hardware)
cc @brandonmmusic-max — your FlashInfer PR #2786 adds the K=64 shapes but the underlying CUTLASS TMA and scale factor layout fixes aren't in CUTLASS 4.4.2 yet. This PR provides those. Happy to coordinate or withdraw if you're working on the CUTLASS side separately.
SM120/SM121 (DGX Spark, RTX 50) has only 99KB SMEM vs 228KB on SM100. The K=128 block-scaled MoE GEMM tiles compile but overflow SMEM at runtime on SM120. And K=64 tiles that would fit can’t compile yet due to two unfixed CUTLASS bugs.
So the real problem isn’t instruction incompatibility, it’s that SM120 has only 99KB SMEM (vs 228KB on SM100), and the K=128 block-scaled MoE GEMM tiles overflow it at runtime.
5 Likes