"invalid argument" while calling a customized cuda kernel function

I’m trying to implement an attention kernel through cuda c++. My kernel function is as below

typedef float realtype;

// flash attention version 1 is parallel over the head_num and batch size dimension
// kernel should be launched with config <<<([head_num]), (BLOCKSIZE * BLOCKSIZE)>>>
template<const int BLOCKSIZE, const int SEQLEN, const int HSIZE>
__global__ void flashattention_kernel(realtype *Q, realtype *K, realtype *V, realtype *O, int seq_len, int head_size)
{
    __shared__ realtype QBlock[BLOCKSIZE * HSIZE];
    __shared__ realtype KBlock[BLOCKSIZE * HSIZE];
    __shared__ realtype VBlock[BLOCKSIZE * HSIZE];
    __shared__ realtype TBlock[BLOCKSIZE * BLOCKSIZE];

    realtype * Q_dst = Q;

    realtype rowmax_old[SEQLEN / BLOCKSIZE] = {0.0}, rowmax_new[SEQLEN / BLOCKSIZE] = {0.0};
    realtype rowsum_old[SEQLEN / BLOCKSIZE] = {0.0}, rowsum_new[SEQLEN / BLOCKSIZE] = {0.0};

    int head_id = blockIdx.x;

    
    int row_threadblock = threadIdx.x / BLOCKSIZE;
    int col_threadblock = threadIdx.x % BLOCKSIZE;
    
    for(int j = 0; j < seq_len; j += BLOCKSIZE)
    {
        // load K, V into KBlock, VBlock (head_id should be taken into consideration)
        for(int offset = 0; offset < head_size; offset+=BLOCKSIZE)
        {
            KBlock[row_threadblock * head_size + col_threadblock + offset] = K[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
            VBlock[row_threadblock * head_size + col_threadblock + offset] = V[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
        }
        for(int i = 0; i < SEQLEN / BLOCKSIZE; i++)
        {
            // load Q into QBlock (head_id should be taken into consideration)
            for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
            {
                QBlock[row_threadblock * head_size + col_threadblock + offset] = Q[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
            }

            // matmul Q and K.T and store the result into TBlock
            // non-coherent access to shared memory?
            realtype tmp = 0.0;
            for(int k = 0; k < head_size; k++)
            {
                tmp += QBlock[row_threadblock * head_size + k] * KBlock[col_threadblock * head_size + k];
            }
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = tmp / sqrtf(_Float32(seq_len));
            // printf("%d %d, %f\n", row_threadblock, col_threadblock, TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);

            // calculate row max
            rowmax_new[i] = -6666.66;
            for(int idx = 0; idx < BLOCKSIZE; idx++)
            {
                if(TBlock[row_threadblock * BLOCKSIZE + idx] >= rowmax_new[i])
                {
                    rowmax_new[i] = TBlock[row_threadblock * BLOCKSIZE + idx];
                }
            }
            if(rowmax_old[i] >= rowmax_new[i] && j != 0)
            {
                rowmax_new[i] = rowmax_old[i];
            }
            // 
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = exp(TBlock[row_threadblock * BLOCKSIZE + col_threadblock] - rowmax_new[i]);
            rowsum_new[i] = rowsum_old[i] * exp(rowmax_old[i] - rowmax_new[i]);
            // printf("%d %d, %f, %f, %f, %f, %f\n", row_threadblock, col_threadblock, rowmax_new[i], rowmax_old[i], rowsum_new[i], rowsum_old[i], TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);
            for(int idx = 0; idx < BLOCKSIZE; idx++)
            {
                rowsum_new[i] += TBlock[row_threadblock * BLOCKSIZE + idx];
            }
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] /= rowsum_new[i];
            // printf("%d %d, %f, %f, %f, %f, %f\n", row_threadblock, col_threadblock, rowmax_new[i], rowmax_old[i], rowsum_new[i], rowsum_old[i], TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);

            // calculate OBlock
            for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
            {
                tmp = 0.0;
                for(int sumidx = 0; sumidx < BLOCKSIZE; sumidx++)
                {
                    tmp += TBlock[row_threadblock * BLOCKSIZE + sumidx] * VBlock[sumidx * head_size + col_threadblock + offset];
                }
                
                O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] = (rowsum_old[i] / rowsum_new[i]) * exp(rowmax_old[i] - rowmax_new[i]) * O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] + tmp;
            }
            
            // update rowmax_old and rowsum_old
            rowmax_old[i] = rowmax_new[i];
            rowsum_old[i] = rowsum_new[i];

            // advance Q
            Q += (BLOCKSIZE * head_size);
        }

        // put Q_ptr back to the original position
        Q = Q_dst;

        // advance K and V
        K += (BLOCKSIZE * head_size);
        V += (BLOCKSIZE * head_size);
    }
}

And the main.cu

int main()
{
    const int seq_len =1024, head_size = 384, head_num = 1;
    cudaError_t err;

    realtype *Q, *K, *V, *S, *O, *O_CUDA;
    realtype *DEVICE_Q, *DEVICE_K, *DEVICE_V, *DEVICE_O;

    Q = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    K = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    V = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    S = (realtype *)malloc(sizeof(realtype) * seq_len * seq_len * head_num);
    O = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    O_CUDA = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);

    generateRandomMatrix(Q, seq_len * head_num, head_size, 0);
    generateRandomMatrix(K, seq_len * head_num, head_size, 1);
    generateRandomMatrix(V, seq_len * head_num, head_size, 2);
    

    CHECK(cudaMalloc(&DEVICE_Q, sizeof(realtype) * seq_len * head_size * head_num));
    CHECK(cudaMalloc(&DEVICE_K, sizeof(realtype) * seq_len * head_size * head_num));
    CHECK(cudaMalloc(&DEVICE_V, sizeof(realtype) * seq_len * head_size * head_num));
    CHECK(cudaMalloc(&DEVICE_O, sizeof(realtype) * seq_len * head_size * head_num));

    CHECK(cudaMemcpy(DEVICE_Q, Q, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(DEVICE_K, K, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(DEVICE_V, V, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
    

    cudaEvent_t start_event, stop_event;
    float elapsed_time = 0.0f;
    cudaEventCreate(&start_event);
    cudaEventCreate(&stop_event);


    //call cpu attention and record elapsed time
    cudaEventRecord(start_event);

    for(int i = 0; i < head_num; i++)
    {
        int offset1 = seq_len * head_size * i;
        int offset2 = seq_len * seq_len * i;
        cpu_attention(Q + offset1, K + offset1, V + offset1, S + offset2, O + offset1, seq_len, head_size);
    }
    

    cudaEventRecord(stop_event);
    cudaEventSynchronize(stop_event);
    cudaEventElapsedTime(&elapsed_time, start_event, stop_event);
    printf("Elapsed time in cpu function: %.4fms\n", elapsed_time);

    //call flash attention kernel and record elapsed time
    cudaEventRecord(start_event);


    err = cudaGetLastError();
    if (err != cudaSuccess) { 
        std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; 
        exit(EXIT_FAILURE); 
    } 

    const int BLOCKSIZE = 8;
    flashattention_kernel<BLOCKSIZE, seq_len, head_size><<<head_num, dim3(BLOCKSIZE * BLOCKSIZE)>>>(DEVICE_Q, DEVICE_K, DEVICE_V, DEVICE_O, seq_len, head_size);

    err = cudaGetLastError();
    if (err != cudaSuccess) { 
        std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; 
        exit(EXIT_FAILURE); 
    } 


    cudaDeviceSynchronize();
    cudaEventRecord(stop_event);
    cudaEventSynchronize(stop_event);
    cudaEventElapsedTime(&elapsed_time, start_event, stop_event);
    printf("Elapsed time in cuda kernel: %.4fms\n", elapsed_time);

    cudaMemcpy(O_CUDA, DEVICE_O, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyDeviceToHost);

    valMatrix(O, O_CUDA, seq_len * head_num, head_size);

    cudaFree(DEVICE_Q);
    cudaFree(DEVICE_K);
    cudaFree(DEVICE_V);
    cudaFree(DEVICE_O);

    free(Q);
    free(K);
    free(V);
    free(S);
    free(O);
    free(O_CUDA);
    return 0;
}

Functions like generateRandomMatrx are designed to generate a random matrix and valMatrix to verify that if two matrix are the same. These functions are not kernel functions and work well in other code. I don’t think they started the error so I will not bother you with them. The thing is if I launch the kernel function with BLOCKSIZE=16 I get issue CUDA error: invalid argument at main.cu:78:

int main()
{
    const int seq_len =1024, head_size = 384, head_num = 1;
    cudaError_t err;

    realtype *Q, *K, *V, *S, *O, *O_CUDA;
    realtype *DEVICE_Q, *DEVICE_K, *DEVICE_V, *DEVICE_O;

   ...

    const int BLOCKSIZE = 16;
    flashattention_kernel<BLOCKSIZE, seq_len, head_size><<<head_num, dim3(BLOCKSIZE * BLOCKSIZE)>>>(DEVICE_Q, DEVICE_K, DEVICE_V, DEVICE_O, seq_len, head_size);

    err = cudaGetLastError();
    if (err != cudaSuccess) { 
        std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; 
        exit(EXIT_FAILURE); 
    } 
    ...
}

BTW, this error pops out during runtime, not complation, and only pops out when I set variable BLOCKSIZE=16. And I am positive the error comes from flashattention_kernel because of cudaGetLastError(). If I set variable BLOCKSIZE=8 or set head_size=96, no error and results from cpu and cuda are identical.
And for your information, my device is a Tesla v100 with 16GB memory and compiler is nvcc. See nvcc -V below.

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0

Any help will be appreciated!

Please provide a complete reproducer that can be compiled and executed.

utils.cuh

#ifndef UTILS4CUDA
#define UTILS4CUDA
#include <cuda.h>
#include <cstdlib>
#include <stdio.h>

#define CHECK(err) { \
    cudaError_t e = err; \
    if (e != cudaSuccess) { \
        std::cerr << "CUDA error: " << cudaGetErrorString(e) << " at " << __FILE__ << ":" << __LINE__ << std::endl; \
        exit(EXIT_FAILURE); \
    } \
}

#define CDIV(a, b) ((a + b - 1) / b)

typedef float realtype;

void generateRandomMatrix(realtype *A, int m, int n, long long seed = 0)
{
    srand(seed);
    for(int i = 0; i < m; i++)
    {
        for(int j = 0; j < n; j++)
        {
            A[i * n + j] = float(rand() % 1000) / 1000;
        }
    }
}

void onesMatrix(realtype *A, int m, int n, realtype value = 1.0)
{
    for(int i = 0; i < m * n; i++)
    {
        A[i] = value;
    }
}

void print_matrix(realtype *A, int m, int n)
{
    printf("\n=================================================================\n");
    for(int i = 0; i < m; i++)
    {
        for(int j = 0; j < n; j++)
        {
            printf("%f, ", A[i * n + j]);
        }
        printf("\n");
    }
    printf("=================================================================\n");
}

void valMatrix(realtype *A, realtype *B, int m, int n)
{
    // float err = 0.0f;
    for(int i = 0; i < m; i++)
    {
        for(int j = 0; j < n; j++)
        {
            if(fabs(A[i * n + j] - B[i * n + j]) >= 1e-4)
            {
                printf("Difference occurred at index [%d, %d]\n", i, j);
                return ;
            }
        }
    }
    printf("Matrix Vertified\n");
}

#endif

fakernel.cuh

#ifndef FLASHATTENTION_KERNEL_CUH_
#define FLASHATTENTION_KERNEL_CUH_

#include <cuda.h>

typedef float realtype;

// flash attention version 1 is parallel over the head_num and batch size dimension
// kernel should be launched with config <<<([head_num]), (BLOCKSIZE * BLOCKSIZE)>>>
template<const int BLOCKSIZE, const int SEQLEN, const int HSIZE>
__global__ void flashattention_kernel(realtype *Q, realtype *K, realtype *V, realtype *O, int seq_len, int head_size)
{
    __shared__ realtype QBlock[BLOCKSIZE * HSIZE];
    __shared__ realtype KBlock[BLOCKSIZE * HSIZE];
    __shared__ realtype VBlock[BLOCKSIZE * HSIZE];
    __shared__ realtype TBlock[BLOCKSIZE * BLOCKSIZE];

    realtype * Q_dst = Q;

    realtype rowmax_old[SEQLEN / BLOCKSIZE] = {0.0}, rowmax_new[SEQLEN / BLOCKSIZE] = {0.0};
    realtype rowsum_old[SEQLEN / BLOCKSIZE] = {0.0}, rowsum_new[SEQLEN / BLOCKSIZE] = {0.0};

    int head_id = blockIdx.x;

    
    int row_threadblock = threadIdx.x / BLOCKSIZE;
    int col_threadblock = threadIdx.x % BLOCKSIZE;
    
    for(int j = 0; j < seq_len; j += BLOCKSIZE)
    {
        // load K, V into KBlock, VBlock (head_id should be taken into consideration)
        for(int offset = 0; offset < head_size; offset+=BLOCKSIZE)
        {
            KBlock[row_threadblock * head_size + col_threadblock + offset] = K[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
            VBlock[row_threadblock * head_size + col_threadblock + offset] = V[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
        }
        for(int i = 0; i < SEQLEN / BLOCKSIZE; i++)
        {
            // load Q into QBlock (head_id should be taken into consideration)
            for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
            {
                QBlock[row_threadblock * head_size + col_threadblock + offset] = Q[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
            }

            // matmul Q and K.T and store the result into TBlock
            // non-coherent access to shared memory?
            realtype tmp = 0.0;
            for(int k = 0; k < head_size; k++)
            {
                tmp += QBlock[row_threadblock * head_size + k] * KBlock[col_threadblock * head_size + k];
            }
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = tmp / sqrtf(_Float32(seq_len));
            // printf("%d %d, %f\n", row_threadblock, col_threadblock, TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);

            // calculate row max
            rowmax_new[i] = -6666.66;
            for(int idx = 0; idx < BLOCKSIZE; idx++)
            {
                if(TBlock[row_threadblock * BLOCKSIZE + idx] >= rowmax_new[i])
                {
                    rowmax_new[i] = TBlock[row_threadblock * BLOCKSIZE + idx];
                }
            }
            if(rowmax_old[i] >= rowmax_new[i] && j != 0)
            {
                rowmax_new[i] = rowmax_old[i];
            }
            // 
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = exp(TBlock[row_threadblock * BLOCKSIZE + col_threadblock] - rowmax_new[i]);
            rowsum_new[i] = rowsum_old[i] * exp(rowmax_old[i] - rowmax_new[i]);
            // printf("%d %d, %f, %f, %f, %f, %f\n", row_threadblock, col_threadblock, rowmax_new[i], rowmax_old[i], rowsum_new[i], rowsum_old[i], TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);
            for(int idx = 0; idx < BLOCKSIZE; idx++)
            {
                rowsum_new[i] += TBlock[row_threadblock * BLOCKSIZE + idx];
            }
            TBlock[row_threadblock * BLOCKSIZE + col_threadblock] /= rowsum_new[i];
            // printf("%d %d, %f, %f, %f, %f, %f\n", row_threadblock, col_threadblock, rowmax_new[i], rowmax_old[i], rowsum_new[i], rowsum_old[i], TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);

            // calculate OBlock
            for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
            {
                tmp = 0.0;
                for(int sumidx = 0; sumidx < BLOCKSIZE; sumidx++)
                {
                    tmp += TBlock[row_threadblock * BLOCKSIZE + sumidx] * VBlock[sumidx * head_size + col_threadblock + offset];
                }
                
                O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] = (rowsum_old[i] / rowsum_new[i]) * exp(rowmax_old[i] - rowmax_new[i]) * O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] + tmp;
            }
            
            // update rowmax_old and rowsum_old
            rowmax_old[i] = rowmax_new[i];
            rowsum_old[i] = rowsum_new[i];

            // advance Q
            Q += (BLOCKSIZE * head_size);
        }

        // put Q_ptr back to the original position
        Q = Q_dst;

        // advance K and V
        K += (BLOCKSIZE * head_size);
        V += (BLOCKSIZE * head_size);
    }
}


void cpu_attention(realtype *Q, realtype *K, realtype *V, realtype *S, realtype *O, int seq_len, int head_size)
{
    // matrix multiplication for Q and K.T
    for(int i = 0; i < seq_len; i++)
    {
        for(int j = 0; j < seq_len; j++)
        {
            realtype tmp = 0.0;
            for(int k = 0; k < head_size; k++)
            {
                tmp += Q[i * head_size + k] * K[j * head_size + k];
            }
            S[i * seq_len + j] = tmp / sqrt(seq_len);
        }
    }
    // for(int i = 0; i < 384; i++)
    // {
    //     printf("%f ", S[i]);
    // }
    // softmax for matrix S
    for(int row = 0; row < seq_len; row++)
    {
        realtype rowmax = -9999.9;
        for(int col = 0; col < seq_len; col++)
        {
            if(S[row * seq_len + col] >= rowmax)
            {
                rowmax = S[row * seq_len + col];
            }
        }
        realtype rowsum = 0.0;
        for(int col = 0; col < seq_len; col++)
        {
            S[row * seq_len + col] = exp(S[row * seq_len + col] - rowmax);
            rowsum += S[row * seq_len + col];
        }
        for(int col = 0; col < seq_len; col++)
        {
            S[row * seq_len + col] /= rowsum;
        }
    }
    // for(int i = 0; i < 384; i++)
    // {
    //     printf("%f ", S[i]);
    // }

    // matrix multiplication for S and V
    for(int i = 0; i < seq_len; i++)
    {
        for(int j = 0; j < head_size; j++)
        {
            realtype tmp = 0.0;
            for(int k = 0; k < seq_len; k++)
            {
                tmp += S[i * seq_len + k] * V[k * head_size + j];
            }
            O[i * head_size + j] = tmp;
        }
    }
}

#endif

main.cu

#include <cuda.h>
#include <iostream>
#include "fakernel.cuh"
#include "utils.cuh"

int main()
{
    const int seq_len =1024, head_size = 384, head_num = 1;
    // const int seq_len = 1024, head_size = 384, head_num = 1;
    cudaError_t err;

    realtype *Q, *K, *V, *S, *O, *O_CUDA;
    realtype *DEVICE_Q, *DEVICE_K, *DEVICE_V, *DEVICE_O;

    Q = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    K = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    V = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    S = (realtype *)malloc(sizeof(realtype) * seq_len * seq_len * head_num);
    O = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
    O_CUDA = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);

    generateRandomMatrix(Q, seq_len * head_num, head_size, 0);
    generateRandomMatrix(K, seq_len * head_num, head_size, 1);
    generateRandomMatrix(V, seq_len * head_num, head_size, 2);

    // onesMatrix(Q, seq_len, head_size);
    // onesMatrix(K, seq_len, head_size);
    // onesMatrix(V, seq_len, head_size);
    

    CHECK(cudaMalloc(&DEVICE_Q, sizeof(realtype) * seq_len * head_size * head_num));
    CHECK(cudaMalloc(&DEVICE_K, sizeof(realtype) * seq_len * head_size * head_num));
    CHECK(cudaMalloc(&DEVICE_V, sizeof(realtype) * seq_len * head_size * head_num));
    CHECK(cudaMalloc(&DEVICE_O, sizeof(realtype) * seq_len * head_size * head_num));

    CHECK(cudaMemcpy(DEVICE_Q, Q, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(DEVICE_K, K, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(DEVICE_V, V, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
    

    cudaEvent_t start_event, stop_event;
    float elapsed_time = 0.0f;
    cudaEventCreate(&start_event);
    cudaEventCreate(&stop_event);


    //call cpu attention and record elapsed time
    cudaEventRecord(start_event);

    for(int i = 0; i < head_num; i++)
    {
        int offset1 = seq_len * head_size * i;
        int offset2 = seq_len * seq_len * i;
        cpu_attention(Q + offset1, K + offset1, V + offset1, S + offset2, O + offset1, seq_len, head_size);
    }
    

    cudaEventRecord(stop_event);
    cudaEventSynchronize(stop_event);
    cudaEventElapsedTime(&elapsed_time, start_event, stop_event);
    printf("Elapsed time in cpu function: %.4fms\n", elapsed_time);

    //call flash attention kernel and record elapsed time
    cudaEventRecord(start_event);


    err = cudaGetLastError();
    if (err != cudaSuccess) { 
        std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; 
        exit(EXIT_FAILURE); 
    } 

    const int BLOCKSIZE = 16;
    flashattention_kernel<BLOCKSIZE, seq_len, head_size><<<head_num, dim3(BLOCKSIZE * BLOCKSIZE)>>>(DEVICE_Q, DEVICE_K, DEVICE_V, DEVICE_O, seq_len, head_size);

    err = cudaGetLastError();
    if (err != cudaSuccess) { 
        std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; 
        exit(EXIT_FAILURE); 
    } 



    cudaDeviceSynchronize();
    cudaEventRecord(stop_event);
    cudaEventSynchronize(stop_event);
    cudaEventElapsedTime(&elapsed_time, start_event, stop_event);
    printf("Elapsed time in cuda kernel: %.4fms\n", elapsed_time);

    cudaMemcpy(O_CUDA, DEVICE_O, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyDeviceToHost);

    valMatrix(O, O_CUDA, seq_len * head_num, head_size);
    // for(int i = 0; i < 16; i++)
    // {
    //     printf("%f %f %f %f\n", Q[i + 100], K[i + 100], V[i + 100], S[i + 100]);
    // }

    // print_matrix(Q, seq_len, head_size);
    // print_matrix(K, seq_len, head_size);
    // print_matrix(V, seq_len, head_size);
    // print_matrix(O, seq_len, head_size);
    // print_matrix(O_CUDA, seq_len, head_size);


    cudaFree(DEVICE_Q);
    cudaFree(DEVICE_K);
    cudaFree(DEVICE_V);
    cudaFree(DEVICE_O);

    free(Q);
    free(K);
    free(V);
    free(S);
    free(O);
    free(O_CUDA);
    return 0;
}

compilation command

nvcc -o main  main.cu
nvcc -Xptxas "-v" -arch=native -g -lineinfo main.cu -o main
ptxas info    : 0 bytes gmem
ptxas info    : Compiling entry function '_Z21flashattention_kernelILi16ELi1024ELi384EEvPfS0_S0_S0_ii' for 'sm_70'
ptxas info    : Function properties for _Z21flashattention_kernelILi16ELi1024ELi384EEvPfS0_S0_S0_ii
    512 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 72 registers, used 0 barriers, 512 bytes cumulative stack size, 74752 bytes smem, 392 bytes cmem[0]

Your kernel uses 74752 bytes of shared memory. It is not allowed to declare this amount as static shared memory in the kernel. You need to use dynamic shared memory, and use cudaFuncSetAttribute for your kernel to allow a shared memory size > 48kb.

Thank you so much. Problem fixed by using dynamic shared memory.
I was gonna set BLOCKSIZE = 32, but I got a kind of uses too much shared data error during compilation. So I thought this uses too much shared data error only pops out during the compilation. Guess I should be more careful managing shared memory.