Hello.
I’m writing matrix multiplication kernel, mostly learnt from: this nvidia blog and this blog.
I tried to match the performance of the warp tiled kernel provided in the latter blog, but not succeeded. To find out what was wrong, I deliberately modified my own code to the point that the SASS of two kernels are completely similar (except for the indices of the constant memory).
And somehow on my device (RTX 3060 12GB), my kernel is about 10% slower than the target kernel on 4096x4096 matrices.
This is my code after the modification:
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
constexpr uint BLOCK_WORK_X = 128;
constexpr uint BLOCK_WORK_Y = 128;
constexpr uint THREAD_WORK_X = 8;
constexpr uint THREAD_WORK_Y = 8;
constexpr uint WARP_WORK_X = 32;
constexpr uint WARP_WORK_Y = 64;
constexpr uint WARP_X = 4;
constexpr uint WARP_Y = 8;
constexpr uint K = 8;
constexpr uint BLOCK_SIZE = 256;
#define CUDA_CHECK (err) { \
cudaError_t result = (err); \
if (result != cudaSuccess) { \
fprintf(stderr, "CUDA Error: %s in %s at line %d\n", \
cudaGetErrorString(result), __FILE__, __LINE__); \
exit(EXIT_FAILURE); \
} \
}
__global__ __launch_bounds__(256) void matmul_general(float *A, float *B, float *C, int m, int k, int n) {
uint tidx = threadIdx.x;
const uint threadRow = (tidx % 32) / 4;
const uint threadCol = (tidx % 32) % 4;
const uint warpRow = (tidx / 32) / (BLOCK_WORK_X / WARP_WORK_X);
const uint warpCol = (tidx / 32) % (BLOCK_WORK_X / WARP_WORK_X);
const uint loadRowA = tidx / (K / 4);
const uint loadColA = tidx % (K / 4);
const uint loadRowB = tidx / (BLOCK_WORK_X / 4);
const uint loadColB = tidx % (BLOCK_WORK_X / 4);
const uint gridRow = blockIdx.x;
const uint gridCol = blockIdx.y;
A += gridRow * BLOCK_WORK_Y * k;
B += gridCol * BLOCK_WORK_X;
C += (warpRow * WARP_WORK_Y + gridRow * BLOCK_WORK_Y) * n + gridCol * BLOCK_WORK_X + warpCol * WARP_WORK_X;
float regA[THREAD_WORK_Y] = {0}, regB[THREAD_WORK_X] = {0};
float res[THREAD_WORK_Y][THREAD_WORK_X] = {0};
float __shared__ sA[K * BLOCK_WORK_Y], sB[K * BLOCK_WORK_X];
for (uint i = 0; i < k; i += K) {
float4 loadA = reinterpret_cast< float4*>(&A[loadRowA * k + loadColA * 4])[0];
// float4 loadA = A4[(loadRowA) * k / 4 + loadColA];
sA[(loadColA * 4 + 0) * BLOCK_WORK_Y + loadRowA] = loadA.x;
sA[(loadColA * 4 + 1) * BLOCK_WORK_Y + loadRowA] = loadA.y;
sA[(loadColA * 4 + 2) * BLOCK_WORK_Y + loadRowA] = loadA.z;
sA[(loadColA * 4 + 3) * BLOCK_WORK_Y + loadRowA] = loadA.w;
reinterpret_cast<float4*>(&sB[(loadRowB) * BLOCK_WORK_X + loadColB * 4])[0] =
reinterpret_cast< float4*>(&B[loadRowB * n + loadColB * 4])[0];
__syncthreads();
for (uint p = 0; p < K; p++) {
for (uint t = 0; t < 2; t++)
for (uint y = 0; y < 4; y++) {
regA[t * 4 + y] = sA[p * BLOCK_WORK_Y + warpRow * WARP_WORK_Y + t * 32 + threadRow * 4 + y];
}
for (uint t = 0; t < 2; t++)
for (uint x = 0; x < 4; x++) {
regB[t * 4 + x] = sB[p * BLOCK_WORK_X + warpCol * WARP_WORK_X + t * 16 + threadCol * 4 + x];
}
for (uint t1 = 0; t1 < 2; t1++)
for (uint t2 = 0; t2 < 2; t2++)
for (uint y = 0; y < 4; y++) {
for (uint x = 0; x < 4; x++) {
res[t1 * 4 + y][t2 * 4 + x] += regB[t2 * 4 + x] * regA[t1 * 4 + y];
}
}
}
A += K;
B += K * n;
__syncthreads();
}
for (uint t1 = 0; t1 < 2; t1++)
for (uint t2 = 0; t2 < 2; t2++) {
float* Ct = C + (t1 * 32) * n + (t2 * 16);
for (uint y = 0; y < 4; y++) {
for (uint x = 0; x < 4; x += 4) {
reinterpret_cast<float4*>(&Ct[(threadRow * 4 + y) * n + (threadCol * 4 + x)])[0] =
make_float4(
res[y + t1 * 4][t2 * 4 + 0],
res[y + t1 * 4][t2 * 4 + 1],
res[y + t1 * 4][t2 * 4 + 2],
res[y + t1 * 4][t2 * 4 + 3]
);
}
}
}
}
and code I am comparing to, taken from the blog:
#include <cuda_runtime.h>
#include <stdio.h>
#define CUDA_CHECK (err) { \
cudaError_t result = (err); \
if (result != cudaSuccess) { \
fprintf(stderr, "CUDA Error: %s in %s at line %d\n", \
cudaGetErrorString(result), __FILE__, __LINE__); \
exit(EXIT_FAILURE); \
} \
}
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
const int WARPSIZE = 32; // warpSize is not constexpr
/*
* @tparam BM The threadblock size for M dimension SMEM caching.
* @tparam BN The threadblock size for N dimension SMEM caching.
* @tparam BK The threadblock size for K dimension SMEM caching.
* @tparam WM M dim of continuous tile computed by each warp
* @tparam WN N dim of continuous tile computed by each warp
* @tparam WMITER The number of subwarp tiling steps in M dimension.
* @tparam WNITER The number of subwarp tiling steps in N dimension.
* @tparam TM The per-thread tile size for M dimension.
* @tparam TN The per-thread tile size for N dimension.
*/
const uint NUM_THREADS = 256;
__global__ void __launch_bounds__(NUM_THREADS)
sgemmWarptiling(int M, int N, int K,float *A, float *B,
float *C) {
const uint BN = 128;
const uint BM = 128;
const uint BK = 8;
const uint WN = 32;
const uint WM = 64;
const uint WNITER = 2;
const uint TN = 4;
const uint TM = 4;
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
// Placement of the warp in the threadblock tile
const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in
const uint warpCol = warpIdx % (BN / WN);
const uint warpRow = warpIdx / (BN / WN);
// size of the warp subtile
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
constexpr uint WSUBM = WM / WMITER; // 64/2=32
constexpr uint WSUBN = WN / WNITER; // 32/2=16
// Placement of the thread in the warp subtile
const uint threadIdxInWarp = threadIdx.x % WARPSIZE; // [0, 31]
const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4)
const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4
// allocate space for the current blocktile in SMEM
__shared__ float As[BK * BM];
__shared__ float Bs[BK * BN];
// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
// Move C_ptr to warp's output tile
C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN;
// calculating the indices that this thread will load into SMEM
// we'll load 128bit / 32bit = 4 elements per thread at each step
const uint innerRowA = threadIdx.x / (BK / 4);
const uint innerColA = threadIdx.x % (BK / 4);
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
const uint innerRowB = threadIdx.x / (BN / 4);
const uint innerColB = threadIdx.x % (BN / 4);
constexpr uint rowStrideB = NUM_THREADS / (BN / 4);
// allocate thread-local cache for results in registerfile
float threadResults[WMITER * TM * WNITER * TN] = {0.0};
// we cache into registers on the warptile level
float regM[WMITER * TM] = {0.0};
float regN[WNITER * TN] = {0.0};
// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
float4 tmp = reinterpret_cast<float4 *>(
&A[(innerRowA + 0) * K + innerColA * 4])[0];
// transpose A while storing it
As[(innerColA * 4 + 0) * BM + innerRowA ] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA ] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA ] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA ] = tmp.w;
// }
// for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) {
reinterpret_cast<float4 *>(
&Bs[(innerRowB + 0) * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(
&B[(innerRowB + 0) * N + innerColB * 4])[0];
// }
__syncthreads();
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// populate registers for whole warptile
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint i = 0; i < TM; ++i) {
regM[wSubRowIdx * TM + i] =
As[dotIdx * BM + warpRow * WM + wSubRowIdx * WSUBM +
threadRowInWarp * TM + i];
}
}
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
for (uint i = 0; i < TN; ++i) {
regN[wSubColIdx * TN + i] =
Bs[(dotIdx) * BN + warpCol * WN + wSubColIdx * WSUBN +
threadColInWarp * TN + i];
}
}
// execute warptile matmul
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
// calculate per-thread results
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[(wSubRowIdx * TM + resIdxM) * WNITER * TN +
(wSubColIdx * TN) + resIdxN] +=
regM[wSubRowIdx * TM + resIdxM] *
regN[wSubColIdx * TN + resIdxN];
}
}
}
}
}
A += BK; // move BK columns to right
B += BK * N; // move BK rows down
__syncthreads();
}
// write out the results
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
// move C pointer to current warp subtile
float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN;
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
// load C vector into registers
float4 tmp ;
const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
wSubColIdx * TN + resIdxN;
tmp.x = threadResults[i + 0];
tmp.y = threadResults[i + 1] ;
tmp.z = threadResults[i + 2] ;
tmp.w = threadResults[i + 3];
// write back
reinterpret_cast<float4 *>(
&C_interim[(threadRowInWarp * TM + resIdxM) * N +
threadColInWarp * TN + resIdxN])[0] = tmp;
}
}
}
}
}
Profiling shows that my kernel has worse L2 hit. But the their assembly are similar. Can you help me verify the behaviour on your device. If you know what is going on, please tell me. Thank you in advance!