Custom FP4 CUDA Kernel - 129 TFLOPS on DGX Spark with Pre-Quantized Weight Cache

I’m training a vision-language model (GPT-OSS-4.2B-Vision) on a DGX Spark and wanted to speed up training. The Spark has FP4 tensor cores on SM121 - lower precision, faster math, should be a straightforward win. Tried Triton’s tl.dot_scaled as a quick fix. Couple hours later I find out it falls back to software emulation on SM121. 100x slower than BF16. Great.

So I went all in and wrote a custom FP4 GEMM kernel on top of CUTLASS 3.8. Along the way I discovered FP4 doesn’t actually help training - no backward pass. But what came out of it is something I haven’t seen anywhere else for consumer Blackwell: a standalone FP4 GEMM library with a pre-quantized weight cache that hits 85-129 TFLOPS on the Spark.

Quantize weights once at model load, only quantize activations on the fly per call. Integrated into a full transformer (GPT-OSS-4.2B, 24 layers, 288 GEMM calls per forward pass), it runs 1.3-2.3x faster than BF16 at inference-relevant batch sizes with 4x memory savings. Tested on both 4.2B and 20B models - the 20B drops from 43.4 GB to 4.0 GB with FP4 weights (10.8x compression). No dependency on vLLM, TRT-LLM, or sglang - just a library you can call from any Python code.

Full source is open: GitHub - VincentKaufmann/fp4-cuda-kernel: Custom FP4 GEMM kernel for DGX Spark / RTX 50 Series (SM120/SM121). 143 TFLOPS, 5-9x faster than BF16. Built on CUTLASS 3.8.


Why This Library Exists

No existing path gives you hardware FP4 on SM121 as a standalone library:

Approach Result on SM121
cuBLAS FP4 Not available (no cublasLtMatmul FP4 support)
Triton tl.dot_scaled Software fallback - 100x slower than BF16
Gluon / tcgen05 SM121 lacks TMEM/tcgen05 (datacenter SM100 only)
bitsandbytes NF4 Software dequant, no tensor core acceleration
CUTLASS Example 79a Works! But standalone binary, not a library
vLLM/TensorRT-LLM Require their full serving stack - can’t just call a GEMM
shanjiaz/gpt-oss-120b-nvfp4-modelopt Exists on HuggingFace but broken - segfaults in sglang (sglang #16595)

Avarok has done good work on FP4 inference via vLLM Docker on Spark, but that requires running inside vLLM’s serving framework. I wanted a standalone library I could call from any Python code - training scripts, custom inference loops, evaluation harnesses - without being locked to a specific serving stack.

Also worth noting: stock MXFP4 weights (E2M1 + E8M0 scales) cannot be directly used on SM121 tensor cores. SM121 uses NVFP4 format with UE4M3 scale factors, not MXFP4’s E8M0. The CUTLASS GEMM expects data in this format with a specific interleaved scale layout (SfKMajorAtom). Any quantized model checkpoints in MXFP4 format need re-quantization to NVFP4 + CUTLASS layout before the hardware will accept them.


What I Built

Two modes of operation:

Dynamic mode - both A and B quantized every call (for training-like workloads):

from fp4_gemm import fp4_linear
output = fp4_linear(x, weight, bias)  # Drop-in F.linear replacement

Cached mode - weights pre-quantized once, only activations quantized per call (for inference):

from fp4_gemm import fp4_quantize, fp4_cached_linear

# Quantize weights once at model load (milliseconds, one-time cost)
cache = fp4_quantize(weight)

# Every inference call - only activations quantized on the fly
output = fp4_cached_linear(x, cache)  # 85-129 TFLOPS

Cached mode is the main contribution. In a transformer inference scenario weight matrices don’t change between tokens - quantizing them every forward pass is pure waste. fp4_quantize() converts BF16 weights to packed FP4 E2M1 + UE4M3 scales in CUTLASS’s interleaved layout, stores them on device, and every subsequent GEMM skips B quantization entirely.


Performance - Honest Numbers

Benchmarked on DGX Spark GB10 (SM121, 128 GB unified LPDDR5x, 273 GB/s):

Size (M x N x K) FP4 Dynamic FP4 Cached BF16 F.linear Float32 mm
256 x 2880 x 2880 0.447 ms (10 TF) 0.050 ms (85 TF) 0.118 ms 0.64 ms
512 x 2880 x 2880 0.450 ms (19 TF) 0.100 ms (85 TF) 0.101 ms 0.85 ms
2048 x 2880 x 7680 1.993 ms (46 TF) 0.702 ms (129 TF) 1.190 ms 7.12 ms
2048 x 7680 x 2880 1.726 ms (53 TF) 0.766 ms (118 TF) 1.073 ms 6.95 ms
4096 x 2880 x 2880 1.478 ms (46 TF) 1.089 ms (62 TF) 0.752 ms 5.12 ms

What the numbers say:

  • FP4 Cached vs BF16 F.linear: 1.4-2.4x faster at small-to-medium M (prefill, small batch inference). At M=4096 BF16 F.linear pulls ahead - the GEMM is compute-bound enough that BF16 cuBLAS saturates the tensor cores and FP4’s activation quantization overhead becomes a tax.
  • FP4 Cached vs Float32 torch.mm: 5-13x faster. The original version of this post claimed “5-9x faster than BF16 cuBLAS” - that comparison was actually against float32 torch.mm, not BF16 F.linear. Corrected here.
  • FP4 Dynamic (both quantized) is slow: 10-53 TFLOPS. Quantizing both matrices every call is dominated by the quantization kernel, not the GEMM. Dynamic mode only makes sense if both matrices change every call (rare in inference).
  • The real win at large M: even when FP4 isn’t faster in wall-clock time, weights are stored as FP4 (4x smaller than BF16). For large models on memory-constrained hardware, fitting the model in memory at all is the bottleneck.

Accuracy: 0.991 Pearson correlation vs float32 reference, ~1.2% mean relative error. FP4 E2M1 only has 12 representable values so there’s real quantization noise - but for inference on a model trained in higher precision the quality impact is minimal. Same FP4 format NVIDIA uses in their own ModelOpt/TensorRT-LLM FP4 quantization.


End-to-End Model Benchmarks

Kernel-level numbers are one thing. Running FP4 through a full transformer with varying weight dimensions per layer is where it actually gets hard.

GPT-OSS-4.2B (Dense MoE - 4 experts, all active)

GPT-OSS-4.2B (24 layers, hidden=2880, 4 experts all-active, vocab=201088) has 12 linear ops per layer with 6 different dimension combinations - 288 GEMM calls per forward pass. The naive approach of calling fp4_gemm_init() per call caused cudaMalloc/cudaFree cycles every time dimensions changed, making FP4 4x slower than BF16.

The fix: fp4_gemm_prealloc(max_M, max_N, max_K) - allocate device buffers for the maximum dimensions once at startup. The GEMM init now checks if existing buffers are sufficient and skips reallocation. Combined with removing per-call cudaDeviceSynchronize() (single sync at end of forward pass) this gave an 8x improvement:

Config FP4 Cached BF16 F.linear Speedup
batch=1, seq=64 25.9 ms (21 TF) 51.0 ms (11 TF) 2.0x FP4 wins
batch=1, seq=128 25.6 ms (42 TF) 59.5 ms (18 TF) 2.3x FP4 wins
batch=4, seq=64 34.6 ms (62 TF) 53.8 ms (40 TF) 1.6x FP4 wins
batch=4, seq=128 57.4 ms (75 TF) 76.1 ms (57 TF) 1.3x FP4 wins
batch=1, seq=256 36.2 ms (59 TF) 57.5 ms (37 TF) 1.6x FP4 wins
batch=4, seq=512 290.6 ms 298.4 ms ~tied

Model loading stats:

  • GPT-OSS-4.2B: 4,186M params, 8,371 MB in BF16
  • FP4 weight cache: 2,083 MB (4.0x compression)
  • Quantization time: 0.22s total across all layers (9ms per layer)
  • Pre-allocation: ~50ms for max buffers

GPT-OSS-20B (Sparse MoE - top-4 of 32 experts)

I also built a full inference engine with KV cache, YaRN RoPE, top-k MoE routing, attention sinks (GPT-OSS-specific), and the custom gated activation - and ran both models end-to-end.

Model Mode Decode Speed Prefill Speed GPU Memory Text Quality
GPT-OSS-4.2B BF16 20.2 tok/s 90 tok/s ~8 GB Coherent
GPT-OSS-20B BF16 16.8 tok/s 41 tok/s 43.4 GB Coherent
GPT-OSS-20B FP4 (raw) 26.9 tok/s 45 tok/s 4.0 GB Gibberish
GPT-OSS-20B NVFP4 (ModelOpt PTQ) 9.7 tok/s* - 41.8 GB** Coherent

* ModelOpt simulated quantization adds overhead (quant+dequant every forward pass). With packed FP4 via our CUTLASS kernel, expect ~27 tok/s.
** Simulated quant stores BF16 + metadata. Packed FP4 weights would be ~4 GB.

The 20B model goes from 43.4 GB to 4.0 GB - 10.8x memory reduction. FP4 decode is 1.6x faster than BF16. The entire 20B fits in the VRAM budget of an RTX 4060.

ModelOpt PTQ validates FP4 text quality. I ran NVIDIA’s ModelOpt (v0.41.0) PTQ calibration on the 20B with 128 samples and NVFP4_DEFAULT_CFG. Calibration took 171 seconds. The result:

Prompt: "Explain how neural networks learn in simple terms."
NVFP4 Response: "Neural networks learn by adjusting the weights of connections
between neurons. This process is called backpropagation. Initially, the
network makes predictions based on random weights."

Clean, coherent text at FP4 precision. Compare to raw FP4 without calibration: "вычай(out nhẹamataDeARRY'ét" - gibberish.

What ModelOpt does: inserts TensorQuantizer modules that simulate FP4 quantization in every linear layer’s forward pass. The max calibration algorithm finds the optimal _amax (maximum activation) per tensor so that FP4 rounding minimizes error. The calibrated model stores these _amax values alongside BF16 weights. For deployment, the weights get packed to actual FP4 + UE4M3 scales using our kernel’s format.

Why raw FP4 fails but calibrated FP4 works: raw FP4 uses naive max-scaling (scale = max(|weight|) / 6.0). Calibrated FP4 uses data-aware scaling that accounts for the actual distribution of activations flowing through each layer - outliers are clipped optimally rather than distorting the entire scale range.

Important caveat on raw FP4: without ModelOpt calibration, FP4 quantization noise compounds through 24 sequential layers. Individual GEMMs have 0.991 Pearson correlation, but the error accumulates - layers 0-5 maintain 0.98-0.99 Pearson vs BF16, but by layer 20+ the signal degrades. GPT-OSS’s custom activation function (gate * sigmoid(gate * 1.702) with clamping at ±7.0) amplifies this more than standard SiLU. For clean text generation, run ModelOpt PTQ calibration first. For GEMM benchmarks and memory savings, raw weights work fine.

Lesson learned: buffer pre-allocation is not optional for multi-layer inference. This is table stakes in TRT-LLM (they pre-allocate everything at engine build time). CUDA Graphs would eliminate the remaining Python dispatch overhead (~60% more based on community benchmarks) but pre-allocation alone was the single biggest win - 8x from one optimization.


Where FP4 Shines - Use Cases Beyond LLMs

The compounding noise issue is specific to autoregressive LLM generation through deep sequential layers without QAT. Many workloads don’t have this problem:

Use Case Compounding? Speed Memory Win Why It Works
QAT models (ModelOpt) Handled by QAT 85-129 TF 4x Weights trained to be FP4-robust
Diffusion models (FLUX, SD3) No 85-129 TF 4x Single forward pass per denoise step, no autoregressive chain
Embeddings (BERT, CLIP) Minimal 85-129 TF 4x Output only needs to be directionally correct (cosine similarity)
Vision (ViT, YOLO, SigLIP) No Slower* 3.6x Memory savings; BF16 faster at typical ViT dimensions
Recommendation (DLRM) No 85-129 TF 4x Shallow dense layers, memory-bound on embedding tables
Audio (Whisper encoder) No 85-129 TF 4x Encoder is a single forward pass
Scientific HPC matmul N/A 85-129 TF 4x Any large matmul where ~1% error is acceptable
Raw LLM (no QAT) Yes Fast but noisy 4x Speed + memory savings; use with QAT for quality

Benchmarked on real model dimensions (DGX Spark, single GEMMs):

Workload FP4 ms FP4 TF BF16 ms Speedup
FLUX.1 MMDiT attn (2048 patches) 0.29 ms 133 TF 0.53 ms 1.82x FP4
BERT-Large MLP (seq=512) 0.03 ms 127 TF 0.07 ms 2.06x FP4
DiT-XL MLP (1024 patches) 0.09 ms 119 TF 0.12 ms 1.30x FP4
DiT-XL attn (1024 patches) 0.03 ms 94 TF 0.04 ms 1.28x FP4
ViT-Large MLP (batch=256) 0.02 ms 94 TF 0.03 ms 1.39x FP4

The FLUX.1 number is the headline: 133 TFLOPS at 1.82x BF16 speed on a consumer GPU. Diffusion models have large M (many patches/tokens per denoising step) which is exactly where FP4 excels.

Real SigLIP-SO400M benchmark - actual model weights, not synthetic tensors:

I loaded the full SigLIP-SO400M-patch14-384 (428M params, 27 ViT layers, 165 linear layers), replaced all 164 quantizable linears with FP4, and benchmarked end-to-end inference on random 384x384 images:

Batch Patches BF16 ms FP4 ms Speedup BF16 img/s FP4 img/s
1 729 18.8 20.4 0.92x 53.3 48.9
4 2,916 61.1 87.7 0.70x 65.5 45.6
8 5,832 127.5 196.6 0.65x 62.7 40.7
16 11,664 265.0 411.2 0.64x 60.4 38.9
32 23,328 528.1 822.2 0.64x 60.6 38.9

Weight memory: 842 MB BF16 → 237 MB FP4 (3.6x smaller). Raw FP4 accuracy: Pearson 0.61, cosine similarity 0.61 - significant degradation without calibration (same pattern as LLM weights). With ModelOpt PTQ calibration, expect accuracy to recover to ~0.99+ while keeping the memory savings.

Honest takeaway on vision models: BF16 is faster on SigLIP at all batch sizes. SigLIP’s dimensions (1152 hidden, 4304 MLP intermediate) are in the range where BF16 cuBLAS saturates the tensor cores, and the overhead of FP4 activation quantization per-layer doesn’t pay for itself. The memory savings remain significant - 3.6x compression lets you run larger batches or multiple models in the same memory budget.

The memory story is often bigger than the speed story. FLUX.1-dev (12B params) on a 4GB GPU. A 20B LLM in 4 GB. SigLIP in 237 MB instead of 842 MB. Embedding models at 4x density. Even when FP4 isn’t faster in wall-clock time, fitting the model in memory at all is frequently the bottleneck on consumer hardware.

ModelOpt PTQ path (tested): I ran ModelOpt 0.41.0’s NVFP4 PTQ calibration on GPT-OSS-20B - 128 calibration samples, 171 seconds on DGX Spark, zero training required. The calibrated model generates clean text at FP4 precision. ModelOpt’s NVFP4_DEFAULT_CFG uses E2M1 weights + E4M3 block scales (block size 16) - the exact same NVFP4 format our CUTLASS kernel expects. Pipeline: pip install --no-deps nvidia-modeloptmtq.quantize(model, NVFP4_DEFAULT_CFG, forward_loop=calibrate) → pack weights to our kernel’s format → 129 TFLOPS inference. Script included in the repo (modelopt_ptq.py).


Pre-Quantized Weight Cache API

Full API with context manager support:

from fp4_gemm import fp4_quantize, fp4_cached_linear, FP4WeightCache

# === Basic usage ===
cache = fp4_quantize(weight)        # weight: [N, K] BF16 on CUDA
output = fp4_cached_linear(x, cache) # x: [..., K] BF16 -> output: [..., N]
cache.free()                         # release GPU memory

# === Context manager (auto-cleanup) ===
with FP4WeightCache(weight, bias=bias) as cache:
    output = cache.forward(x)

# === Quantize all layers at model load ===
caches = {}
for name, param in model.named_parameters():
    if 'weight' in name and param.dim() == 2:
        caches[name] = fp4_quantize(param.data)
# Now use caches[name].forward(x) instead of F.linear(x, param)

Memory footprint: FP4 weights are 4x smaller than BF16. A 2880x2880 weight matrix drops from 15.8 MB (BF16) to 4.0 MB (FP4 packed + UE4M3 scales).

Auto-padding is handled internally - if your dimensions aren’t multiples of 128 the library pads/unpads transparently.


Technical Details

Architecture overview:

Cached path (inference):
  Load time:  BF16 Weight [N, K] --> GPU Quantize --> FP4 [N, K/2] + UE4M3 scales  (stored)
  Each call:  BF16 Activation [M, K] --> GPU Quantize --> FP4 [M, K/2] + UE4M3 scales
                                                                    |
                                                                    v
                                                         CUTLASS Block-Scaled GEMM
                                                         (mma.sync.aligned.block_scale)
                                                                    |
                                                                    v
                                                         BF16 Output D [M, N]

CUTLASS configuration (identical to Example 79a):

using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;  // NVFP4
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;  // NVFP4
using TileShape = Shape<_128, _128, _128>;   // Threadblock tile
using ClusterShape = Shape<_1, _1, _1>;      // No multicast (SM121)

GPU FP4 Quantization Kernel (quantize_bf16_to_fp4_kernel) - not present in stock CUTLASS. One CUDA thread per 16-element scale block:

Per thread (handles 16 consecutive elements):
  1. Read 16 BF16 values from source matrix
  2. Compute max |value| across the block
  3. Scale factor = max / 6.0, convert to UE4M3 (8-bit unsigned float, 4 exp bias=7, 3 mantissa)
  4. For each value: divide by scale, round to nearest FP4 E2M1
  5. Pack pairs into bytes (low nibble = even index)
  6. Write scale to CUTLASS interleaved SfKMajorAtom position

Scale Factor Layout - the hardest part. CUTLASS uses an interleaved SfKMajorAtom that requires CuTe’s flat coordinate decomposition to index correctly:

Atom Shape:  ((32, 4), (SFVecSize, 4))
Atom Stride: ((16, 4), (0,         1))

For coordinate (row, k_block):
  index = (row % 32) * 16
        + ((row / 32) % 4) * 4
        + (row / 128) * row_tile_stride
        + (k_block % 4) * 1
        + (k_block / 4) * k_tile_stride

I reverse-engineered this from CuTe’s internal flat decomposition. Manual hierarchical coordinates (as I initially tried) produce different indices and corrupt ~10% of output elements. The flat decomposition is the only path that matches what the GEMM kernel expects.

C API exposed as a shared library:

// Pre-allocate buffers for max dimensions (call once at startup)
int fp4_gemm_prealloc(int max_M, int max_N, int max_K);

// Dynamic (both quantized per call)
int fp4_gemm_run(A_bf16, B_bf16, C_bf16, D_bf16, M, N, K, alpha, beta);

// Pre-quantize weights once
void* fp4_quantize_weights(weight_bf16, N, K);

// Run with cached weights (only A quantized, async - no sync)
int fp4_gemm_run_cached(A_bf16, cache_handle, C_bf16, D_bf16, M, alpha, beta);

// Explicit sync (call once after all GEMMs in a forward pass)
void fp4_gemm_sync();

// Cleanup
void fp4_weight_cache_free(cache_handle);

Interactive Inference Engine

Beyond the GEMM library, I also wrote a full custom inference engine for GPT-OSS that supports both the 4.2B (dense MoE) and 20B (sparse MoE) models. This was necessary to understand how FP4 behaves in real generation - things I had to reverse-engineer from the HuggingFace source that aren’t documented:

  • Attention sinks: GPT-OSS has a per-head learned scalar concatenated to attention logits as an extra softmax column, then dropped after softmax. Without this, the attention distribution is completely wrong.
  • Custom gated activation: Not standard SiLU. Uses interleaved gate/up split (even/odd indices), gate * sigmoid(gate * 1.702) with clamping at ±7.0, and (up + 1) * glu. This is more sensitive to quantization noise than vanilla SiLU.
  • Sparse MoE routing: Top-4 of 32 experts with per-token scatter-gather and post-selection softmax.

The engine includes YaRN RoPE, KV cache, alternating sliding/full attention, and streaming generation - all written from scratch to isolate the FP4 behavior from any framework overhead. Source is in fp4_generate.py.


Where to Use FP4 - Quick Summary

Speed wins:

  • Diffusion (FLUX.1, SD3) - 1.82x faster, 133 TFLOPS. Large token counts per step, single forward pass. This is the best FP4 use case.
  • LLM decode at small batch - 1.6-2.3x faster. Memory-bound regime where FP4’s 4x smaller reads dominate.

Memory wins (everywhere):

  • 20B LLM in 4 GB instead of 43 GB. Run models that literally don’t fit otherwise.
  • Stack multiple models in one GPU - SigLIP + LLM + YOLO in a pipeline that would OOM with BF16.
  • Embedding serving - 4x more models per GPU for retrieval/RAG.

The community angle: anyone on an RTX 5090/5080 or Spark who wants to run bigger models than their VRAM allows. The fp4_quantize() + fp4_cached_linear() API is two lines to drop into any model. Pair with ModelOpt PTQ for quality.


Known Issues and Caveats

Raw FP4 produces gibberish on LLMs. Without ModelOpt calibration, quantization noise compounds through 24+ sequential transformer layers. Each individual GEMM has 0.991 Pearson correlation vs BF16, but the errors accumulate layer after layer. By the final layer the signal is degraded enough that text generation produces nonsense. This is not a kernel bug - it’s the fundamental limitation of naive FP4 quantization on deep networks. The fix is ModelOpt PTQ calibration (171 seconds, zero training) which produces clean coherent text.

ModelOpt PTQ is simulated quantization, not packed FP4. ModelOpt inserts TensorQuantizer modules that quant/dequant on every forward pass. The saved model is still BF16 weights + _amax calibration metadata (41.8 GB for 20B, not 4 GB). This adds overhead - 9.7 tok/s vs 16.8 tok/s BF16 baseline. The full pipeline to get both speed AND quality: ModelOpt calibration → extract _amax scales → pack weights to FP4 using our CUTLASS kernel format → 27+ tok/s with clean text. The packing step is not yet automated in the repo but the format is compatible.

BF16 is faster than FP4 on vision models. Tested on real SigLIP-SO400M (428M params, 164 linear layers). BF16 wins at all batch sizes (0.64-0.92x). SigLIP’s dimensions (1152 hidden, 4304 MLP) are in the range where BF16 cuBLAS already saturates the tensor cores. FP4 still saves 3.6x memory (842 MB → 237 MB) but is not a speed win here.

FP4 loses at large batch sizes. At M=4096+ (large batch inference, long sequences), BF16 F.linear pulls ahead because the GEMM becomes compute-bound and FP4’s activation quantization overhead becomes a tax. FP4 wins at small-to-medium M (batch=1-4, seq=64-256) where memory bandwidth is the bottleneck.

No backward pass. FP4 tensor cores on SM121 only support forward (A*B=D). No gradient computation. This is an inference-only library.


What This Is and What It Isn’t

What it is:

  • A standalone FP4 GEMM library for SM120/SM121 (RTX 5090, RTX 5080, DGX Spark, etc.)
  • Framework-independent - works from any Python code via ctypes, no dependency on vLLM/TRT-LLM/sglang
  • The only library I’m aware of that gives you a callable FP4 GEMM on consumer Blackwell without a full serving stack
  • A pre-quantized weight cache that eliminates redundant quantization in inference loops
  • Useful for any matrix-multiply workload - LLMs, diffusion, embedding, vision, recommendation, scientific computing
  • Compatible with ModelOpt PTQ/QAT weights (same NVFP4 E2M1 + UE4M3 format)

What it isn’t:

  • A training accelerator - BF16 F.linear is faster at large batch sizes and FP4 has no backward pass. I originally built this thinking it would speed up training. It doesn’t. It’s for inference.
  • A replacement for TRT-LLM/vLLM if you’re already using those - they have their own optimized FP4 paths for SM100. On SM121 specifically this library fills a gap.
  • A magic bullet for all GEMM sizes - at M=4096+ BF16 cuBLAS wins on wall-clock time. FP4 still saves 4x memory.
  • A magic bullet for LLM text quality - raw (non-QAT) FP4 weights produce degraded text generation due to error compounding through deep transformer layers. Use ModelOpt PTQ calibration for production quality. The GEMM speed and memory savings are real regardless.

Open Source

Repository: GitHub - VincentKaufmann/fp4-cuda-kernel: Custom FP4 GEMM kernel for DGX Spark / RTX 50 Series (SM120/SM121). 143 TFLOPS, 5-9x faster than BF16. Built on CUTLASS 3.8.

File Description
fp4_gemm_lib.cu CUDA source - CUTLASS GEMM + GPU quantization + cached weight API + pre-allocation
fp4_gemm.py Python wrapper - dynamic, cached, auto-padding, batching, context managers
fp4_inference_demo.py End-to-end GPT-OSS-4.2B inference benchmark (FP4 vs BF16)
generate/fp4_generate.py Full interactive generation engine - supports 4.2B + 20B, KV cache, YaRN RoPE, sparse MoE
build.sh Build script - auto-detects GPU, clones CUTLASS 3.8, compiles

Build:

git clone https://github.com/VincentKaufmann/fp4-cuda-kernel.git
cd fp4-gemm-blackwell
./build.sh  # auto-detects sm_120a or sm_121a

Or manually:

nvcc -arch=sm_121a -shared -Xcompiler -fPIC -O2 --expt-relaxed-constexpr \
  -I cutlass/include -I cutlass/tools/util/include -I cutlass/examples/common \
  -o libfp4gemm.so fp4_gemm_lib.cu

Requirements: SM120 or SM121 GPU, CUDA 12.8+ (12.9+ for SM121), CUTLASS 3.8+, Python 3.8+, PyTorch with CUDA.

I built this because I needed it. If you have a DGX Spark, RTX 5090, or RTX 5080 and want hardware FP4 without spinning up a full serving framework, this is the path that works. The fp4_inference_demo.py shows how to integrate it into a real model - quantize all weights at load time, pre-allocate buffers, run async GEMMs, sync once at the end.

Happy to discuss implementation details, limitations, or benchmarks at other matrix sizes.

Nice work!

Both sm_120a and sm_121a support
mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::2X.f32.e2m1.e2m1.f32.ue8m0 and
mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3
since nvcc 12.9.0 and
mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue8m0
since nvcc 13.1.
(All three compile to a single OMMA instruction, so are supported natively)

So with the latest CUDA toolkit MXFP4 (E2M1 with E8M0 scales) is supported with block sizes 16 and 32 (before 13.1 only 32 is supported). NVFP4 only supports block size 16.

There is a place on these forums for project reports like this, it is here.

I can move the post there if you wish. Or we can leave it here.

Robert, you can move it if you want :)

The headline is more Custom FP4 CUDA Kernel than DGX Spark, No?

Yes, I can see both. I’ll leave it.