I’m trying to implement a customized cuda operator and have it exposed to python with pybind11
. Here is my code:
flashattention_cuda_001.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
// flash attention version 1 is parallel over the head_num and batch size dimension
// kernel should be launched with config <<<([head_num]), (BLOCKSIZE * BLOCKSIZE)>>>
template<class realtype>
__global__ void flashattention_kernel(realtype *Q, realtype *K, realtype *V, realtype *O, const int seq_len, const int head_size, const int BLOCKSIZE)
{
extern __shared__ realtype share_mem[];
realtype *QBlock = share_mem;
realtype *KBlock = share_mem + BLOCKSIZE * head_size;
realtype *VBlock = share_mem + BLOCKSIZE * head_size * 2;
realtype *TBlock = share_mem + BLOCKSIZE * head_size * 3;
realtype * Q_dst = Q;
realtype rowmax_old[32] = {0.0}, rowmax_new[32] = {0.0};
realtype rowsum_old[32] = {0.0}, rowsum_new[32] = {0.0};
int head_id = blockIdx.x;
int row_threadblock = threadIdx.x / BLOCKSIZE;
int col_threadblock = threadIdx.x % BLOCKSIZE;
for(int j = 0; j < seq_len; j += BLOCKSIZE)
{
// load K, V into KBlock, VBlock (head_id should be taken into consideration)
for(int offset = 0; offset < head_size; offset+=BLOCKSIZE)
{
KBlock[row_threadblock * head_size + col_threadblock + offset] = K[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
VBlock[row_threadblock * head_size + col_threadblock + offset] = V[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
}
for(int i = 0; i < seq_len / BLOCKSIZE; i++)
{
// load Q into QBlock (head_id should be taken into consideration)
for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
{
QBlock[row_threadblock * head_size + col_threadblock + offset] = Q[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
}
// matmul Q and K.T and store the result into TBlock
// non-coherent access to shared memory?
realtype tmp = 0.0;
for(int k = 0; k < head_size; k++)
{
tmp += QBlock[row_threadblock * head_size + k] * KBlock[col_threadblock * head_size + k];
}
TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = tmp / sqrtf(_Float32(seq_len));
// calculate row max
rowmax_new[i] = -6666.66;
for(int idx = 0; idx < BLOCKSIZE; idx++)
{
if(TBlock[row_threadblock * BLOCKSIZE + idx] >= rowmax_new[i])
{
rowmax_new[i] = TBlock[row_threadblock * BLOCKSIZE + idx];
}
}
if(rowmax_old[i] >= rowmax_new[i] && j != 0)
{
rowmax_new[i] = rowmax_old[i];
}
//
TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = exp(TBlock[row_threadblock * BLOCKSIZE + col_threadblock] - rowmax_new[i]);
rowsum_new[i] = rowsum_old[i] * exp(rowmax_old[i] - rowmax_new[i]);
for(int idx = 0; idx < BLOCKSIZE; idx++)
{
rowsum_new[i] += TBlock[row_threadblock * BLOCKSIZE + idx];
}
TBlock[row_threadblock * BLOCKSIZE + col_threadblock] /= rowsum_new[i];
// calculate OBlock
for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
{
tmp = 0.0;
for(int sumidx = 0; sumidx < BLOCKSIZE; sumidx++)
{
tmp += TBlock[row_threadblock * BLOCKSIZE + sumidx] * VBlock[sumidx * head_size + col_threadblock + offset];
}
O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] = (rowsum_old[i] / rowsum_new[i]) * exp(rowmax_old[i] - rowmax_new[i]) * O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] + tmp;
}
// update rowmax_old and rowsum_old
rowmax_old[i] = rowmax_new[i];
rowsum_old[i] = rowsum_new[i];
// advance Q
Q += (BLOCKSIZE * head_size);
}
// put Q_ptr back to the original position
Q = Q_dst;
// advance K and V
K += (BLOCKSIZE * head_size);
V += (BLOCKSIZE * head_size);
}
}
// Q.shape = (num_heads, seq_len, head_size)
torch::Tensor flashattention_cuda_001(torch::Tensor Q, torch::Tensor K, torch::Tensor V)
{
torch::Tensor output = torch::empty_like(Q);
const int BLOCKSIZE = 32;
int num_heads = Q.sizes()[0];
int seq_len = Q.sizes()[1];
int head_size = Q.sizes()[2];
int shared_mem_size = 64 + 4 * (BLOCKSIZE * head_size * 3 + BLOCKSIZE * BLOCKSIZE);
AT_DISPATCH_FLOATING_TYPES(Q.type(), "flashattention_cuda_001", [&]{
flashattention_kernel<scalar_t><<<num_heads, BLOCKSIZE * BLOCKSIZE, shared_mem_size>>>(
Q.data_ptr<scalar_t>(),
K.data_ptr<scalar_t>(),
V.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
seq_len, head_size, BLOCKSIZE
);
});
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("flashattention_cuda_001", &flashattention_cuda_001, "Flash attention cuda version 0.01");
}
setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name = 'flashattention_cuda',
ext_modules = [
CUDAExtension('flashattention_cuda', [
'flashattention_cuda_001.cu',
]),
],
cmdclass = {
'build_ext': BuildExtension
}
)
As you can see, I used dynamic shared memory at the beginning of the kernel function. When I run python setup.py install
I get error:
FAILED: /home/zgt/cudac/flashattention/build/temp.linux-x86_64-3.10/flashattention_cuda_001.o
/usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/zgt/cudac/flashattention/build/temp.linux-x86_64-3.10/flashattention_cuda_001.o.d -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include/TH -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zgt/venv/cuda/include -I/usr/include/python3.10 -c -c /home/zgt/cudac/flashattention/flashattention_cuda_001.cu -o /home/zgt/cudac/flashattention/build/temp.linux-x86_64-3.10/flashattention_cuda_001.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flashattention_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 -std=c++17
/home/zgt/cudac/flashattention/flashattention_cuda_001.cu(11): error: declaration is incompatible with previous "share_mem" (declared at line 11)
extern __attribute__((shared)) realtype share_mem[];
^
detected during instantiation of "void flashattention_kernel(realtype *, realtype *, realtype *, realtype *, int, int, int) [with realtype=float]" at line 114
/home/zgt/cudac/flashattention/flashattention_cuda_001.cu(11): warning #20042-D: a host variable("share_mem") redeclared with __shared__
extern __attribute__((shared)) realtype share_mem[];
^
detected during instantiation of "void flashattention_kernel(realtype *, realtype *, realtype *, realtype *, int, int, int) [with realtype=float]" at line 114
Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
According to the error infomation, I redeclared a host variable share_mem
. Is that the wrong way to declare dynamic shared memory? Please tell me how to fix this, thank you so much.