I am writing a kernel that uses wmma, and it performs reasonably well, but online blogs and other forum posts seem to hint at PTX mma ops providing much better performance. However, my implementation seems lacklustre since it fails to match up even to wmma (its about 10% slower). Any ideas why it’s slower and how I can achieve the faster performance everyone else hints at? Below are my MMA functions and structs for reference:
#define TILE_SIZE 16 // or 32 I keep varying this based on needs
// Row Major fragment
struct fragment8x4 {
static constexpr unsigned int ept = 4; // elements per thread;
half f[ept];
__forceinline__ static __device__ unsigned int get_row(unsigned int lane_id){
return (lane_id < 16) ? lane_id % 4 : (lane_id % 4) + 4;
}
__forceinline__ static __device__ unsigned int get_col(unsigned int i){
// assert(i < ept);
return i;
}
};
// Row Major fragment
struct fragment4x8 {
static constexpr unsigned int ept = 4; // elements per thread;
half f[ept];
__forceinline__ static __device__ unsigned int get_row(unsigned int lane_id){
return lane_id % 4;
}
__forceinline__ static __device__ unsigned int get_col(unsigned int i, unsigned int lane_id){
// assert(i < ept);
return (lane_id < 16) ? i : i + 4;
}
};
struct accumulator8x8fp16 {
static constexpr unsigned int ept = 8; // elements per thread;
half f[ept] = {__float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f),
__float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)};
__forceinline__ static __device__ unsigned int get_row(unsigned int lane_id){
// assert(i < ept);
return (lane_id < 16) ? lane_id % 4 : (lane_id % 4) + 4;
}
__forceinline__ static __device__ unsigned int get_col(unsigned int i){
// assert(i < ept);
return i;
}
};
__device__ accumulator8x8fp16 mma_16x16x16_fp16(const half* A, const half* B, unsigned int lane_id, accumulator8x8fp16 acc_frag){
constexpr int LDM = TILE_SIZE; // Leading dimension
fragment8x4 a_frag;
fragment4x8 b_frag;
int start_row_a = ((lane_id % 16) < 8) ? 0 : 8;
int start_col_b = ((lane_id % 8) < 4) ? 0 : 8;
int row_a = start_row_a + a_frag.get_row(lane_id);
const int* af_regs = (const int *)a_frag.f;
const int* bf_regs = (const int *)b_frag.f;
int* acc_regs = (int *)acc_frag.f;
for (int p = 0; p < 4; p++){
int start_col_a_and_row_b = p*4;
int row_b = start_col_a_and_row_b + b_frag.get_row(lane_id);
for (int i = 0; i < 4; i++){
int col_a = start_col_a_and_row_b + a_frag.get_col(i);
int col_b = start_col_b + b_frag.get_col(i, lane_id);
a_frag.f[i] = A[row_a * LDM + col_a];
b_frag.f[i] = B[row_b * LDM + col_b];
}
// perform mma
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0, %1, %2, %3}, "
"{%4, %5}, "
"{%6, %7}, "
"{%0, %1, %2, %3};\n"
: "+r"(acc_regs[0]), "+r"(acc_regs[1]), "+r"(acc_regs[2]), "+r"(acc_regs[3])
: "r"(af_regs[0]), "r"(af_regs[1]),
"r"(bf_regs[0]), "r"(bf_regs[1])
);
}
return acc_frag;
}
__device__ void ldmatrix_accfp16_to_smem(half* smem, const accumulator8x8fp16 &acc_frag, unsigned int lane_id){
constexpr int LDM = TILE_SIZE;
int start_row = ((lane_id % 16) < 8) ? 0 : 8;
int start_col = ((lane_id % 8) < 4) ? 0 : 8;
int out_row = start_row + acc_frag.get_row(lane_id);
for (int i = 0; i < 8; i++){
int out_col = start_col + acc_frag.get_col(i);
smem[out_row * LDM + out_col] = acc_frag.f[i];
}
}