2x slower kernel if the inner dimension is divsible by 16/32

Hi all,

I wrote a simple kernel that calculates the column sum for each row in matrix A (of shape MxK) and stores the results in matrix C (of shape MxN).

__global__ void simple_kernel(float* A, float* C, const int M, const int N, const int K) {
    const int row = threadIdx.x + blockDim.x * blockIdx.x;
    const int col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        float acc = 0.0;
        for (int p = 0; p < K; p++) {
            acc += A[row * K + p];
        }
        C[row * N + col] = acc;
    }
}

However, I’ve observed that the kernel’s performance drops significantly when the inner dimension K is divisible by 32 or 16, with a 2x slowdown in the former case and a 1.5x slowdown in the latter. I understand that the memory access pattern in the kernel is not optimal and lacks coalescing, but I’m curious about why the divisibility of the inner dimension K has such a notable impact on performance. Despite reviewing the CUDA Programming Guide and Kernel Profiling Guide, I haven’t found a clear explanation for this phenomenon. This issue isn’t related to memory bank conflicts, as the kernel doesn’t use shared memory. Additionally, it doesn’t seem to be explained by tile or wave quantization effects (please correct me if I’m mistaken). After profiling the kernel I noticed that if K is divisible by 32 this results in 2x-3x more data transfer between L1<->L2 and L2<->Device memory.


Still I don’t understand why this happens. I’m eager to deepen my understanding of CUDA and NVIDIA hardware. Any insights would be greatly appreciated! Thanks in advance!

The full code is listed below and was tested on RTX 3090:

#include <cstdlib>
#include <stdio.h>

#define NITER 1
#define BLOCKSIZE 16
#define MDIM BLOCKSIZE * 200
#define NDIM BLOCKSIZE * 200
// the kernel is 2-3x faster if KDIM not divisible by 16/32
// change KDIM to (BLOCKSIZE*200 + 1/2/3/4/5/6/7/8/9/10/11/12/13/14/15) to see the effect
#define KDIM BLOCKSIZE * 200

inline int cdiv(const int a, const int b) { return (a + b - 1) / b; }

void init_random(float* data, int size) {
    for (int i = 0; i < size; ++i)
        data[i] = rand() / (float)RAND_MAX;
}

void init_const(float* data, const int size, const float value) {
    for (int i = 0; i < size; ++i)
        data[i] = value;
}

__global__ void simple_kernel(float* A, float* C, const int M, const int N, const int K) {
    const int row = threadIdx.x + blockDim.x * blockIdx.x;
    const int col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        float acc = 0.0;
        for (int p = 0; p < K; p++) {
            acc += A[row * K + p];
        }
        C[row * N + col] = acc;
    }
}

int main() {
    const int M = MDIM, N = NDIM, K = KDIM;
    const int A_nelem = M * K;
    const int C_nelem = M * N;
    const int A_memsize = sizeof(float) * A_nelem;
    const int C_memsize = sizeof(float) * C_nelem;

    float *A_host, *C_host;
    cudaMallocHost(&A_host, A_memsize);
    cudaMallocHost(&C_host, C_memsize);
    init_random(A_host, A_nelem);
    init_const(C_host, C_nelem, 0.0);

    float *A_device, *C_device;
    cudaMalloc(&A_device, A_memsize);
    cudaMalloc(&C_device, C_memsize);
    cudaMemcpy(A_device, A_host, A_memsize, cudaMemcpyHostToDevice);
    cudaMemcpy(C_device, C_host, C_memsize, cudaMemcpyHostToDevice);

    cudaEvent_t start, stop;
    (cudaEventCreate(&start));
    (cudaEventCreate(&stop));

    dim3 threads(BLOCKSIZE, BLOCKSIZE);
    dim3 grid(cdiv(M, BLOCKSIZE), cdiv(N, BLOCKSIZE));

    float avg_elapsed_time_ms = 0.0;
    float min_elapsed_time_ms = 1e69;
    float max_elapsed_time_ms = 0.0;
    float elapsed_time_ms;

    for (int i = 0; i < NITER; i++) {
        cudaEventRecord(start);

        simple_kernel<<<grid, threads>>>(A_device, C_device, M, N, K);

        cudaEventRecord(stop);
        cudaEventSynchronize(stop);
        cudaEventElapsedTime(&elapsed_time_ms, start, stop);
        min_elapsed_time_ms =
            elapsed_time_ms < min_elapsed_time_ms ? elapsed_time_ms : min_elapsed_time_ms;
        max_elapsed_time_ms =
            elapsed_time_ms > max_elapsed_time_ms ? elapsed_time_ms : max_elapsed_time_ms;
        avg_elapsed_time_ms += elapsed_time_ms;
    }

    avg_elapsed_time_ms = avg_elapsed_time_ms / (float)NITER;
    const double FLOP = 2 * (double)M * (double)N * (double)K;
    const double GFLOP = FLOP * 1e-9f;
    const double GFLOPS_AVG = GFLOP / (avg_elapsed_time_ms * 1e-3);
    const double GFLOPS_MIN = GFLOP / (max_elapsed_time_ms * 1e-3);
    const double GFLOPS_MAX = GFLOP / (min_elapsed_time_ms * 1e-3);
    printf("AVG GFLOPS = %.2f\n", GFLOPS_AVG);
    printf("MAX GFLOPS = %.2f\n", GFLOPS_MAX);
    printf("MIN GFLOPS = %.2f\n", GFLOPS_MIN);
    printf("AVG exec. time = %.2fms\n", avg_elapsed_time_ms);
    printf("MAX exec. time = %.2fms\n", max_elapsed_time_ms);
    printf("MIN exec. time = %.2fms\n", min_elapsed_time_ms);

    cudaFreeHost(A_host);
    cudaFreeHost(C_host);
    cudaFree(A_device);
    cudaFree(C_device);
    cudaEventDestroy(start);
    cudaEventDestroy(stop);
    return 0;
}

Now look for possible causes of that. I would expect to find higher cache miss rates. If so, what kind of misses occur? From the three classes compulsory misses, capacity misses, and conflict misses, it is likely the last one. If so, compare memory access patterns with cache characteristics (in particular set associativity, line length, sectoring, replacement policy).

This kind of memory hierarchy analysis isn’t in any way specific to CUDA and GPUs.

@njuffa thanks for your answer! Yes, I see 2x-3x higher cache miss rates in both L1 and L2. Please correct me if Im mistaken, but Nvidia doesn’t officially provide information about cache internals such as set associativity, sectoring, replacement policy. Do you by any chance have some good resources on this topic to read?

NVIDIA provides some of that information. You may need to hunt for it in officials docs, then try to get the rest of it from whitepapers published by third parties that are based on running micro-benchmarks to tease out that information. Some additional information may be exposed through NVIDIA’s patent applications (check with your legal department before reading patents).

I am not claiming that this is an ideal state of affairs.

@njuffa you also mentioned different cache miss classes: compulsory, capacity and conflict. Is there a profiler that can classify cache misses?

Don’t know; off-hand it seems unlikely that a profiler based on sampling HW event counters could distinguish these, but I have never built a profiler. For GPUs, I usually stop at findings like “more data transfers” and have long given up on trying to develop a deep understanding of the hardware.

You could ask for further advice in the sub-forum(s) dedicated to the Nsight profiler. That is where the profiler experts are.

@njuffa Perhaps you could recommend to dangreen some best practices advice, how to solve or at least test for each of the cache miss classes. As you seem to have good insight (better than me into cache details):

E.g. one vanishes possibly with different data stride sizes, another with less overall working set, …

Memory is too hazy to provide recommendations. The last time I dealt with deep internals of memory sub-systems was when I was involved with building x86 processors at AMD, almost 25 years ago. We did a lot of modelling at various levels of fidelity to the hardware, and models are great for extracting details of cache miss behavior.

NVIDIA GPUs are a particular unthankful target for trying to figure out what is going on inside the hardware, as so little information is made publicly available. Other that on recent architectures cache lines are 128 bytes comprising four 32-byte sectors, I am not aware of any caching details.

When I needed to deal with memory optimizations while working at NVIDIA, I would contact one particular senior DevTech engineer who had spent a year working in the hardware architecture team and therefore was an expert on these matters. NVIDIA’s phenomenal DevTech organization has experts on just about any aspect of GPU compute performance, but outside of GPU Developer Conferences that is a resource that is mostly accessible to companies and research institutions, not individual CUDA programmers.

M=N=256, CUDA 12.2, L4 GPU

                                   Kernel duration (us)
threadblock dims                    K=256                K=255
                        L1:    enabled   disabled   enabled   disabled
      (32,32)                      261        594        76        607
      (16,16)                       82        232        29        234
      (32, 4)                      148        451        47        452
      ( 4,32)                       20         55        13         55

test case:

# cat t237.cu
template <int W>
__global__ void simple_kernel(float* A, float* C, const int M, const int N, const int K) {
    const int row = threadIdx.x + blockDim.x * blockIdx.x;
    const int col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        float acc = 0.0;
        for (int p = 0; p < K; p++) {
            acc += A[row * K + p];
        }
        C[row * N + col] = acc;
    }
}

const int my_M = 256;
const int my_K = 256;
const int my_N = 256;

int main(){

  float *A, *C;
  cudaMalloc(&A, my_M*my_K*sizeof(A[0]));
  cudaMalloc(&C, my_M*my_N*sizeof(C[0]));

  {
    dim3 block(32,32);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<0><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<1><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  {
    dim3 block(16,16);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<2><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<3><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  {
    dim3 block(32,4);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<4><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<5><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  {
    dim3 block(4,32);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<6><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<7><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  cudaDeviceSynchronize();
}
# nvcc -o t237 t237.cu
# compute-sanitizer ./t237
========= COMPUTE-SANITIZER
========= ERROR SUMMARY: 0 errors
# nsys profile --stats=true ./t237
Generating '/tmp/nsys-report-f027.qdstrm'
[1/8] [========================100%] report87.nsys-rep
[2/8] [========================100%] report87.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /root/bobc/report87.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)     Med (ns)    Min (ns)   Max (ns)    StdDev (ns)        Name
 --------  ---------------  ---------  ------------  -----------  --------  -----------  ------------  --------------
     49.8      172,597,170         11  15,690,651.8  3,029,676.0     7,117  100,130,612  29,526,036.1  poll
     49.0      169,622,884        476     356,350.6     14,957.5     1,034   79,206,725   3,648,460.3  ioctl
      0.6        2,040,196         27      75,562.8     12,693.0    10,462    1,261,226     238,299.2  mmap64
      0.2          768,760         44      17,471.8     16,549.5     7,040       25,315       4,681.2  open64
      0.1          317,206          9      35,245.1     34,460.0    25,265       57,787      10,015.4  sem_timedwait
      0.1          266,634          2     133,317.0    133,317.0   123,983      142,651      13,200.3  pthread_create
      0.1          192,860         31       6,221.3      6,191.0     2,102       18,822       3,203.7  fopen
      0.0          163,680         14      11,691.4      5,127.5     2,594       71,177      17,694.2  mmap
      0.0           91,274         48       1,901.5         67.0        58       87,907      12,678.0  fgets
      0.0           71,025         25       2,841.0      2,622.0     1,557        6,540       1,032.9  fclose
      0.0           63,357          6      10,559.5     11,642.5       340       15,857       5,345.7  fread
      0.0           60,115         53       1,134.2      1,050.0       717        4,090         451.3  fcntl
      0.0           37,747          6       6,291.2      6,397.5     2,665       10,948       2,748.2  open
      0.0           31,525          5       6,305.0      6,505.0     4,038        7,922       1,448.7  munmap
      0.0           28,279         10       2,827.9      2,637.0     1,430        5,552       1,076.6  write
      0.0           25,208         13       1,939.1      1,333.0     1,067        4,036         986.0  read
      0.0           19,050          2       9,525.0      9,525.0     6,455       12,595       4,341.6  socket
      0.0           14,562          1      14,562.0     14,562.0    14,562       14,562           0.0  connect
      0.0            9,295          1       9,295.0      9,295.0     9,295        9,295           0.0  pipe2
      0.0            6,698          7         956.9        897.0       837        1,333         171.4  dup
      0.0            2,601          1       2,601.0      2,601.0     2,601        2,601           0.0  bind
      0.0            1,498          1       1,498.0      1,498.0     1,498        1,498           0.0  listen

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)     StdDev (ns)            Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  -------------  ----------------------
     99.4      186,226,467          2  93,113,233.5  93,113,233.5     6,131  186,220,336  131,673,327.1  cudaMalloc
      0.4          668,160          8      83,520.0      25,185.0    20,648      491,815      165,050.1  cudaLaunchKernel
      0.3          512,357          1     512,357.0     512,357.0   512,357      512,357            0.0  cudaDeviceSynchronize
      0.0            1,849          1       1,849.0       1,849.0     1,849        1,849            0.0  cuModuleGetLoadingMode

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                             Name
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  -----------------------------------------------------------
     38.6          261,408          1  261,408.0  261,408.0   261,408   261,408          0.0  void simple_kernel<(int)0>(float *, float *, int, int, int)
     21.9          147,840          1  147,840.0  147,840.0   147,840   147,840          0.0  void simple_kernel<(int)4>(float *, float *, int, int, int)
     12.2           82,400          1   82,400.0   82,400.0    82,400    82,400          0.0  void simple_kernel<(int)2>(float *, float *, int, int, int)
     11.2           76,096          1   76,096.0   76,096.0    76,096    76,096          0.0  void simple_kernel<(int)1>(float *, float *, int, int, int)
      6.9           46,592          1   46,592.0   46,592.0    46,592    46,592          0.0  void simple_kernel<(int)5>(float *, float *, int, int, int)
      4.3           29,312          1   29,312.0   29,312.0    29,312    29,312          0.0  void simple_kernel<(int)3>(float *, float *, int, int, int)
      3.0           20,353          1   20,353.0   20,353.0    20,353    20,353          0.0  void simple_kernel<(int)6>(float *, float *, int, int, int)
      1.8           12,512          1   12,512.0   12,512.0    12,512    12,512          0.0  void simple_kernel<(int)7>(float *, float *, int, int, int)

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
SKIPPED: /root/bobc/report87.sqlite does not contain GPU memory data.
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
SKIPPED: /root/bobc/report87.sqlite does not contain GPU memory data.
Generated:
    /root/bobc/report87.nsys-rep
    /root/bobc/report87.sqlite
# nvcc -o t237 t237.cu -Xptxas -dlcm=cg
# nsys profile --stats=true ./t237
Generating '/tmp/nsys-report-1ccb.qdstrm'
[1/8] [========================100%] report88.nsys-rep
[2/8] [========================100%] report88.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /root/bobc/report88.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)    StdDev (ns)        Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  ------------  --------------
     55.8      300,206,377         13  23,092,798.2  10,376,152.0     7,537  100,135,268  35,695,659.7  poll
     39.8      214,011,861        476     449,604.8      15,295.0     1,025   79,363,354   4,202,963.9  ioctl
      3.6       19,515,763         33     591,386.8       6,262.0     2,020   19,228,085   3,345,628.3  fopen
      0.4        2,023,610         27      74,948.5      12,756.0    10,590    1,246,837     235,535.9  mmap64
      0.1          789,119         44      17,934.5      16,668.5     6,918       28,945       4,560.0  open64
      0.1          432,213          9      48,023.7      42,368.0    36,214       80,272      13,599.1  sem_timedwait
      0.1          273,863          2     136,931.5     136,931.5   115,264      158,599      30,642.5  pthread_create
      0.0          162,936         14      11,638.3       5,317.5     2,385       71,835      17,814.2  mmap
      0.0          102,595         26       3,946.0       3,403.5     1,512       13,175       2,563.6  fclose
      0.0           84,344         48       1,757.2          65.0        58       81,052      11,688.7  fgets
      0.0           67,908         55       1,234.7       1,040.0       728        4,808         719.5  fcntl
      0.0           50,675          4      12,668.8       9,316.5       490       31,552      14,099.8  fwrite
      0.0           43,345          6       7,224.2       6,738.5     2,548       11,776       3,558.2  open
      0.0           32,648         13       2,511.4       2,027.0     1,136        5,857       1,308.7  read
      0.0           32,634          6       5,439.0       5,659.0     2,732        8,362       2,092.1  munmap
      0.0           29,119         10       2,911.9       2,881.0     1,810        4,786         814.1  write
      0.0           28,650          5       5,730.0       5,458.0       228       16,267       6,534.9  fread
      0.0           18,589          2       9,294.5       9,294.5     5,335       13,254       5,599.6  socket
      0.0           16,088          1      16,088.0      16,088.0    16,088       16,088           0.0  connect
      0.0            9,720          1       9,720.0       9,720.0     9,720        9,720           0.0  pipe2
      0.0            6,700          7         957.1         919.0       898        1,098          77.5  dup
      0.0            3,870          1       3,870.0       3,870.0     3,870        3,870           0.0  fflush
      0.0            2,702          1       2,702.0       2,702.0     2,702        2,702           0.0  bind
      0.0            1,702          1       1,702.0       1,702.0     1,702        1,702           0.0  listen

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)    Max (ns)     StdDev (ns)            Name
 --------  ---------------  ---------  ------------  ------------  ---------  -----------  -------------  ----------------------
     59.7      185,525,232          2  92,762,616.0  92,762,616.0      5,878  185,519,354  131,177,836.9  cudaMalloc
     39.5      122,581,749          8  15,322,718.6      21,853.5     20,248  122,418,002   43,273,029.4  cudaLaunchKernel
      0.8        2,530,117          1   2,530,117.0   2,530,117.0  2,530,117    2,530,117            0.0  cudaDeviceSynchronize
      0.0            1,845          1       1,845.0       1,845.0      1,845        1,845            0.0  cuModuleGetLoadingMode

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                             Name
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  -----------------------------------------------------------
     22.7          607,361          1  607,361.0  607,361.0   607,361   607,361          0.0  void simple_kernel<(int)1>(float *, float *, int, int, int)
     22.2          594,625          1  594,625.0  594,625.0   594,625   594,625          0.0  void simple_kernel<(int)0>(float *, float *, int, int, int)
     16.8          451,552          1  451,552.0  451,552.0   451,552   451,552          0.0  void simple_kernel<(int)5>(float *, float *, int, int, int)
     16.8          450,753          1  450,753.0  450,753.0   450,753   450,753          0.0  void simple_kernel<(int)4>(float *, float *, int, int, int)
      8.7          234,496          1  234,496.0  234,496.0   234,496   234,496          0.0  void simple_kernel<(int)3>(float *, float *, int, int, int)
      8.7          232,159          1  232,159.0  232,159.0   232,159   232,159          0.0  void simple_kernel<(int)2>(float *, float *, int, int, int)
      2.1           55,360          1   55,360.0   55,360.0    55,360    55,360          0.0  void simple_kernel<(int)6>(float *, float *, int, int, int)
      2.0           54,816          1   54,816.0   54,816.0    54,816    54,816          0.0  void simple_kernel<(int)7>(float *, float *, int, int, int)

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
SKIPPED: /root/bobc/report88.sqlite does not contain GPU memory data.
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
SKIPPED: /root/bobc/report88.sqlite does not contain GPU memory data.
Generated:
    /root/bobc/report88.nsys-rep
    /root/bobc/report88.sqlite
#

The “lens” of the L1 seems important. I think for this work it would be good to also think about the tail-effect, which I have not done.

L4 has 48MB of L2 cache, and the data sizes here are A: 256KB, C: 256KB. So they are well below the L2 cache size, but significantly above the L1 available in a single SM (128KB).

Hi @Robert_Crovella! Thanks for your answer! May I ask what the purpose of the template parameter <int W> is? From what I see, it’s currently unused.

It’s so I can readily distinguish which kernel is which in the profiler output.

@Robert_Crovella got it! May I know what does “lens of the L1 cache” mean? Cache lines?

Yes, I agree tail effect is probably irrelevant.

I have also observed that for certain threadblock dimensions, the difference is much less than 2X:

# cat t237.cu
template <int W, bool cache=true>
__global__ void simple_kernel(const float* A, float* C, const int M, const int N, const int K) {
    const int row = threadIdx.x + blockDim.x * blockIdx.x;
    const int col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        float acc = 0.0;
        for (int p = 0; p < K; p++) {
                if (cache)
                  acc += __ldca(A+row * K + p);
                else
                  acc += __ldcg(A+row * K + p);
        }
        __stcg(C+row * N + col, acc);
    }
}
const int nBLK = 58;
const int nTY  =  6;
const int my_M = nBLK*nTY;
const int my_K = 256;
const int my_N = 256;

int main(){

  float *A, *C;
  cudaMalloc(&A, my_M*my_K*sizeof(A[0]));
  cudaMalloc(&C, my_M*my_N*sizeof(C[0]));

  {
    dim3 block(2,256);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<10><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<11><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  {
    dim3 block(2,256);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<10><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<11><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  {
    dim3 block(1,256);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<12><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<13><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  {
    dim3 block(1,256);
    dim3 grid((my_M+block.x-1)/block.x, (my_N+block.y-1)/block.y);
    simple_kernel<12><<<grid,block>>>(A, C, my_M, my_N, my_K);
    simple_kernel<13><<<grid,block>>>(A, C, my_M, my_N, my_K-1);
  }
  cudaDeviceSynchronize();
}
# nvcc -o t237 t237.cu -arch=sm_89
# nsys profile --stats=true ./t237
Generating '/tmp/nsys-report-cde6.qdstrm'
[1/8] [========================100%] report110.nsys-rep
[2/8] [========================100%] report110.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /root/bobc/report110.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)     Med (ns)    Min (ns)   Max (ns)    StdDev (ns)        Name
 --------  ---------------  ---------  ------------  -----------  --------  -----------  ------------  --------------
     55.7      242,341,680        476     509,121.2     15,123.5     1,028   78,839,663   4,972,162.2  ioctl
     39.5      171,656,292         11  15,605,117.5  3,092,482.0     7,984  100,137,445  29,532,787.8  poll
      3.8       16,688,380         29     575,461.4      5,732.0     2,132   16,503,382   3,063,376.7  fopen
      0.5        2,021,772         27      74,880.4     13,587.0    10,830    1,227,848     231,742.5  mmap64
      0.2          785,293         44      17,847.6     16,679.0     6,886       34,620       4,872.8  open64
      0.1          406,045          9      45,116.1     40,409.0    35,800       72,431      11,317.3  sem_timedwait
      0.1          264,262          2     132,131.0    132,131.0   117,244      147,018      21,053.4  pthread_create
      0.0          166,956         14      11,925.4      5,077.5     2,845       71,508      17,759.0  mmap
      0.0           83,401         48       1,737.5         66.5        58       80,022      11,539.8  fgets
      0.0           72,390         23       3,147.4      3,215.0     1,546        5,083         916.6  fclose
      0.0           54,942         51       1,077.3      1,050.0       726        1,837         199.3  fcntl
      0.0           41,188          6       6,864.7      6,553.0     2,622       10,872       3,101.7  open
      0.0           32,772         13       2,520.9      2,120.0     1,364        5,257       1,237.2  read
      0.0           32,553          5       6,510.6      5,861.0     3,537        9,517       2,300.6  munmap
      0.0           29,439         10       2,943.9      2,835.0     1,449        5,023       1,027.0  write
      0.0           18,770          2       9,385.0      9,385.0     5,370       13,400       5,678.1  socket
      0.0           18,477          1      18,477.0     18,477.0    18,477       18,477           0.0  fread
      0.0           15,551          1      15,551.0     15,551.0    15,551       15,551           0.0  connect
      0.0            9,587          1       9,587.0      9,587.0     9,587        9,587           0.0  pipe2
      0.0            6,481          7         925.9        912.0       812        1,048          82.1  dup
      0.0            2,636          1       2,636.0      2,636.0     2,636        2,636           0.0  bind
      0.0            1,680          1       1,680.0      1,680.0     1,680        1,680           0.0  listen

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)     StdDev (ns)            Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  -------------  ----------------------
     99.8      185,167,551          2  92,583,775.5  92,583,775.5     6,111  185,161,440  130,924,588.7  cudaMalloc
      0.2          367,995          8      45,999.4      14,902.0     5,387      261,326       87,665.0  cudaLaunchKernel
      0.0           34,169          1      34,169.0      34,169.0    34,169       34,169            0.0  cudaDeviceSynchronize
      0.0            1,772          1       1,772.0       1,772.0     1,772        1,772            0.0  cuModuleGetLoadingMode

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                     Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ---------------------------------------------------------------------------
     28.9           35,680          2  17,840.0  17,840.0    17,024    18,656      1,154.0  void simple_kernel<(int)10, (bool)1>(const float *, float *, int, int, int)
     25.0           30,880          2  15,440.0  15,440.0    14,848    16,032        837.2  void simple_kernel<(int)11, (bool)1>(const float *, float *, int, int, int)
     23.4           28,896          2  14,448.0  14,448.0    14,368    14,528        113.1  void simple_kernel<(int)13, (bool)1>(const float *, float *, int, int, int)
     22.6           27,840          2  13,920.0  13,920.0    13,888    13,952         45.3  void simple_kernel<(int)12, (bool)1>(const float *, float *, int, int, int)

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
SKIPPED: /root/bobc/report110.sqlite does not contain GPU memory data.
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
SKIPPED: /root/bobc/report110.sqlite does not contain GPU memory data.
Generated:
    /root/bobc/report110.nsys-rep
    /root/bobc/report110.sqlite
#

I didn’t mean anything by “lens” except that the L1 is playing a role, somehow. Since the L1 is smaller than the L2, and since the data set size is larger than a single L1 size, the L1 “focuses” attention in certain areas.

I don’t have an explanation, just mentioning observations here.

@Robert_Crovella Not only does the L1 cache impact performance, but the L2 cache does as well. As you can see from my images, data transfer between the L2 cache and main memory is almost twice as large if the inner dimension is divisible by 32. It would be helpful if a Nvidia engineer could comment on this. In any case, thank you for your answer and for your involvement!