Matrix multiplication: Two codes with similar assembly, but different performance?

Hello.

I’m writing matrix multiplication kernel, mostly learnt from: this nvidia blog and this blog.

I tried to match the performance of the warp tiled kernel provided in the latter blog, but not succeeded. To find out what was wrong, I deliberately modified my own code to the point that the SASS of two kernels are completely similar (except for the indices of the constant memory).

And somehow on my device (RTX 3060 12GB), my kernel is about 10% slower than the target kernel on 4096x4096 matrices.

This is my code after the modification:

#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>

constexpr  uint BLOCK_WORK_X = 128;
constexpr  uint BLOCK_WORK_Y = 128;
constexpr  uint THREAD_WORK_X = 8;
constexpr  uint THREAD_WORK_Y = 8;
constexpr  uint WARP_WORK_X = 32;
constexpr  uint WARP_WORK_Y = 64;
constexpr  uint WARP_X = 4;
constexpr  uint WARP_Y = 8;
constexpr  uint K = 8;
constexpr  uint BLOCK_SIZE = 256;

#define CUDA_CHECK (err) { \
cudaError_t result = (err); \
if (result != cudaSuccess) { \
    fprintf(stderr, "CUDA Error: %s in %s at line %d\n", \
        cudaGetErrorString(result), __FILE__, __LINE__); \
        exit(EXIT_FAILURE); \
    } \
}
__global__  __launch_bounds__(256) void matmul_general(float *A, float *B, float *C, int m, int k, int n) {
    uint tidx = threadIdx.x;
    const uint threadRow = (tidx % 32) / 4;
    const uint threadCol = (tidx % 32) % 4;

    const uint warpRow = (tidx / 32) / (BLOCK_WORK_X / WARP_WORK_X);
    const uint warpCol = (tidx / 32) % (BLOCK_WORK_X / WARP_WORK_X);

    const uint loadRowA = tidx / (K / 4);
    const uint loadColA = tidx % (K / 4);

    const uint loadRowB = tidx / (BLOCK_WORK_X / 4);
    const uint loadColB = tidx % (BLOCK_WORK_X / 4);
    
    const uint gridRow = blockIdx.x;
    const uint gridCol = blockIdx.y;


    A += gridRow * BLOCK_WORK_Y * k;
    B += gridCol * BLOCK_WORK_X;
    C += (warpRow * WARP_WORK_Y + gridRow * BLOCK_WORK_Y) * n + gridCol * BLOCK_WORK_X + warpCol * WARP_WORK_X;

    float regA[THREAD_WORK_Y] = {0}, regB[THREAD_WORK_X] = {0};
    float res[THREAD_WORK_Y][THREAD_WORK_X] = {0};

    float __shared__ sA[K * BLOCK_WORK_Y], sB[K * BLOCK_WORK_X];

    for (uint i = 0; i < k; i += K) {
         float4 loadA = reinterpret_cast< float4*>(&A[loadRowA * k + loadColA * 4])[0];
            // float4 loadA = A4[(loadRowA) * k / 4 + loadColA];
            sA[(loadColA * 4 + 0) * BLOCK_WORK_Y + loadRowA] = loadA.x;
            sA[(loadColA * 4 + 1) * BLOCK_WORK_Y + loadRowA] = loadA.y;
            sA[(loadColA * 4 + 2) * BLOCK_WORK_Y + loadRowA] = loadA.z;
            sA[(loadColA * 4 + 3) * BLOCK_WORK_Y + loadRowA] = loadA.w;
        
    
        reinterpret_cast<float4*>(&sB[(loadRowB) * BLOCK_WORK_X + loadColB * 4])[0] = 
        reinterpret_cast< float4*>(&B[loadRowB * n + loadColB * 4])[0];

       
        __syncthreads();

        for (uint p = 0; p < K; p++) {
            for (uint t = 0; t < 2; t++)
                for (uint y = 0; y < 4; y++) {
                    regA[t * 4 + y] = sA[p * BLOCK_WORK_Y  + warpRow * WARP_WORK_Y + t * 32 + threadRow * 4 + y];
                }

            for (uint t = 0; t < 2; t++)
                for (uint x = 0; x < 4; x++) {
                    regB[t * 4 + x] = sB[p * BLOCK_WORK_X + warpCol * WARP_WORK_X + t * 16 + threadCol * 4 + x];
                }
                

            for (uint t1 = 0; t1 < 2; t1++)
                for (uint t2 = 0; t2 < 2; t2++)
                    for (uint y = 0; y < 4; y++) {
                        for (uint x = 0; x < 4; x++) {
                            res[t1 * 4 + y][t2 * 4 + x] += regB[t2 * 4 + x] * regA[t1 * 4 + y];
                        }
                    }

        }
         
        A += K;
        B += K * n;
        __syncthreads();
    }

    for (uint t1 = 0; t1 < 2; t1++)
        for (uint t2 = 0; t2 < 2; t2++) {
            float* Ct = C + (t1 * 32) * n + (t2 * 16);
            for (uint y = 0; y < 4; y++) {
                for (uint x = 0; x < 4; x += 4) {
                    reinterpret_cast<float4*>(&Ct[(threadRow * 4 + y) * n + (threadCol * 4 +  x)])[0] = 
                        make_float4(
                            res[y + t1 * 4][t2 * 4 + 0],
                            res[y + t1 * 4][t2 * 4 + 1],
                            res[y + t1 * 4][t2 * 4 + 2],
                            res[y + t1 * 4][t2 * 4 + 3]
                        );
                }
            }
        }
}

and code I am comparing to, taken from the blog:

#include <cuda_runtime.h>
#include <stdio.h>

#define CUDA_CHECK (err) { \
cudaError_t result = (err); \
if (result != cudaSuccess) { \
    fprintf(stderr, "CUDA Error: %s in %s at line %d\n", \
        cudaGetErrorString(result), __FILE__, __LINE__); \
        exit(EXIT_FAILURE); \
    } \
}
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
const int WARPSIZE = 32; // warpSize is not constexpr

/*
 * @tparam BM The threadblock size for M dimension SMEM caching.
 * @tparam BN The threadblock size for N dimension SMEM caching.
 * @tparam BK The threadblock size for K dimension SMEM caching.
 * @tparam WM M dim of continuous tile computed by each warp
 * @tparam WN N dim of continuous tile computed by each warp
 * @tparam WMITER The number of subwarp tiling steps in M dimension.
 * @tparam WNITER The number of subwarp tiling steps in N dimension.
 * @tparam TM The per-thread tile size for M dimension.
 * @tparam TN The per-thread tile size for N dimension.
 */

const uint NUM_THREADS = 256;

__global__ void __launch_bounds__(NUM_THREADS)
    sgemmWarptiling(int M, int N, int K,float *A, float *B,
                   float *C) {
  const uint BN = 128;
  const uint BM = 128;
  const uint BK = 8;
  const uint WN = 32;
  const uint WM = 64;
  const uint WNITER = 2;
  const uint TN = 4;
  const uint TM = 4;

  const uint cRow = blockIdx.y;
  const uint cCol = blockIdx.x;

  // Placement of the warp in the threadblock tile
  const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in
  const uint warpCol = warpIdx % (BN / WN);
  const uint warpRow = warpIdx / (BN / WN);

  // size of the warp subtile
  constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
  constexpr uint WSUBM = WM / WMITER; // 64/2=32
  constexpr uint WSUBN = WN / WNITER; // 32/2=16

  // Placement of the thread in the warp subtile
  const uint threadIdxInWarp = threadIdx.x % WARPSIZE;         // [0, 31]
  const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4)
  const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4

  // allocate space for the current blocktile in SMEM
  __shared__ float As[BK * BM];
  __shared__ float Bs[BK * BN];

  // Move blocktile to beginning of A's row and B's column
  A += cRow * BM * K;
  B += cCol * BN;
  // Move C_ptr to warp's output tile
  C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN;

  // calculating the indices that this thread will load into SMEM
  // we'll load 128bit / 32bit = 4 elements per thread at each step
  const uint innerRowA = threadIdx.x / (BK / 4);
  const uint innerColA = threadIdx.x % (BK / 4);
  constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
  const uint innerRowB = threadIdx.x / (BN / 4);
  const uint innerColB = threadIdx.x % (BN / 4);
  constexpr uint rowStrideB = NUM_THREADS / (BN / 4);

  // allocate thread-local cache for results in registerfile
  float threadResults[WMITER * TM * WNITER * TN] = {0.0};
  // we cache into registers on the warptile level
  float regM[WMITER * TM] = {0.0};
  float regN[WNITER * TN] = {0.0};

  // outer-most loop over block tiles
  for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
    // for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
      float4 tmp = reinterpret_cast<float4 *>(
          &A[(innerRowA + 0) * K + innerColA * 4])[0];
      // transpose A while storing it
      As[(innerColA * 4 + 0) * BM + innerRowA ] = tmp.x;
      As[(innerColA * 4 + 1) * BM + innerRowA ] = tmp.y;
      As[(innerColA * 4 + 2) * BM + innerRowA ] = tmp.z;
      As[(innerColA * 4 + 3) * BM + innerRowA ] = tmp.w;
    // }

    // for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) {
      reinterpret_cast<float4 *>(
          &Bs[(innerRowB + 0) * BN + innerColB * 4])[0] =
          reinterpret_cast<float4 *>(
              &B[(innerRowB + 0) * N + innerColB * 4])[0];
    // }
    __syncthreads();

    for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
      // populate registers for whole warptile
      for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
        for (uint i = 0; i < TM; ++i) {
          regM[wSubRowIdx * TM + i] =
              As[dotIdx * BM + warpRow * WM + wSubRowIdx * WSUBM +
                 threadRowInWarp * TM + i];
        }
      }
      for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
        for (uint i = 0; i < TN; ++i) {
          regN[wSubColIdx * TN + i] =
              Bs[(dotIdx) * BN + warpCol * WN + wSubColIdx * WSUBN +
                 threadColInWarp * TN + i];
        }
      }

      // execute warptile matmul
      for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
        for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
          // calculate per-thread results
          for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
            for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
              threadResults[(wSubRowIdx * TM + resIdxM) * WNITER * TN +
                            (wSubColIdx * TN) + resIdxN] +=
                  regM[wSubRowIdx * TM + resIdxM] *
                  regN[wSubColIdx * TN + resIdxN];
            }
          }
        }
      }
    }
    A += BK;     // move BK columns to right
    B += BK * N; // move BK rows down
    __syncthreads();
  }

  // write out the results
  for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
    for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
      // move C pointer to current warp subtile
      float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN;
      for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
        for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
          // load C vector into registers
          float4 tmp ;
          const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
                        wSubColIdx * TN + resIdxN;
          tmp.x = threadResults[i + 0];
          tmp.y = threadResults[i + 1] ;
          tmp.z =  threadResults[i + 2] ;
          tmp.w =  threadResults[i + 3];
          // write back
          reinterpret_cast<float4 *>(
              &C_interim[(threadRowInWarp * TM + resIdxM) * N +
                         threadColInWarp * TN + resIdxN])[0] = tmp;
        }
      }
    }
  }
}

Profiling shows that my kernel has worse L2 hit. But the their assembly are similar. Can you help me verify the behaviour on your device. If you know what is going on, please tell me. Thank you in advance!

I found the issue, the problem is I accidentally swapped blockIdx.x and blockIdx.y:

const uint gridRow = blockIdx.x;
const uint gridCol = blockIdx.y;

where row should be y and col should be x. That means previously block [x, y] calculates the results for submatrix[y, x] instead of submatrix [x, y]. The results are still correct for square matrix.

But the question now is why this makes the kernel slower. I observed that blocks are launched in row-major order. That means in the correct kernel, consecutively launched blocks may have to load the same submatrices of A, which are likely to be found in L2. Similarly, in the wrong kernel, submatrices of B are instead shared between consecutively launched blocks.

But loading from A is more costly, because only 8 consecutive elements on each row of A is loaded each time, so there are multiple 32 bytes transactions issued. On the other hand loading from B is perfectly coalesce. My hypothesis is that caching more, smaller transactions is more beneficial than caching less, bigger transactions because less transactions, which actually need to fetch data from global memory, are issued.

To test it, I tried to load data from A in a perfectly coalesce way, ignoring the accuracy of the kernel:

reinterpret_cast<float4*>(&sA[(loadRowB) * BLOCK_WORK_X + loadColB * 4])[0] = 
reinterpret_cast<const float4*>(&A[loadRowB * n + loadColB * 4])[0];
      
reinterpret_cast<float4*>(&sB[(loadRowB) * BLOCK_WORK_X + loadColB * 4])[0] = 
reinterpret_cast<const float4*>(&B[loadRowB * n + loadColB * 4])[0];

This time both the correct and the wrong kernel have the same performance. Is this a plausible explanation for the difference?

In recent architectures 32 byte sector coalescing gives you the full bandwidth. 128 bytes only has an added advantage for the cache working set (roughly how much data can be stored in the cache) and has a performance influence if you preload the next data (e.g. for each 32 byte access load 64 bytes).

But I think there should still be some overhead for higher number of transactions? I set up a benchmark of copying data from a matrix to another. Each block in one kernel copies a 8x128 submatrix, and each block from the other kernel copies a 128x8 submatrix. The first kernel is still slightly faster (around 1%). Consider that in the matmul kernel, each load must wait for the previous load and calculation complete. So the slight overhead might add up because of the dependency. Currently, I don’t know how to verify this in the actual matmul kernel.

Yes, there can be 1% difference between 32 bytes and 128 bytes access. In the first post, you wrote about 10% difference (before correcting the swapped blockIdx.x and blockIdx.y). And if you would use 16 bytes instead of 32 bytes, it would be half (!) as fast (except, if the caches can compensate). So 1% is already fine optimization IMHO. But 1% could also come from a specific access pattern or as you were thinking about in what order the work is assigned.

I mean that 1% is solely the difference between 32 bytes and 128 bytes operations. In the matmul kernels, copying from the A matrix uses 32 bytes transactions and copying from B uses 128 bytes transactions. In the corrected kernel, the load from A is cached. So the kernel uses less “more expensive” (1%) loads. And that 1% may translate to 10% because of the dependency between each load. This is my reasoning behind the difference. What do you think?

If e.g. a memory transaction is a part of a kernel and two variants have 1% of a difference, then the difference is at a maximum 1% for the whole kernel.
(It is related to the so-called Amdahl’s Law.)

Try to use Nsight Compute to see the actual reason for the difference in execution time. There you see, what the threads are waiting for, the performance and hit rate of all the cache levels, the transaction sizes.

Actually, there are several dependent load operations. Because they are dependent, that 1% will add up to a larger figure.

Ncu showed that the wrong kernel has worse L2 read hit rate, 43% vs 48%, and more stall at global load from matrix A and less at global load from B, but still more stalls overall.

If the read hit rate is 43% vs 48%, the difference is much more than 1%. If for example L2 is 4x faster than global memory, then the difference is about 5% ((.43 / 4 + .57) / (.48 / 4 + .52)).

That computation instructions are dependent should not matter. If the kernel is memory-bound, then the memory bandwidth is the deciding factor. As soon as new data is loaded from memory, the compute units stop their stall.

I think the kernel is not memory-bound. In fact the throughput is around 100GB/s. But you are right, the dependency should not matter, my mistake.

From this paper, for my GPU, L2 is only 1.5x faster than global memory, maybe we can assume 2x if cache miss is considered for global memory. So the difference in global load should be smaller than what you calculated. Then L2 miss should not be the sole reason for the slow down, or am I missing something?

So perhaps your kernel is not memory-bound, and it is the latency, not the bandwidth?