Why does my actual measured count of shared memory load/store instructions differ from the theoretical count? How can I explain and verify this differ

"When analyzing the shared memory access instructions of my custom CUDA SGEMM kernel, I found that the load instructions number approximately 167,772,160, while the store instructions number 16,777,216. Theoretically, I expected the ratio to be 32:1, but in reality, it’s only 10:1. What could be causing this difference?

The store instructions are reasonable.

template <unsigned int BLOCK_SIZE, unsigned int STRIDE>
__global__ void cuda_sgemm(float *A_ptr, float *B_ptr, float *C_ptr, const int M, const int N, const int K)
{
    constexpr int STEP = BLOCK_SIZE * STRIDE;
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    float *A_ptr_start = A_ptr + STEP * blockIdx.y * K;
    float *B_ptr_start = B_ptr + STEP * blockIdx.x;

    __shared__ float A_shared[STEP][STEP];
    __shared__ float B_shared[STEP][STEP];
    float C_value[STRIDE][STRIDE] = {0.0f};
    for (int s = 0; s < K; s += STEP)
    {
        for (int i = 0; i < STRIDE; ++i) {
            for (int j = 0; j < STRIDE; ++j) {
                A_shared[ty + i * BLOCK_SIZE][tx + j * BLOCK_SIZE] =
                    A_ptr_start[(ty + i * BLOCK_SIZE) * K + (tx + j * BLOCK_SIZE) + s];
                B_shared[ty + i * BLOCK_SIZE][tx + j * BLOCK_SIZE] =
                    B_ptr_start[(ty + i * BLOCK_SIZE + s) * N + (tx + j * BLOCK_SIZE)];
            }
        }
        __syncthreads();
        //#pragma unroll
        for (int i = 0; i < STRIDE; ++i) {
            for (int j = 0; j < STRIDE; ++j) {
                for (int k = 0; k < STEP; ++k) {
                    C_value[i][j] += A_shared[ty + i * BLOCK_SIZE][k] * B_shared[k][tx + j * BLOCK_SIZE];
                }
            }
        }
        __syncthreads();
    }
    float *C_ptr_start = C_ptr + N * blockIdx.y * STEP + blockIdx.x * STEP;

    for (int i = 0; i < STRIDE; ++i) {
        for (int j = 0; j < STRIDE; ++j ) {
            C_ptr_start[(ty + i * BLOCK_SIZE) * N + (tx + j * BLOCK_SIZE)] = C_value[i][j];
        }
    }

}


void MMult_v1(int m, int n, int k, float *d_A,
              float *d_B, float *d_C) {
    constexpr int BLOCK = 16;
    const int STRIDE = 2;
    dim3 block(BLOCK, BLOCK);
    dim3 grid((m + BLOCK - 1) / BLOCK / STRIDE, (n + BLOCK - 1) / BLOCK / STRIDE);
    cuda_sgemm<BLOCK, STRIDE><<<grid, block>>>(d_A, d_B, d_C, m, n, k);
}

It’s being affected by the compiler loop unrolling (and apparently some instruction/load optimization that that affords). This is fairly typical modern compiler behavior.

According to my testing, if I place an unroll directive like this:

#pragma unroll 1
            for (int k = 0; k < STEP; ++k) {
                C_value[i][j] += A_shared[ty + i * BLOCK_SIZE][k] * B_shared[k][tx + j * BLOCK_SIZE];

Then I observe the 32:1 ratio that you expect.

If I omit that, or alternatively do:

#pragma unroll 32
            for (int k = 0; k < STEP; ++k) {
                C_value[i][j] += A_shared[ty + i * BLOCK_SIZE][k] * B_shared[k][tx + j * BLOCK_SIZE];

Then I observe the 10:1 ratio. I don’t offer a complete analysis, but I do observe that in the unroll 1 case, the SASS includes no instances of LDS.128 instruction.

In the unroll 32 case, the SASS includes 16 instances of the LDS.128 instruction. This will reduce the number of shared load instructions to some value lower than the unroll 1 case.

# cat t430.cu
template <unsigned int BLOCK_SIZE, unsigned int STRIDE>
__global__ void cuda_sgemm(float *A_ptr, float *B_ptr, float *C_ptr, const int M, const int N, const int K)
{
    constexpr int STEP = BLOCK_SIZE * STRIDE;
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    float *A_ptr_start = A_ptr + STEP * blockIdx.y * K;
    float *B_ptr_start = B_ptr + STEP * blockIdx.x;

    __shared__ float A_shared[STEP][STEP];
    __shared__ float B_shared[STEP][STEP];
    float C_value[STRIDE][STRIDE] = {0.0f};
    for (int s = 0; s < K; s += STEP)
    {
        for (int i = 0; i < STRIDE; ++i) {
            for (int j = 0; j < STRIDE; ++j) {
                A_shared[ty + i * BLOCK_SIZE][tx + j * BLOCK_SIZE] =
                    A_ptr_start[(ty + i * BLOCK_SIZE) * K + (tx + j * BLOCK_SIZE) + s];
                B_shared[ty + i * BLOCK_SIZE][tx + j * BLOCK_SIZE] =
                    B_ptr_start[(ty + i * BLOCK_SIZE + s) * N + (tx + j * BLOCK_SIZE)];
            }
        }
        __syncthreads();
        for (int i = 0; i < STRIDE; ++i) {
            for (int j = 0; j < STRIDE; ++j) {
#ifdef USE_32
        #pragma unroll 32
#else
        #pragma unroll 1
#endif
                for (int k = 0; k < STEP; ++k) {
                    C_value[i][j] += A_shared[ty + i * BLOCK_SIZE][k] * B_shared[k][tx + j * BLOCK_SIZE];
                }
            }
        }
        __syncthreads();
    }
    float *C_ptr_start = C_ptr + N * blockIdx.y * STEP + blockIdx.x * STEP;

    for (int i = 0; i < STRIDE; ++i) {
        for (int j = 0; j < STRIDE; ++j ) {
            C_ptr_start[(ty + i * BLOCK_SIZE) * N + (tx + j * BLOCK_SIZE)] = C_value[i][j];
        }
    }

}


void MMult_v1(int m, int n, int k, float *d_A,
              float *d_B, float *d_C) {
    constexpr int BLOCK = 16;
    const int STRIDE = 2;
    dim3 block(BLOCK, BLOCK);
    dim3 grid((m + BLOCK - 1) / BLOCK / STRIDE, (n + BLOCK - 1) / BLOCK / STRIDE);
    cuda_sgemm<BLOCK, STRIDE><<<grid, block>>>(d_A, d_B, d_C, m, n, k);
}

int main(){

  int m = 1024;
  int n = 1024;
  int k = 1024;
  float *d_A, *d_B, *d_C;
  cudaMalloc(&d_A, sizeof(d_A[0])*m*k);
  cudaMalloc(&d_B, sizeof(d_B[0])*k*n);
  cudaMalloc(&d_C, sizeof(d_C[0])*m*n);
  MMult_v1(m, n, k, d_A, d_B, d_C);
  cudaDeviceSynchronize();
}
# nvcc -o t430 t430.cu -arch=sm_89 -lineinfo
# cuobjdump -sass ./t430 |grep LDS.128
# nvcc -o t430 t430.cu -arch=sm_89 -lineinfo -DUSE_32
# cuobjdump -sass ./t430 |grep LDS.128
        /*0550*/                   LDS.128 R8, [R21] ;                               /* 0x0000000015087984 */
        /*0570*/                   LDS.128 R4, [R21+0x800] ;                         /* 0x0008000015047984 */
        /*0620*/                   LDS.128 R16, [R21+0x10] ;                         /* 0x0000100015107984 */
        /*0650*/                   LDS.128 R12, [R21+0x810] ;                        /* 0x00081000150c7984 */
        /*07b0*/                   LDS.128 R8, [R21+0x20] ;                          /* 0x0000200015087984 */
        /*07e0*/                   LDS.128 R4, [R21+0x820] ;                         /* 0x0008200015047984 */
        /*0930*/                   LDS.128 R12, [R21+0x30] ;                         /* 0x00003000150c7984 */
        /*0950*/                   LDS.128 R16, [R21+0x830] ;                        /* 0x0008300015107984 */
        /*0b00*/                   LDS.128 R8, [R21+0x840] ;                         /* 0x0008400015087984 */
        /*0b10*/                   LDS.128 R4, [R21+0x40] ;                          /* 0x0000400015047984 */
        /*0c90*/                   LDS.128 R12, [R21+0x50] ;                         /* 0x00005000150c7984 */
        /*0cc0*/                   LDS.128 R16, [R21+0x850] ;                        /* 0x0008500015107984 */
        /*0e10*/                   LDS.128 R4, [R21+0x60] ;                          /* 0x0000600015047984 */
        /*0e30*/                   LDS.128 R8, [R21+0x860] ;                         /* 0x0008600015087984 */
        /*0ff0*/                   LDS.128 R12, [R21+0x70] ;                         /* 0x00007000150c7984 */
        /*1000*/                   LDS.128 R16, [R21+0x870] ;                        /* 0x0008700015107984 */
#