I’m trying to implement an attention kernel through cuda c++. My kernel function is as below
typedef float realtype;
// flash attention version 1 is parallel over the head_num and batch size dimension
// kernel should be launched with config <<<([head_num]), (BLOCKSIZE * BLOCKSIZE)>>>
template<const int BLOCKSIZE, const int SEQLEN, const int HSIZE>
__global__ void flashattention_kernel(realtype *Q, realtype *K, realtype *V, realtype *O, int seq_len, int head_size)
{
__shared__ realtype QBlock[BLOCKSIZE * HSIZE];
__shared__ realtype KBlock[BLOCKSIZE * HSIZE];
__shared__ realtype VBlock[BLOCKSIZE * HSIZE];
__shared__ realtype TBlock[BLOCKSIZE * BLOCKSIZE];
realtype * Q_dst = Q;
realtype rowmax_old[SEQLEN / BLOCKSIZE] = {0.0}, rowmax_new[SEQLEN / BLOCKSIZE] = {0.0};
realtype rowsum_old[SEQLEN / BLOCKSIZE] = {0.0}, rowsum_new[SEQLEN / BLOCKSIZE] = {0.0};
int head_id = blockIdx.x;
int row_threadblock = threadIdx.x / BLOCKSIZE;
int col_threadblock = threadIdx.x % BLOCKSIZE;
for(int j = 0; j < seq_len; j += BLOCKSIZE)
{
// load K, V into KBlock, VBlock (head_id should be taken into consideration)
for(int offset = 0; offset < head_size; offset+=BLOCKSIZE)
{
KBlock[row_threadblock * head_size + col_threadblock + offset] = K[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
VBlock[row_threadblock * head_size + col_threadblock + offset] = V[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
}
for(int i = 0; i < SEQLEN / BLOCKSIZE; i++)
{
// load Q into QBlock (head_id should be taken into consideration)
for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
{
QBlock[row_threadblock * head_size + col_threadblock + offset] = Q[(row_threadblock + head_id * seq_len) * head_size + col_threadblock + offset];
}
// matmul Q and K.T and store the result into TBlock
// non-coherent access to shared memory?
realtype tmp = 0.0;
for(int k = 0; k < head_size; k++)
{
tmp += QBlock[row_threadblock * head_size + k] * KBlock[col_threadblock * head_size + k];
}
TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = tmp / sqrtf(_Float32(seq_len));
// printf("%d %d, %f\n", row_threadblock, col_threadblock, TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);
// calculate row max
rowmax_new[i] = -6666.66;
for(int idx = 0; idx < BLOCKSIZE; idx++)
{
if(TBlock[row_threadblock * BLOCKSIZE + idx] >= rowmax_new[i])
{
rowmax_new[i] = TBlock[row_threadblock * BLOCKSIZE + idx];
}
}
if(rowmax_old[i] >= rowmax_new[i] && j != 0)
{
rowmax_new[i] = rowmax_old[i];
}
//
TBlock[row_threadblock * BLOCKSIZE + col_threadblock] = exp(TBlock[row_threadblock * BLOCKSIZE + col_threadblock] - rowmax_new[i]);
rowsum_new[i] = rowsum_old[i] * exp(rowmax_old[i] - rowmax_new[i]);
// printf("%d %d, %f, %f, %f, %f, %f\n", row_threadblock, col_threadblock, rowmax_new[i], rowmax_old[i], rowsum_new[i], rowsum_old[i], TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);
for(int idx = 0; idx < BLOCKSIZE; idx++)
{
rowsum_new[i] += TBlock[row_threadblock * BLOCKSIZE + idx];
}
TBlock[row_threadblock * BLOCKSIZE + col_threadblock] /= rowsum_new[i];
// printf("%d %d, %f, %f, %f, %f, %f\n", row_threadblock, col_threadblock, rowmax_new[i], rowmax_old[i], rowsum_new[i], rowsum_old[i], TBlock[row_threadblock * BLOCKSIZE + col_threadblock]);
// calculate OBlock
for(int offset = 0; offset < head_size; offset += BLOCKSIZE)
{
tmp = 0.0;
for(int sumidx = 0; sumidx < BLOCKSIZE; sumidx++)
{
tmp += TBlock[row_threadblock * BLOCKSIZE + sumidx] * VBlock[sumidx * head_size + col_threadblock + offset];
}
O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] = (rowsum_old[i] / rowsum_new[i]) * exp(rowmax_old[i] - rowmax_new[i]) * O[(row_threadblock + i * BLOCKSIZE + head_id * seq_len) * head_size + col_threadblock + offset] + tmp;
}
// update rowmax_old and rowsum_old
rowmax_old[i] = rowmax_new[i];
rowsum_old[i] = rowsum_new[i];
// advance Q
Q += (BLOCKSIZE * head_size);
}
// put Q_ptr back to the original position
Q = Q_dst;
// advance K and V
K += (BLOCKSIZE * head_size);
V += (BLOCKSIZE * head_size);
}
}
And the main.cu
int main()
{
const int seq_len =1024, head_size = 384, head_num = 1;
cudaError_t err;
realtype *Q, *K, *V, *S, *O, *O_CUDA;
realtype *DEVICE_Q, *DEVICE_K, *DEVICE_V, *DEVICE_O;
Q = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
K = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
V = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
S = (realtype *)malloc(sizeof(realtype) * seq_len * seq_len * head_num);
O = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
O_CUDA = (realtype *)malloc(sizeof(realtype) * seq_len * head_size * head_num);
generateRandomMatrix(Q, seq_len * head_num, head_size, 0);
generateRandomMatrix(K, seq_len * head_num, head_size, 1);
generateRandomMatrix(V, seq_len * head_num, head_size, 2);
CHECK(cudaMalloc(&DEVICE_Q, sizeof(realtype) * seq_len * head_size * head_num));
CHECK(cudaMalloc(&DEVICE_K, sizeof(realtype) * seq_len * head_size * head_num));
CHECK(cudaMalloc(&DEVICE_V, sizeof(realtype) * seq_len * head_size * head_num));
CHECK(cudaMalloc(&DEVICE_O, sizeof(realtype) * seq_len * head_size * head_num));
CHECK(cudaMemcpy(DEVICE_Q, Q, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
CHECK(cudaMemcpy(DEVICE_K, K, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
CHECK(cudaMemcpy(DEVICE_V, V, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyHostToDevice));
cudaEvent_t start_event, stop_event;
float elapsed_time = 0.0f;
cudaEventCreate(&start_event);
cudaEventCreate(&stop_event);
//call cpu attention and record elapsed time
cudaEventRecord(start_event);
for(int i = 0; i < head_num; i++)
{
int offset1 = seq_len * head_size * i;
int offset2 = seq_len * seq_len * i;
cpu_attention(Q + offset1, K + offset1, V + offset1, S + offset2, O + offset1, seq_len, head_size);
}
cudaEventRecord(stop_event);
cudaEventSynchronize(stop_event);
cudaEventElapsedTime(&elapsed_time, start_event, stop_event);
printf("Elapsed time in cpu function: %.4fms\n", elapsed_time);
//call flash attention kernel and record elapsed time
cudaEventRecord(start_event);
err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl;
exit(EXIT_FAILURE);
}
const int BLOCKSIZE = 8;
flashattention_kernel<BLOCKSIZE, seq_len, head_size><<<head_num, dim3(BLOCKSIZE * BLOCKSIZE)>>>(DEVICE_Q, DEVICE_K, DEVICE_V, DEVICE_O, seq_len, head_size);
err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl;
exit(EXIT_FAILURE);
}
cudaDeviceSynchronize();
cudaEventRecord(stop_event);
cudaEventSynchronize(stop_event);
cudaEventElapsedTime(&elapsed_time, start_event, stop_event);
printf("Elapsed time in cuda kernel: %.4fms\n", elapsed_time);
cudaMemcpy(O_CUDA, DEVICE_O, sizeof(realtype) * seq_len * head_size * head_num, cudaMemcpyDeviceToHost);
valMatrix(O, O_CUDA, seq_len * head_num, head_size);
cudaFree(DEVICE_Q);
cudaFree(DEVICE_K);
cudaFree(DEVICE_V);
cudaFree(DEVICE_O);
free(Q);
free(K);
free(V);
free(S);
free(O);
free(O_CUDA);
return 0;
}
Functions like generateRandomMatrx
are designed to generate a random matrix and valMatrix
to verify that if two matrix are the same. These functions are not kernel functions and work well in other code. I don’t think they started the error so I will not bother you with them. The thing is if I launch the kernel function with BLOCKSIZE=16
I get issue CUDA error: invalid argument at main.cu:78
:
int main()
{
const int seq_len =1024, head_size = 384, head_num = 1;
cudaError_t err;
realtype *Q, *K, *V, *S, *O, *O_CUDA;
realtype *DEVICE_Q, *DEVICE_K, *DEVICE_V, *DEVICE_O;
...
const int BLOCKSIZE = 16;
flashattention_kernel<BLOCKSIZE, seq_len, head_size><<<head_num, dim3(BLOCKSIZE * BLOCKSIZE)>>>(DEVICE_Q, DEVICE_K, DEVICE_V, DEVICE_O, seq_len, head_size);
err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl;
exit(EXIT_FAILURE);
}
...
}
BTW, this error pops out during runtime, not complation, and only pops out when I set variable BLOCKSIZE=16
. And I am positive the error comes from flashattention_kernel
because of cudaGetLastError()
. If I set variable BLOCKSIZE=8
or set head_size=96
, no error and results from cpu and cuda are identical.
And for your information, my device is a Tesla v100 with 16GB memory and compiler is nvcc. See nvcc -V
below.
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
Any help will be appreciated!