Good evening, all. I am attempting to compile a minimal CUTLASS GEMM example in a PyTorch project. I want to write a simple CUTLASS kernel and be able to call it from PyTorch.
I have a minimal example that refuses to compile.
setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
setup(
### Minimal Example Args
name="gorby",
install_requires=["torch >= 2.1", "pybind11"],
### PyTorch C++/CUDA Examples
ext_modules=[
#CUDAExtension(
# name="gorby_vector_add", sources=["gorby_vector_add.cu"]
#),
#CUDAExtension(
# name="gorby_nvtx", sources=["gorby_nvtx.cpp"]
#),
CUDAExtension(
name="gorby_sdpa", sources=["gorby_sdpa.cu"],
include_dirs=['C:\\Users\\neliopou\\Documents\\PhD\\Github\\cutlass\\include']
),
],
cmdclass={"build_ext": BuildExtension},
)
gorby_sdpa.cu
// CUTLASS
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_batched.h"
cudaError_t cutlass_array_sgemm(
int m,
int n,
int k,
float alpha,
float const * const *A,
int lda,
float const * const *B,
int ldb,
float * const *C,
int ldc,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmArray<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
A, lda,
B, ldb,
C, ldc,
C, ldc,
{alpha, beta},
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
I run MAX_JOBS=4 python setup.py develop
to compile and get the following output
build.txt (173.8 KB)
I would like to point out that this example code is taken from cutlass/examples/05_batched_gemm/batched_gemm.cu at main · NVIDIA/cutlass (github.com).
The errors seem to indicate some type problem originating from CUTLASS v3.4.1, and I am not sure how to proceed.
Here is my system information in case it is helpful:
- Windows Version10.0.19045 Build 19045
- GPU: NVIDIA RTX 3090 Ti
- CUDA 11.8
- Compiler: MSVC 2019 14.29.30133
- PyTorch 2.2.2+cu118