Switch from "sm90_xmma_gemm.._cublas"/ "void cutlass::Kernel<cutlass_80_tensorop_.." kernels with CUDA-12.1 to "nvjet_tst..." kernels with CUDA-12.8

I am a running few ML models using PyTorch and torch.compile on top it on a system backed with H100 NVL GPUs.

With PyTorch built from source using CUDA-12.1 :

(pytorch-test) abhishek@chisel-8:/usr/local/cuda-12.8/include$ python3
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__config__.show()) 
PyTorch built with:
  - GCC 11.4
  - C++ Version: 201703
  - Intel(R) MKL-DNN v3.3.6 (Git Hash 86e6af5974177e513fd3fee58425e1063e7f1361)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DTMP_LIBKINETO_NANOSECOND -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, FORCE_FALLBACK_CUDA_MPI=1, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=OFF, USE_MKLDNN=ON, USE_MPI=ON, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, 


the ML model showing use of void cutlass::Kernel<cutlass_80_tensorop… kernel (6.1 us – cudaGraphLaunch delay 0s)


The ML model showing use of sm90_xmma_gemm_…_cublas kernel (13.2 us - cudaGraphLaunch Delay: 28 us)


The same PyTorch built from source using cuda 12.8 and

(abhishek-pytorch) abhishek@chisel-8:/disk2/abhishek/pytorch$ python3
Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__config__.show()) 
PyTorch built with:
  - GCC 11.4
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2025.0.1-Product Build 20241031 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.6 (Git Hash 86e6af5974177e513fd3fee58425e1063e7f1361)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.8
  - NVCC architecture flags: -gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2  (built against CUDA 12.1)
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.8, CUDNN_VERSION=8.9.2, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DTMP_LIBKINETO_NANOSECOND -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, FORCE_FALLBACK_CUDA_MPI=1, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=ON, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, 


an nvjet_tst.. kernel (5.2 us) in place of the 6.1 us cutlass kernel.
the nvjet_tst… having a cudaGraphLaunch delay of 5.5 us compared to the 0 s delay with the cutlass kernel


Again a nvjet_tst… kernel (10.5 us) without a preceeding memset in place of the sm90_xmma_gemm… (13.2 us) with a preceeding memset
the nvjet_tst… having a cudaGraphLaunch delay of 98.6 us compared to the 28.3 us delay with the sm90_xmma_gemm.. kernel


In CUDA 12.1 setting, there are calls to “sm90_xmma_gemm..cublas"/ "void cutlass::Kernel<cutlass_80_tensorop..” with torch.compile.
But with the CUDA 12.8 build they are getting replaced by
“nvjet_tst…”

Why is that so? (I have seen the post here and here). Just wanted to understand a bit more. How are these “nvjet_tst…” kernels different from the cutlass and the sm90_xmma_gemm variants?