CUTLASS Minimal Example - error: expression must have constant value

Good evening, all. I am attempting to compile a minimal CUTLASS GEMM example in a PyTorch project. I want to write a simple CUTLASS kernel and be able to call it from PyTorch.

I have a minimal example that refuses to compile.

setup.py

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension

setup(
    ### Minimal Example Args
    name="gorby",
    install_requires=["torch >= 2.1", "pybind11"],
    ### PyTorch C++/CUDA Examples
    ext_modules=[
        #CUDAExtension(
        #    name="gorby_vector_add", sources=["gorby_vector_add.cu"]
        #),
        #CUDAExtension(
        #    name="gorby_nvtx", sources=["gorby_nvtx.cpp"]
        #),
        CUDAExtension(
            name="gorby_sdpa", sources=["gorby_sdpa.cu"],
            include_dirs=['C:\\Users\\neliopou\\Documents\\PhD\\Github\\cutlass\\include']
        ),
    ],
    cmdclass={"build_ext": BuildExtension},
)

gorby_sdpa.cu

// CUTLASS
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_batched.h"

cudaError_t cutlass_array_sgemm(
  int m,
  int n,
  int k,
  float alpha,
  float const * const *A,
  int lda,
  float const * const *B,
  int ldb,
  float * const *C,
  int ldc,
  float beta,
  int batch_count) {

  using Gemm = cutlass::gemm::device::GemmArray<
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor
  >;

  Gemm gemm_op;

  cutlass::Status status = gemm_op({
    {m, n, k},
    A, lda,
    B, ldb,
    C, ldc,
    C, ldc,
    {alpha, beta},
    batch_count
  });

  if (status != cutlass::Status::kSuccess) {
    return cudaErrorUnknown;
  }

  return cudaSuccess;
}

I run MAX_JOBS=4 python setup.py develop to compile and get the following output
build.txt (173.8 KB)

I would like to point out that this example code is taken from cutlass/examples/05_batched_gemm/batched_gemm.cu at main · NVIDIA/cutlass (github.com).

The errors seem to indicate some type problem originating from CUTLASS v3.4.1, and I am not sure how to proceed.

Here is my system information in case it is helpful:

  • Windows Version10.0.19045 Build 19045
  • GPU: NVIDIA RTX 3090 Ti
  • CUDA 11.8
  • Compiler: MSVC 2019 14.29.30133
  • PyTorch 2.2.2+cu118

You should post on CUTLASS Github, if you haven’t already.