A nice day for everyone.
I am trying to create a matrix multiplication for my No Saved Kaleidoscope academic coding language, and I want to adapt it to work with flash attention.
Here is what I done so far. I noticed that mma_sync leads to unexpected behaviors when the matrix dimensions go beyond the wmma 16x16 fragments expected size. This is likely due to load_matrix_sync. I believe that this operation tries to load a __half * tensor even if it is beyond the tensor size bounds. I solved this issue by loading tensors into shared memory and perming padding at smem.
Also, I read How to use WMMA efficiently and it feel to me that I should overlap smem loading with the load_matrix_sync and mma_sync, but I do not know if it this possible due to the necessary __syncthreads operations.
That said, is there a way I can avoid the memory loading involving the smem and still get the correct 0 padding? Are there any other optimizations I can make to turn it faster?
template<int WMMA_T, int X_WARPS, int Y_WARPS>
__global__ void wmma_mult_kernel(const float *x, const float *w,
float *out, const int B, const int C, const int OC) {
int laneId = ( threadIdx.y * blockDim.x + threadIdx.x) % warpSize;
int mw = laneId / WMMA_T;
int ml = laneId % WMMA_T;
int warp_y = threadIdx.y;
int warp_x = (threadIdx.x / 32);
const uint32_t warpX{(blockIdx.x * blockDim.x + threadIdx.x) / warpSize}; // OC
const uint32_t warpY{blockIdx.y * blockDim.y + threadIdx.y}; // B
// warpX = (oc*X_WARPS + warp_x)
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> x_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> w_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> y_frag;
wmma::fill_fragment(y_frag, 0.0f);
extern __shared__ float smem[];
float *out_smem = smem;
__half *hsmem = reinterpret_cast<__half*>(smem + Y_WARPS*WMMA_T*(X_WARPS*WMMA_T));
__half *x_smem = hsmem;
__half *w_smem = hsmem + Y_WARPS*WMMA_T*(WMMA_T);
#pragma unroll
for (int tile=0; tile<C; tile+=WMMA_T)
{
#pragma unroll
for (int i=0; i<2; ++i)
{
// warp * mw_size * i_size + mw*i_size + i
int row_aux1 = warp_x*((int)(warpSize/WMMA_T))*2 + mw*2+i;
int row_aux2 = warp_y*((int)(warpSize/WMMA_T))*2 + mw*2+i;
if (row_aux1<WMMA_T)
{
if ((warpY*WMMA_T+row_aux1)<B && (tile+ml)<C)
x_smem[(warp_y*WMMA_T+row_aux1)*WMMA_T + ml] = __float2half(*(x + (warpY*WMMA_T+row_aux1)*C + tile+ml));
else
x_smem[(warp_y*WMMA_T+row_aux1)*WMMA_T + ml] = 0;
}
if (row_aux2<WMMA_T)
{
if ((warpX*WMMA_T+row_aux2)<OC && (tile+ml)<C)
w_smem[(warp_x*WMMA_T+row_aux2)*WMMA_T + ml] = __float2half(*(w + (warpX*WMMA_T+row_aux2)*C + tile+ml));
else
w_smem[(warp_x*WMMA_T+row_aux2)*WMMA_T + ml] = 0;
}
}
__syncthreads();
if ((warpY*WMMA_T)<B && (warpX*WMMA_T)<OC)
{
wmma::load_matrix_sync(x_frag, x_smem+warp_y*WMMA_T*WMMA_T, WMMA_T);
wmma::load_matrix_sync(w_frag, w_smem+warp_x*WMMA_T*WMMA_T, WMMA_T);
wmma::mma_sync(y_frag, x_frag, w_frag, y_frag);
}
__syncthreads();
}
if ((warpY*WMMA_T)<B && (warpX*WMMA_T)<OC && (warp_y*WMMA_T)<B && (warp_x*WMMA_T)<OC)
{
float *_out = out_smem + warp_y*WMMA_T*(X_WARPS*WMMA_T) + warp_x*WMMA_T;
wmma::store_matrix_sync(_out, y_frag, X_WARPS*WMMA_T, wmma::mem_row_major);
__syncthreads();
#pragma unroll
for (int tile=0; tile<std::ceil((WMMA_T*WMMA_T)/(float)warpSize); ++tile)
{
int tile_idx = tile*warpSize + laneId;
int row = tile_idx / WMMA_T;
int col = tile_idx % WMMA_T;
if((warpY*WMMA_T+row)<B && (warpX*WMMA_T+col)<OC && row<WMMA_T)
out[(warpY*WMMA_T+row)*OC + warpX*WMMA_T+col] = _out[row*(X_WARPS*WMMA_T)+col];
}
}
}