Compilation error when using dynamic shared memory

I’m trying to implement a customized cuda operator and have it exposed to python with pybind11. Here is my code:
flashattention_cuda_001.cu

#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>


// flash attention version 1 is parallel over the head_num and batch size dimension
// kernel should be launched with config <<<([head_num]), (BLOCKSIZE * BLOCKSIZE)>>>
template<class realtype>
__global__ void flashattention_kernel(realtype *Q, realtype *K, realtype *V, realtype *O, const int seq_len, const int head_size, const int BLOCKSIZE)
{
    extern __shared__ realtype share_mem[];
    realtype *QBlock = share_mem;
    realtype *KBlock = share_mem + BLOCKSIZE * head_size;
    realtype *VBlock = share_mem + BLOCKSIZE * head_size * 2;
    realtype *TBlock = share_mem + BLOCKSIZE * head_size * 3;

    realtype * Q_dst = Q;

    realtype rowmax_old[32] = {0.0}, rowmax_new[32] = {0.0};
    realtype rowsum_old[32] = {0.0}, rowsum_new[32] = {0.0};

    int head_id = blockIdx.x;

    
    int row_threadblock = threadIdx.x / BLOCKSIZE;
    int col_threadblock = threadIdx.x % BLOCKSIZE;
    
    for(int j = 0; j < seq_len; j += BLOCKSIZE)
    {
        // load K, V into KBlock, VBlock (head_id should be taken into consideration)
        for(int offset = 0; offset < head_size; offset+=BLOCKSIZE)
        {
            KBlock[row_threadblock * head_size + col_threadblock + offset] = K[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
            VBlock[row_threadblock * head_size + col_threadblock + offset] = V[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
        }
        for(int i = 0; i < seq_len / BLOCKSIZE; i++)
        {
            // load Q into QBlock (head_id should be taken into consideration)
            for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
            {
                QBlock[row_threadblock * head_size + col_threadblock + offset] = Q[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
            }

            // matmul Q and K.T and store the result into TBlock
            // non-coherent access to shared memory?
            realtype tmp = 0.0;
            for(int k = 0; k < head_size; k++)
            {
                tmp += QBlock[row_threadblock * head_size + k] * KBlock[col_threadblock * head_size + k];
            }
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = tmp / sqrtf(_Float32(seq_len));
        
            // calculate row max
            rowmax_new[i] = -6666.66;
            for(int idx = 0; idx < BLOCKSIZE; idx++)
            {
                if(TBlock[row_threadblock * BLOCKSIZE + idx] >= rowmax_new[i])
                {
                    rowmax_new[i] = TBlock[row_threadblock * BLOCKSIZE + idx];
                }
            }
            if(rowmax_old[i] >= rowmax_new[i] && j != 0)
            {
                rowmax_new[i] = rowmax_old[i];
            }
            // 
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = exp(TBlock[row_threadblock * BLOCKSIZE + col_threadblock] - rowmax_new[i]);
            rowsum_new[i] = rowsum_old[i] * exp(rowmax_old[i] - rowmax_new[i]);
            for(int idx = 0; idx < BLOCKSIZE; idx++)
            {
                rowsum_new[i] += TBlock[row_threadblock * BLOCKSIZE + idx];
            }
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] /= rowsum_new[i];
            // calculate OBlock
            for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
            {
                tmp = 0.0;
                for(int sumidx = 0; sumidx < BLOCKSIZE; sumidx++)
                {
                    tmp += TBlock[row_threadblock * BLOCKSIZE + sumidx] * VBlock[sumidx * head_size + col_threadblock + offset];
                }
                
                O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] = (rowsum_old[i] / rowsum_new[i]) * exp(rowmax_old[i] - rowmax_new[i]) * O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] + tmp;
            }
            
            // update rowmax_old and rowsum_old
            rowmax_old[i] = rowmax_new[i];
            rowsum_old[i] = rowsum_new[i];

            // advance Q
            Q += (BLOCKSIZE * head_size);
        }

        // put Q_ptr back to the original position
        Q = Q_dst;

        // advance K and V
        K += (BLOCKSIZE * head_size);
        V += (BLOCKSIZE * head_size);
    }
}

// Q.shape = (num_heads, seq_len, head_size)
torch::Tensor flashattention_cuda_001(torch::Tensor Q, torch::Tensor K, torch::Tensor V)
{
    torch::Tensor output = torch::empty_like(Q);
    const int BLOCKSIZE = 32;
    int num_heads = Q.sizes()[0];
    int seq_len = Q.sizes()[1];
    int head_size = Q.sizes()[2];

    int shared_mem_size = 64 + 4 * (BLOCKSIZE * head_size * 3 + BLOCKSIZE * BLOCKSIZE);

    AT_DISPATCH_FLOATING_TYPES(Q.type(), "flashattention_cuda_001", [&]{
        flashattention_kernel<scalar_t><<<num_heads, BLOCKSIZE * BLOCKSIZE, shared_mem_size>>>(
            Q.data_ptr<scalar_t>(),
            K.data_ptr<scalar_t>(),
            V.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            seq_len, head_size, BLOCKSIZE
        );
    });
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("flashattention_cuda_001", &flashattention_cuda_001, "Flash attention cuda version 0.01");
}

setup.py

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name = 'flashattention_cuda',
    ext_modules = [
        CUDAExtension('flashattention_cuda', [
            'flashattention_cuda_001.cu',
        ]),
    ],
    cmdclass = {
        'build_ext': BuildExtension
    }
)

As you can see, I used dynamic shared memory at the beginning of the kernel function. When I run python setup.py install I get error:

FAILED: /home/zgt/cudac/flashattention/build/temp.linux-x86_64-3.10/flashattention_cuda_001.o 
/usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/zgt/cudac/flashattention/build/temp.linux-x86_64-3.10/flashattention_cuda_001.o.d -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include/TH -I/home/zgt/venv/cuda/lib/python3.10/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zgt/venv/cuda/include -I/usr/include/python3.10 -c -c /home/zgt/cudac/flashattention/flashattention_cuda_001.cu -o /home/zgt/cudac/flashattention/build/temp.linux-x86_64-3.10/flashattention_cuda_001.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flashattention_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 -std=c++17
/home/zgt/cudac/flashattention/flashattention_cuda_001.cu(11): error: declaration is incompatible with previous "share_mem" (declared at line 11)
      extern __attribute__((shared)) realtype share_mem[];
                                              ^
          detected during instantiation of "void flashattention_kernel(realtype *, realtype *, realtype *, realtype *, int, int, int) [with realtype=float]" at line 114

/home/zgt/cudac/flashattention/flashattention_cuda_001.cu(11): warning #20042-D: a host variable("share_mem") redeclared with __shared__
      extern __attribute__((shared)) realtype share_mem[];
                                              ^
          detected during instantiation of "void flashattention_kernel(realtype *, realtype *, realtype *, realtype *, int, int, int) [with realtype=float]" at line 114

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

According to the error infomation, I redeclared a host variable share_mem. Is that the wrong way to declare dynamic shared memory? Please tell me how to fix this, thank you so much.

multiple instantiation of a (e.g. templated) kernel with “similar” extern shared declaration can become problematic. I’m not suggesting this is identical to your problem; they are obviously somewhat different. But this outlines the general issue.

I don’t know exactly why you are getting multiple instantiations, but its clear that you are from the error message:

Perhaps your compilation process is instantiating it with a template argument of float (that is obviously the case) and also a template argument of some other type. I haven’t unpacked the torch/pybind stuff enough to determine any of that.

I guess I would suggest trying this:

extern __shared__ unsigned char my_smem[];
realtype *share_mem = (realtype *)my_smem;

and see if that makes the error disappear. If it does, you are barking up the right tree.

If that makes the error go away, if it were me, I would probably just proceed with that. However it raises questions around alignment and type-punning. If these are of concern to you, then you would need to come up with another realization.

EDIT: For the benefit of future readers in a time crunch, I have edited my response here to remove an errant usage of __shared__ which is mentioned below.

Another method to avoid this issue is demonstrated in CUDA sample codes eg. here

I managed to fix this shared memory issue thanks to your advice. Here is what I did.
At first I tried:

extern __shared__ unsigned char my_smem[];
__shared__ realtype *share_mem = (realtype *)my_smem;

which raised an error

error: initializer not allowed for __shared__ variable 
      __attribute__((shared)) realtype *share_mem = (realtype *)my_smem;
                                        ^

Obviously you can’t initialize a __shared__ variable. So I just removed the __shared__ attribute:

extern __shared__ unsigned char my_smem[];
realtype *share_mem = (realtype *)my_smem;

which turned out to be successfully compiled. Any way, thank you for your help.

Yes, I should not have included the __shared__ decorator on that line. That would have created a pointer that lived in (statically allocated) shared memory, which is not what’s wanted (and you can’t initialize it…)