cuDNN Bug Report: Conv3d Performance Regression with bfloat16/float16 on H100

cuDNN Bug Report: Conv3d Performance Regression with bfloat16/float16 on H100

Summary

Conv3d operations exhibit a performance degradation when using bfloat16 or float16 inputs compared to float32 on H100 GPUs, for specific input shapes commonly used in Vision Transformer patch embedding layers.

Environment

  • GPU: NVIDIA H100 PCIe (Compute Capability 9.0)
  • CUDA Version: 12.8
  • cuDNN Version: 9.1.0.02 (91002)
  • PyTorch Version: 2.9.0+cu128
  • OS: Ubuntu 24.04 (Linux)
  • Driver Version: 570.153.02

Minimal Reproducer

import torch
import time

def benchmark_conv3d(dtype, warmup=3, iterations=10):
    """Benchmark Conv3d with specified dtype."""
    # Conv3d configuration matching Qwen3-VL vision encoder patch_embed
    conv = torch.nn.Conv3d(
        in_channels=3,
        out_channels=1024,
        kernel_size=(2, 16, 16),
        stride=(2, 16, 16),
        bias=True
    ).cuda()
    
    if dtype != torch.float32:
        conv = conv.to(dtype)
    
    # Input shape: 64 images × 144 patches = 9216 batch elements
    # Each element: 3 channels × 2 temporal × 16×16 spatial
    x = torch.randn(9216, 3, 2, 16, 16, dtype=dtype, device='cuda')
    
    # Warmup
    for _ in range(warmup):
        with torch.no_grad():
            _ = conv(x)
        torch.cuda.synchronize()
    
    # Benchmark
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(iterations):
        with torch.no_grad():
            _ = conv(x)
        torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / iterations
    
    return elapsed

if __name__ == "__main__":
    print("Conv3d Benchmark: [9216, 3, 2, 16, 16] -> [9216, 1024, 1, 1, 1]")
    print("=" * 60)
    
    for dtype, name in [
        (torch.float32, "float32"),
        (torch.float16, "float16"),
        (torch.bfloat16, "bfloat16"),
    ]:
        try:
            elapsed = benchmark_conv3d(dtype)
            print(f"{name:>10}: {elapsed*1000:>10.2f} ms")
        except Exception as e:
            print(f"{name:>10}: ERROR - {e}")
    
    # Calculate and display regression
    t_f32 = benchmark_conv3d(torch.float32)
    t_bf16 = benchmark_conv3d(torch.bfloat16)
    print("=" * 60)
    print(f"Regression factor (bfloat16/float32): {t_bf16/t_f32:.0f}x slower")

Observed Results

Data Type Time per Forward Pass Relative Performance
float32 2.2 ms 1.0x (baseline)
float16 35,621 ms 16,191x slower
bfloat16 ~35,000 ms ~16,000x slower

Expected Behavior

Half-precision operations (bfloat16/float16) should be equal or faster than float32 on H100, which has native support for these data types via Tensor Cores.

Impact

This issue affects production workloads using Vision-Language Models, specifically:

  • Qwen3-VL (Alibaba) - Uses this exact Conv3d configuration in Qwen3VLVisionPatchEmbed
  • Qwen2-VL - Same architecture
  • Any ViT-style model that processes patches through Conv3d with large batch dimensions

The workaround (forcing float32 for Conv3d) increases memory usage and prevents full utilization of H100’s half-precision capabilities.

Analysis

The issue appears to be in cuDNN’s algorithm selection heuristics. When the combination of:

  1. Large batch dimension (9216)
  2. Small spatial output (1×1×1)
  3. Half-precision dtype (bf16/fp16)

…is encountered, cuDNN selects a catastrophically slow algorithm.

Evidence that this is an algorithm selection issue (not a fundamental limitation):

  • float32 with identical shapes runs in ~70ms
  • H100 Tensor Cores natively support bf16/fp16
  • The mathematical operations are identical

Workaround

Force float32 computation and cast output back:

# Slow (36 seconds)
output = conv(x.bfloat16())

# Fast (0.07 seconds)  
with torch.autocast(device_type="cuda", enabled=False):
    output = conv(x.float()).to(torch.bfloat16)

Request

Please investigate the algorithm selection logic for Conv3d with:

  • Large batch dimensions (>1000)
  • 1×1×1 spatial output
  • Half-precision inputs

The heuristic should prefer the same fast algorithm used for float32 inputs.

Additional Information

Happy to provide additional profiling data, NSight traces, or run specific diagnostic commands if helpful.

Disclaimer: LLM helped me write this issue in a clearer manner