Matrix transpose perfomance profile explanation

Hey I’m trying to profile some community matrix transpose implementation to better understand metrics in ncu, both input and output matrix are stored in row-major format (and bind to PyTorch tensors).

Part I

Impl A

The first implementation is:

__global__ void mat_transpose_f32_col2row_kernel(
  float *x, float *y, const int row, const int col) {
  const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int global_row = global_idx / col;
  const int global_col = global_idx % col;
  const int in_idx = global_idx;
  const int out_idx = global_col * row + global_row;
  if (global_idx < row * col) {
    y[out_idx] = x[in_idx];
  }
}

, which does coalesced reads on input matrix but a strided writes on output matrix.

Impl B

__global__ void mat_transpose_f32_row2col_kernel(
  float *x, float *y, const int row, const int col) {
  const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int global_col = global_idx / row;
  const int global_row = global_idx % row;
  const int in_idx = global_row * col + global_col;
  const int out_idx = global_idx;
  if (global_idx < row * col) {
    y[out_idx] = x[in_idx];
  }
}

The second does it in the opposite way, strided reads on input matrix and coalesced writes on output matrix.

Description

I expected them to have similar performance (although they are both suboptimal), however, the Impl B is notably faster than Impl A (2x faster).

DRAM

By observing the DRAM profile, I saw something weird. Both kernels did 4 MiB read from DRAM, but Impl B (strided reads) achieved a higher throughput than Impl A (as below, Impl A is Baseline and Impl B is Current):

So my Q1 will be why strided reads here yielded higher DRAM throughput than coalesced reads?

I also tried to interpret L2 and L1 cache statitics.

L2

As for L2 cache, it’s understandable that strided read/write will lead to more sector reads and lower request efficiency. One metric I cannot understand is L2 Fabric Total: why would Impl A lead to cache misses in l2 partition and how would this lead to overall performance impact?

L1

Comparing L1 and L2 cache, I expected the global_load_sectors * (1 - hit_rate) == l1_load_sectors (check red and blue and boxes), but it seems there is still a gap in between. Meanwhile, the Impl B l1 cache hit rate drops dramatically compared with Impl A, why does it behave like this?

Part II

Impl C

I also tried to implement Impl A using 2D indices (Impl C) and did coalesced row reads, I expected Impl A and Impl C to have the same performance or even the same saas code, as the mapping between thread and elements is the same. However, I spot similar issue as in Part I. Could you please provide some insights?

__global__ void mat_transpose_f32_col2row2d_kernel(
  float *x, float *y, const int row, const int col) {
  const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
  const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
  int in_idx = global_y * col + global_x;
  int out_idx = global_x * row + global_y;
  if (global_x < col && global_y < row) {
    y[out_idx] = x[in_idx];
  }
}

Impl D

Just for curiosity, I tried to implement the above using CuTe (but I didn’t use any advanced features), and none of the above issues appear; the kernel is even faster than plain CUDA C. How could this be possible and does NVCC provide special optimization for CuTe?

template <typename T, int BLK_M, int BLK_N, typename ThreadLayoutA,
          typename ThreadLayoutB>
__global__ void mat_transpose_cute_reg_kernel(const T *pA, T *pB, int M, int N,
                                              ThreadLayoutA tA,
                                              ThreadLayoutB tB) {
  int tx = threadIdx.x;
  int bx = blockIdx.x, by = blockIdx.y;

  auto mA =
      make_tensor(make_gmem_ptr(pA),
                  make_layout(make_shape(M, N), GenRowMajor{}));  // (M, N)
  auto mB =
      make_tensor(make_gmem_ptr(pB),
                  make_layout(make_shape(N, M), GenRowMajor{}));  // (N, M)

  auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
                       make_coord(bx, by));  // (BM, BN)
  auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
                       make_coord(by, bx));  // (BN, BM)
  auto cA = local_tile(make_identity_tensor(mA.shape()),
                       make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
                       make_coord(bx, by));  // (BM, BN)

  Tensor tAgA = local_partition(gA, tA, tx);
  Tensor tBgB = local_partition(gB, tB, tx);
  Tensor tAcA = local_partition(cA, tA, tx);

  Tensor tApA = make_tensor<bool>(tAcA.shape(), tAcA.stride());
  CUTE_UNROLL
  for (int i = 0; i < size<0>(tApA); i++) {
    CUTE_UNROLL
    for (int j = 0; j < size<1>(tApA); j++) {
      tApA(i, j) = get<0>(tAcA(i, j)) < M && get<1>(tAcA(i, j)) < N;
    }
  }
  copy_if(tApA, tAgA, tBgB);
}

void mat_transpose_cute_row2col_reg(torch::Tensor x, torch::Tensor y) {
  const int BM = UNIT_BLK_SIZE;
  const int BN = UNIT_BLK_SIZE;
  const int M = x.size(0);
  const int N = x.size(1);
  auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenColMajor{});
  auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{});
  static_assert(size(tA) == size(tB));
  dim3 block(size(tA));
  dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
  mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
      <<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
  CUDA_CHECK(cudaGetLastError());
}

void mat_transpose_cute_col2row_reg(torch::Tensor x, torch::Tensor y) {
  const int BM = UNIT_BLK_SIZE;
  const int BN = UNIT_BLK_SIZE;
  const int M = x.size(0);
  const int N = x.size(1);
  auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
  auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
  static_assert(size(tA) == size(tB));
  dim3 block(size(tA));
  dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
  mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
      <<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
  CUDA_CHECK(cudaGetLastError());
}

Profile and code

Sorry that my questions might seem dumb and not providing enough details. The original codes and profiles are here in case you need them:
code_and_profile.zip (4.2 MB)

@Robert_Crovella Any insights?

I guess you are calling this from torch. Anyway I don’t see a main() function or anything I can test. I don’t doubt that this could be tested from torch or python but I won’t have the time to spend on that.

I’m guessing you are on a A100 GPU, in which case a 4MB working set is going to fit in L2 cache. Because of that, we can expect (and your results suggest) that we can load the data once, without worrying about eviction.

With that in mind, a general principle with memory-bound codes is that to the extent possible, you want to get global loads kicked off in the memory pipe as soon as possible.

Questions I have for you: Which of your implementations (load row-wise col2row vs. load column-wise row2col) does that? All of the input data needs to be loaded; which one gets the load(s) scheduled as early as possible? Also, if throughput is defined as data loaded per unit time, could scheduling all loads as early as possible perhaps increase the throughput?

That last question is a bit unfair, because it presumes an outcome. However, if the load mechanism results in higher throughput for loading, then, on a memory bound code, we can presume that that results in a shorter duration. And the shorter duration is causing the throughput measured to be higher. The answer to the higher throughput question is the same as the answer to the performance question. We should not assume that higher throughput the way nsight defines it is the reason for the higher performance, rather I would reverse those from a causality standpoint. But, if we leave aside the nsight definition, and instead imagine that getting the loads scheduled as early as possible results in earliest possible retrieval of the data, then that is also a definition of higher throughput, and I think it is possible that that viewpoint could be part of the explanation for why the uncoalesced load version is faster.

All of my musings are predicated on the idea that the working set fits in the L2 cache. For a working set that is much much larger than the L2 cache, things might be different.

As near as I can tell for Part II, you’re saying “I reformulated the indexing, but the row/column load directions are still the same, and I observed the same thing: loading column-wise (uncoalesced) was faster than loading row-wise” If that is a correct interpretation, then the fact that your part II observations match your part I observations is not surprising to me.

@Robert_Crovella Thanks for your wonderful explanation!

Sorry, I forgot to attach the kernel launching code and you can find it here: code_and_profile_updated.zip (4.2 MB)

Which of your implementations (load row-wise col2row vs. load column-wise row2col) does that? All of the input data needs to be loaded; which one gets the load(s) scheduled as early as possible? Also, if throughput is defined as data loaded per unit time, could scheduling all loads as early as possible perhaps increase the throughput?

That logically makes sense to me, the kernel exiting time will be the last STG, which depends on the previous LDG. Thus, the kernel that kicked off LDG as early as possible should finish faster.

But from what I saw from the SASS comparison:

It seems that col2row launched the load instruction earlier (I guess NVCC did something here, cuz I think both kernels should be the same), which contradicts the intuition.

You can’t just look at the ordering of the instruction itself. You also have to consider which sectors a particular instruction is fetching, and how those fetches are triggered program wise. This is not just an instruction ordering issue. The column wise load triggers the load of many sectors per warp. The row-wise (coalesced load) triggers the load of relatively few sectors per warp. You also need to consider this in light of when all the LDG instructions actually issue for the entire duration of the program. That’s not easy to do with the profiler and therefore my comments are speculative.

@Robert_Crovella I see, it seems like the analysis will be much more complex if I try to figure out every detail of this. Thanks for your explanation and patience!

I also measured about a 2:1 kernel duration ratio on a L4 GPU for a 2kx2k matrix transpose using your first 2 kernels (16MB matrix size, 50MB L2 cache). I did my measurement with nsys. AFAIK the profilers invalidate the caches at the start of a kernel launch (by default - this can be modified in nsight compute) which means that global loads will be an important factor.

You could start to prove or disprove my theory either by measuring in the profiler (ncu) with cache invalidation disabled or do a fairly careful experiment with ordinary host-based timing where you “invalidate” the L2 cache by running (or not running) an intervening kernel that loaded a different patch of memory. This could be interesting data points.

If my theory holds water, then my expectation is that with cache invalidation, the global load is going to be an important factor, leading to the 2:1 perf ratio. If the cache invalidation doesn’t happen, and the transpose kernel starts with the L2 cache fully populated with the input matrix, I’m expecting a noticeably smaller difference in performance between the two kernels. I haven’t tried it yet, though.

It seems that with invalidated caches, the performance of both kernels is similar:

# cat t372.cu
#include <iostream>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

    timeval tv;
    gettimeofday(&tv, 0);
    return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}

__global__ void mat_transpose_f32_col2row_kernel(
  float *x, float *y, const int row, const int col) {
  const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int global_row = global_idx / col;
  const int global_col = global_idx % col;
  const int in_idx = global_idx;
  const int out_idx = global_col * row + global_row;
  if (global_idx < row * col) {
    y[out_idx] = x[in_idx];
  }
}

__global__ void mat_transpose_f32_row2col_kernel(
  float *x, float *y, const int row, const int col) {
  const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int global_col = global_idx / row;
  const int global_row = global_idx % row;
  const int in_idx = global_row * col + global_col;
  const int out_idx = global_idx;
  if (global_idx < row * col) {
    y[out_idx] = x[in_idx];
  }
}

void cput(float *i, float *o, int dim){
  for (int l = 0; l<dim; l++)
    for (int m = 0; m < dim; m++)
      o[m*dim+l] = i[l*dim+m];
}

int main(){

  const int dim = 2048;
  const int s = dim*dim;
  const int sz = sizeof(float)*s;
  float *i, *o, *d_i, *d_o, *r, *d_ti, *d_to;
  i = new float[s];
  o = new float[s];
  r = new float[s];
  cudaMalloc(&d_i, sz);
  cudaMalloc(&d_o, sz);
  cudaMalloc(&d_ti, 4*sz);
  cudaMalloc(&d_to, 4*sz);
  for (int l = 0; l < s; l++) i[l] = (float)(l+1);
  cudaMemcpy(d_i, i, sz, cudaMemcpyHostToDevice);
  unsigned long long dt = dtime_usec(0);
  mat_transpose_f32_row2col_kernel<<<s/512, 512>>>(d_i, d_o, dim, dim);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "row2col: " << dt << "us" << std::endl;
  dt = dtime_usec(0);
  mat_transpose_f32_col2row_kernel<<<s/512, 512>>>(d_i, d_o, dim, dim);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "col2row: " << dt << "us" << std::endl;
  cudaMemcpy(o, d_o, sz, cudaMemcpyDeviceToHost);
  dt = dtime_usec(0);
  mat_transpose_f32_row2col_kernel<<<s/512, 512>>>(d_i, d_o, dim, dim);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "row2col: " << dt << "us" << std::endl;
  dt = dtime_usec(0);
  mat_transpose_f32_col2row_kernel<<<s/512, 512>>>(d_i, d_o, dim, dim);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "col2row: " << dt << "us" << std::endl;
  mat_transpose_f32_row2col_kernel<<<(s*4)/512, 512>>>(d_ti, d_to, dim*2, dim*2);
  cudaDeviceSynchronize();
  dt = dtime_usec(0);
  mat_transpose_f32_row2col_kernel<<<s/512, 512>>>(d_i, d_o, dim, dim);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "row2col: " << dt << "us" << std::endl;
  mat_transpose_f32_row2col_kernel<<<(s*4)/512, 512>>>(d_ti, d_to, dim*2, dim*2);
  cudaDeviceSynchronize();
  dt = dtime_usec(0);
  mat_transpose_f32_col2row_kernel<<<s/512, 512>>>(d_i, d_o, dim, dim);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "col2row: " << dt << "us" << std::endl;
  cput(i, r, dim);
  for (int l = 0; l < s; l++) if (r[l] != o[l]) {std::cout << "oops" << std::endl; return 0;}
  cudaDeviceSynchronize();
}
# nvcc -o t372 t372.cu -arch=sm_89
# ./t372
row2col: 368us
col2row: 230us
row2col: 114us
col2row: 209us
row2col: 223us
col2row: 209us
#

So the perf discrepancy seems associated with the cached case. My previous speculation should be discounted. I don’t have an immediate explanation.

(L4, CUDA 12.2, 50MB L2 cache)