Hey I’m trying to profile some community matrix transpose implementation to better understand metrics in ncu, both input and output matrix are stored in row-major format (and bind to PyTorch tensors).
Part I
Impl A
The first implementation is:
__global__ void mat_transpose_f32_col2row_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_row = global_idx / col;
const int global_col = global_idx % col;
const int in_idx = global_idx;
const int out_idx = global_col * row + global_row;
if (global_idx < row * col) {
y[out_idx] = x[in_idx];
}
}
, which does coalesced reads on input matrix but a strided writes on output matrix.
Impl B
__global__ void mat_transpose_f32_row2col_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_col = global_idx / row;
const int global_row = global_idx % row;
const int in_idx = global_row * col + global_col;
const int out_idx = global_idx;
if (global_idx < row * col) {
y[out_idx] = x[in_idx];
}
}
The second does it in the opposite way, strided reads on input matrix and coalesced writes on output matrix.
Description
I expected them to have similar performance (although they are both suboptimal), however, the Impl B is notably faster than Impl A (2x faster).
DRAM
By observing the DRAM profile, I saw something weird. Both kernels did 4 MiB read from DRAM, but Impl B (strided reads) achieved a higher throughput than Impl A (as below, Impl A is Baseline and Impl B is Current):
So my Q1 will be why strided reads here yielded higher DRAM throughput than coalesced reads?
I also tried to interpret L2 and L1 cache statitics.
L2
As for L2 cache, it’s understandable that strided read/write will lead to more sector reads and lower request efficiency. One metric I cannot understand is L2 Fabric Total: why would Impl A lead to cache misses in l2 partition and how would this lead to overall performance impact?
L1
Comparing L1 and L2 cache, I expected the global_load_sectors * (1 - hit_rate) == l1_load_sectors (check red and blue and boxes), but it seems there is still a gap in between. Meanwhile, the Impl B l1 cache hit rate drops dramatically compared with Impl A, why does it behave like this?
Part II
Impl C
I also tried to implement Impl A using 2D indices (Impl C) and did coalesced row reads, I expected Impl A and Impl C to have the same performance or even the same saas code, as the mapping between thread and elements is the same. However, I spot similar issue as in Part I. Could you please provide some insights?
__global__ void mat_transpose_f32_col2row2d_kernel(
float *x, float *y, const int row, const int col) {
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
int in_idx = global_y * col + global_x;
int out_idx = global_x * row + global_y;
if (global_x < col && global_y < row) {
y[out_idx] = x[in_idx];
}
}
Impl D
Just for curiosity, I tried to implement the above using CuTe (but I didn’t use any advanced features), and none of the above issues appear; the kernel is even faster than plain CUDA C. How could this be possible and does NVCC provide special optimization for CuTe?
template <typename T, int BLK_M, int BLK_N, typename ThreadLayoutA,
typename ThreadLayoutB>
__global__ void mat_transpose_cute_reg_kernel(const T *pA, T *pB, int M, int N,
ThreadLayoutA tA,
ThreadLayoutB tB) {
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
auto mA =
make_tensor(make_gmem_ptr(pA),
make_layout(make_shape(M, N), GenRowMajor{})); // (M, N)
auto mB =
make_tensor(make_gmem_ptr(pB),
make_layout(make_shape(N, M), GenRowMajor{})); // (N, M)
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
auto cA = local_tile(make_identity_tensor(mA.shape()),
make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
Tensor tAgA = local_partition(gA, tA, tx);
Tensor tBgB = local_partition(gB, tB, tx);
Tensor tAcA = local_partition(cA, tA, tx);
Tensor tApA = make_tensor<bool>(tAcA.shape(), tAcA.stride());
CUTE_UNROLL
for (int i = 0; i < size<0>(tApA); i++) {
CUTE_UNROLL
for (int j = 0; j < size<1>(tApA); j++) {
tApA(i, j) = get<0>(tAcA(i, j)) < M && get<1>(tAcA(i, j)) < N;
}
}
copy_if(tApA, tAgA, tBgB);
}
void mat_transpose_cute_row2col_reg(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenColMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_col2row_reg(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
CUDA_CHECK(cudaGetLastError());
}
Profile and code
Sorry that my questions might seem dumb and not providing enough details. The original codes and profiles are here in case you need them:
code_and_profile.zip (4.2 MB)