Fastest Tiled WMMA for Matrices of Any Size?

A nice day for everyone.

I am trying to create a matrix multiplication for my No Saved Kaleidoscope academic coding language, and I want to adapt it to work with flash attention.

Here is what I done so far. I noticed that mma_sync leads to unexpected behaviors when the matrix dimensions go beyond the wmma 16x16 fragments expected size. This is likely due to load_matrix_sync. I believe that this operation tries to load a __half * tensor even if it is beyond the tensor size bounds. I solved this issue by loading tensors into shared memory and perming padding at smem.

Also, I read How to use WMMA efficiently and it feel to me that I should overlap smem loading with the load_matrix_sync and mma_sync, but I do not know if it this possible due to the necessary __syncthreads operations.

That said, is there a way I can avoid the memory loading involving the smem and still get the correct 0 padding? Are there any other optimizations I can make to turn it faster?

template<int WMMA_T, int X_WARPS, int Y_WARPS>
__global__ void wmma_mult_kernel(const float *x, const float *w,
                      float *out, const int B, const int C, const int OC) {

  int laneId = ( threadIdx.y * blockDim.x + threadIdx.x) % warpSize;
  int mw = laneId / WMMA_T;
  int ml = laneId % WMMA_T;

  int warp_y = threadIdx.y;
  int warp_x = (threadIdx.x / 32);


  const uint32_t warpX{(blockIdx.x * blockDim.x + threadIdx.x) / warpSize};  // OC
  const uint32_t warpY{blockIdx.y * blockDim.y + threadIdx.y};               // B

  // warpX = (oc*X_WARPS + warp_x)




  wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> x_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> w_frag;
  wmma::fragment<wmma::accumulator, 16, 16, 16, float> y_frag;



  wmma::fill_fragment(y_frag, 0.0f);


  extern __shared__ float smem[];
  float *out_smem = smem;
  __half *hsmem = reinterpret_cast<__half*>(smem + Y_WARPS*WMMA_T*(X_WARPS*WMMA_T));

  __half *x_smem     = hsmem;
  __half *w_smem     = hsmem + Y_WARPS*WMMA_T*(WMMA_T);
  
  
  

#pragma unroll
  for (int tile=0; tile<C; tile+=WMMA_T)
  {

#pragma unroll
    for (int i=0; i<2; ++i)
    {
      // warp * mw_size * i_size + mw*i_size + i
      int row_aux1 = warp_x*((int)(warpSize/WMMA_T))*2 + mw*2+i;
      int row_aux2 = warp_y*((int)(warpSize/WMMA_T))*2 + mw*2+i;
      
      if (row_aux1<WMMA_T)
      {
        if ((warpY*WMMA_T+row_aux1)<B && (tile+ml)<C)
          x_smem[(warp_y*WMMA_T+row_aux1)*WMMA_T + ml] = __float2half(*(x + (warpY*WMMA_T+row_aux1)*C + tile+ml));
        else
          x_smem[(warp_y*WMMA_T+row_aux1)*WMMA_T + ml] = 0;
      }

      if (row_aux2<WMMA_T)
      {
        if ((warpX*WMMA_T+row_aux2)<OC && (tile+ml)<C)
          w_smem[(warp_x*WMMA_T+row_aux2)*WMMA_T + ml] = __float2half(*(w + (warpX*WMMA_T+row_aux2)*C + tile+ml));
        else
          w_smem[(warp_x*WMMA_T+row_aux2)*WMMA_T + ml] = 0;
      }
    }
    

    __syncthreads();


    if ((warpY*WMMA_T)<B && (warpX*WMMA_T)<OC)
    {
      wmma::load_matrix_sync(x_frag, x_smem+warp_y*WMMA_T*WMMA_T, WMMA_T);
      wmma::load_matrix_sync(w_frag, w_smem+warp_x*WMMA_T*WMMA_T, WMMA_T);
      


      wmma::mma_sync(y_frag, x_frag, w_frag, y_frag);
    }
    
    __syncthreads();
  }


  if ((warpY*WMMA_T)<B && (warpX*WMMA_T)<OC && (warp_y*WMMA_T)<B && (warp_x*WMMA_T)<OC)
  { 

    float *_out = out_smem + warp_y*WMMA_T*(X_WARPS*WMMA_T) + warp_x*WMMA_T;
    wmma::store_matrix_sync(_out, y_frag, X_WARPS*WMMA_T, wmma::mem_row_major);

    __syncthreads();
    
#pragma unroll
    for (int tile=0; tile<std::ceil((WMMA_T*WMMA_T)/(float)warpSize); ++tile)
    {
      int tile_idx = tile*warpSize + laneId;

      int row = tile_idx / WMMA_T;
      int col = tile_idx % WMMA_T;

      if((warpY*WMMA_T+row)<B  &&  (warpX*WMMA_T+col)<OC && row<WMMA_T)
        out[(warpY*WMMA_T+row)*OC + warpX*WMMA_T+col] = _out[row*(X_WARPS*WMMA_T)+col];

    }
  }
}

You can always run mma manually (in an asm instruction). Look at the cutlass header files.

e.g.

  unsigned const & A = reinterpret_cast<unsigned const &>(a);
  unsigned const & B = reinterpret_cast<unsigned const &>(b);

  int const *C = reinterpret_cast<int const *>(&c);
  int *D = reinterpret_cast<int *>(&d);

  asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
      : "=r"(D[0]), "=r"(D[1])
      : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

in cutlass/include/cutlass/arch/mma_sm75.h at main · NVIDIA/cutlass · GitHub

It just takes the data from a local variable (in the end a register), which can be padded any way you want.
Detailed description in the PTX manual.

Thank you for your answer. By researching about some asm instructions, I was able to find some two articles of interest: WMMA Extension and Acceleration of Tensor-Product. However, they follow a pattern of global → shared → register. The last one specifies it pads tensors at the smem, and that direct load from global memory into registers is slower. They also explain a bit of bank conflicts at smem, which currently I could not implemente because I don’t know which pattern could satisfy both global memory loading and fragments storing.

Thus, on the mean time, I read Flash Attention 3 and I understood more about the pingpong schedule and wgmma. That said, my GPU is a RTX 4090, so I think I cannot run wgmma, but I think pingpong should work. Thus, I have implemented it (demonstrated at the code belown), but the pingpong wmma runs 100k 1024x1024x1024 matrix multiplications at 11 seconds, whereas my previous code (the first of this forum post) can do it in 8.9 seconds and cublasSgemm at fp32 uses only 3.5 seconds. Loading from global memory directly to registers with padding also took 11 seconds.

Code listing:

template<int WMMA_T, int X_WARPS, int Y_WARPS>
__global__ void wmma_pingpong(const float *x, const float *w,
                      float *out, const int B, const int C, const int OC) {

  int laneId = ( threadIdx.y * blockDim.x + threadIdx.x) % warpSize;
  int mw = laneId / WMMA_T;
  int ml = laneId % WMMA_T;

  int warp_y = threadIdx.y;
  int warp_x = (threadIdx.x / warpSize);

  int s=2;
  int circular_smem_counter=0;


  uint32_t warpX;                                                     // OC
  uint32_t warpY{blockIdx.y * blockDim.y + threadIdx.y};               // B

  // warpX = (oc*X_WARPS + warp_x)





  wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> x_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> w_frag;
  wmma::fragment<wmma::accumulator, 16, 16, 16, float> y_frag;

  //using FRAG_T = wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major>;

  wmma::fill_fragment(y_frag, 0.0f);


  extern __shared__ float smem[];
  float *out_smem = smem;
  __half *hsmem = reinterpret_cast<__half*>(smem + Y_WARPS*WMMA_T*(X_WARPS*WMMA_T));

  __half *x_smem_base  = hsmem;
  __half *w_smem_base  = hsmem + (Y_WARPS*WMMA_T)*WMMA_T;
  
  __half *x_smem, *w_smem;
  





  if (warp_x>=4)
  {
    warp_x-=4;
    warpX = (blockIdx.x*(blockDim.x/2))/warpSize + warp_x;
    
    
  
    for (int tile=0; tile<C; tile+=WMMA_T)
    {

      int tgt_smem = circular_smem_counter % s;


      x_smem = x_smem_base + tgt_smem*((Y_WARPS+X_WARPS)*WMMA_T*WMMA_T);
      w_smem = w_smem_base + tgt_smem*((Y_WARPS+X_WARPS)*WMMA_T*WMMA_T);

      

  
      for (int i=0; i<2; ++i)
      {
        // warp*mw_size*i_size + mw*i_size + i
        int row_aux1 = warp_x*((int)(warpSize/WMMA_T))*2 + mw*2+i;
        int row_aux2 = warp_y*((int)(warpSize/WMMA_T))*2 + mw*2+i;
        
        if (row_aux1<WMMA_T)
        {
          if ((warpY*WMMA_T+row_aux1)<B && (tile+ml)<C)
            x_smem[(warp_y*WMMA_T+row_aux1)*WMMA_T + ml] = __float2half(*(x + (warpY*WMMA_T+row_aux1)*C + tile+ml));
          else
            x_smem[(warp_y*WMMA_T+row_aux1)*WMMA_T + ml] = 0;
        }

        if (row_aux2<WMMA_T)
        {
          if ((warpX*WMMA_T+row_aux2)<OC && (tile+ml)<C)
            w_smem[(warp_x*WMMA_T+row_aux2)*WMMA_T + ml] = __float2half(*(w + (warpX*WMMA_T+row_aux2)*C + tile+ml));
          else
            w_smem[(warp_x*WMMA_T+row_aux2)*WMMA_T + ml] = 0;
        }
      }
      

      
      asm volatile("bar.sync 0, 1024;"); // producer waits consumer
      
      asm volatile("bar.arrive 1, 1024;"); // producer ends


      circular_smem_counter++;
      
      

      // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
      //   printf("Producer finished, tile: %d/%d.\n", tile, C);




      
      //asm volatile("bar.sync 2, 512;");
      // __syncthreads();

    }

    // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
    //   printf("Producer exits.\n");

    return; // return is a must, otherwise the if below bugs
  }
  else if (warp_x<4)
  {
    warpX = (blockIdx.x*(blockDim.x/2))/warpSize + warp_x;


    asm volatile("bar.arrive 0, 1024;");

  
    for (int tile=0; tile<C; tile+=WMMA_T)
    {


      int tgt_smem = circular_smem_counter % s;

      x_smem = x_smem_base + tgt_smem*((Y_WARPS+X_WARPS)*WMMA_T*WMMA_T);
      w_smem = w_smem_base + tgt_smem*((Y_WARPS+X_WARPS)*WMMA_T*WMMA_T);


      


      // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
      //   printf("\t\t\t\t\tConsumer wait %d.\n", tile);

      asm volatile("bar.sync 1, 1024;"); // consumer waits producer

      // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
      //   printf("\t\t\t\t\tConsumer go %d.\n", tile);




      if ((warpY*WMMA_T)<B && (warpX*WMMA_T)<OC)
      {


        wmma::load_matrix_sync(x_frag, x_smem+warp_y*WMMA_T*WMMA_T, WMMA_T);
        wmma::load_matrix_sync(w_frag, w_smem+warp_x*WMMA_T*WMMA_T, WMMA_T);



        wmma::mma_sync(y_frag, x_frag, w_frag, y_frag);


      }
      

      asm volatile("bar.arrive 0, 1024;"); // consumer ends

      // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
      //   printf("\t\t\t\t\tConsumer finished, tile: %d.\n", tile);
      
      
      circular_smem_counter++;
    }


    // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
    //   printf("\t\tConsumer exits.\n");


    if ((warpY*WMMA_T)<B && (warpX*WMMA_T)<OC && (warp_y*WMMA_T)<B && (warp_x*WMMA_T)<OC)
    { 
      // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
      //   printf("\t\t ------ NOW STORE OUTPUT ------\n");

      float *_out = out_smem + warp_y*WMMA_T*(X_WARPS*WMMA_T) + warp_x*WMMA_T;
      wmma::store_matrix_sync(_out, y_frag, X_WARPS*WMMA_T, wmma::mem_row_major);

      // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
      //   printf("\t\t ------ post wmma store ------\n");


      
      
  
      for (int tile=0; tile<std::ceil((WMMA_T*WMMA_T)/(float)warpSize); ++tile)
      {
        int tile_idx = tile*warpSize + laneId;

        int row = tile_idx / WMMA_T;
        int col = tile_idx % WMMA_T;
        


        // if ((blockIdx.y+ warp_y+laneId)==0)
        //   printf("warpX: %d\t warpX offset: %d\t OC: %d\n", warpX, warpX*WMMA_T, OC);
        // if ((blockIdx.x+ warp_x+laneId)==0)
        //   printf("warpY: %d, out: %f\n", warpY);

        if((warpY*WMMA_T+row)<B  &&  (warpX*WMMA_T+col)<OC  &&  row<WMMA_T)
          out[(warpY*WMMA_T+row)*OC + warpX*WMMA_T+col] = _out[row*(X_WARPS*WMMA_T)+col];

      }
    }
    // if ((blockIdx.x+blockIdx.y+warp_x+warp_y+laneId)==0)
    //     printf("\t\t ------ post tiled ------\n");
    
  }
}

[…]
Kernel call:


    constexpr int num_warps_x{4};
    constexpr int num_warps_y{4};
    

    constexpr int WMMA_T{16};
    dim3 block_size_pp(num_warps_x * WARP_SIZE*2, num_warps_y);
    dim3 grid_size(std::ceil((OC + (num_warps_x*WMMA_T - 1)) / (float)(num_warps_x*WMMA_T)), std::ceil((B + (num_warps_y*WMMA_T - 1)) / (float)(num_warps_y*WMMA_T)));

    
    
    int shared_mem_pp   = (num_warps_y*WMMA_T*WMMA_T*num_warps_x)*sizeof(float) + 2*(num_warps_x+num_warps_y)*WMMA_T*WMMA_T*sizeof(__half);
    
    wmma_pingpong<WMMA_T,num_warps_x,num_warps_y><<<grid_size, block_size_pp, shared_mem_pp, stream>>>(x->tensor_ptr, W, out, B, C, OC);

I have tested the accuracy of these models by training neural networks at a MLP on MNSIT and a IMDB Sentiment Analysis using attention (sometimes it evaluates to 50% of 2 classes with attention, but this is because the attention implementation has some flaws, but it also hits 82%). They have are around 0.5% accuracy of difference from their fp32 standart implementation.

Then, I need help with:

  • How can I load into shared memory in a way I avoid bank conflicts at writing time and at reading time (when I write to fragments)?
  • What am I missing at the pingpong schedule that makes it slower than its standard counterpart?

One trick to avoid bank conflicts at writing and reading time is not to write just one set of data at a time, but multiple in a for loop. Then it is possible to order reads and writes of the threads from and into shared memory differently between reading and writing.

To get matrix multiplications really fast, a lot of optimizations have to be made.
WMMA compared to MMA are the simpler functions. They mix accesses to shared memory with the actual computation. I would not expect to get the same or even better speed as cublas out of them.

Perhaps you want to achieve too much at once. Within cublas there are a lot of techniques, which one can learn and master, and even afterwards you would only start to optimize your program and try out different variations.