Optimizing PTX mma ops on volta to surpass wmma

I am writing a kernel that uses wmma, and it performs reasonably well, but online blogs and other forum posts seem to hint at PTX mma ops providing much better performance. However, my implementation seems lacklustre since it fails to match up even to wmma (its about 10% slower). Any ideas why it’s slower and how I can achieve the faster performance everyone else hints at? Below are my MMA functions and structs for reference:

#define TILE_SIZE 16 // or 32 I keep varying this based on needs

// Row Major fragment
struct fragment8x4 {
    static constexpr unsigned int ept = 4; // elements per thread;
    half f[ept];

    __forceinline__ static __device__ unsigned int get_row(unsigned int lane_id){
        return (lane_id < 16) ? lane_id % 4  : (lane_id % 4) + 4;
    }

    __forceinline__ static __device__ unsigned int get_col(unsigned int i){
        // assert(i < ept);
        return i;
    }
};

// Row Major fragment
struct fragment4x8 {
    static constexpr unsigned int ept = 4; // elements per thread;
    half f[ept];

    __forceinline__ static __device__ unsigned int get_row(unsigned int lane_id){
        return lane_id % 4;
    }

    __forceinline__ static __device__ unsigned int get_col(unsigned int i, unsigned int lane_id){
        // assert(i < ept);
        return (lane_id < 16) ? i : i + 4;
    }
};

struct accumulator8x8fp16 {
    static constexpr unsigned int ept = 8; // elements per thread;
    half f[ept] = {__float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f),
                    __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)};

    __forceinline__ static __device__ unsigned int get_row(unsigned int lane_id){
        // assert(i < ept);
        return (lane_id < 16) ? lane_id % 4 : (lane_id % 4) + 4;
    }

    __forceinline__ static __device__ unsigned int get_col(unsigned int i){
        // assert(i < ept);
        return i;
    }
};

__device__ accumulator8x8fp16 mma_16x16x16_fp16(const half* A, const half* B, unsigned int lane_id, accumulator8x8fp16 acc_frag){
    constexpr int LDM = TILE_SIZE; // Leading dimension

    fragment8x4 a_frag;
    fragment4x8 b_frag;

    int start_row_a = ((lane_id % 16) < 8) ? 0 : 8;
    int start_col_b = ((lane_id % 8) < 4) ? 0 : 8;

    int row_a = start_row_a + a_frag.get_row(lane_id);

    const int* af_regs = (const int *)a_frag.f;
    const int* bf_regs = (const int *)b_frag.f;
    int* acc_regs = (int *)acc_frag.f;

    for (int p = 0; p < 4; p++){
        int start_col_a_and_row_b = p*4;
        int row_b = start_col_a_and_row_b + b_frag.get_row(lane_id);

        for (int i = 0; i < 4; i++){
            int col_a = start_col_a_and_row_b + a_frag.get_col(i);
            int col_b = start_col_b + b_frag.get_col(i, lane_id);

            a_frag.f[i] = A[row_a * LDM + col_a];
            b_frag.f[i] = B[row_b * LDM + col_b];
        }

        // perform mma
        asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
               "{%0, %1, %2, %3}, "
               "{%4, %5}, "
               "{%6, %7}, "
               "{%0, %1, %2, %3};\n"
               : "+r"(acc_regs[0]), "+r"(acc_regs[1]), "+r"(acc_regs[2]), "+r"(acc_regs[3])
               : "r"(af_regs[0]), "r"(af_regs[1]),
                 "r"(bf_regs[0]), "r"(bf_regs[1])
        );
    }

    return acc_frag;
}

__device__ void ldmatrix_accfp16_to_smem(half* smem, const accumulator8x8fp16 &acc_frag, unsigned int lane_id){
    constexpr int LDM = TILE_SIZE;

    int start_row = ((lane_id % 16) < 8) ? 0 : 8;
    int start_col = ((lane_id % 8) < 4) ? 0 : 8;

    int out_row = start_row + acc_frag.get_row(lane_id);

    for (int i = 0; i < 8; i++){
        int out_col = start_col + acc_frag.get_col(i);

        smem[out_row * LDM + out_col] = acc_frag.f[i];
    }
}

maybe you now have data that calls that assertion into question

wmma is just a convenience API around mma. So mma is at least as fast - in theory.

However, wmma exists on PTX level, so there could (in theory) be internal optimizations, when it is converted to SASS. (So one could say it is NOT an API around it.)

Have a look at the central SASS code of both variants.

Compare with Nsight Compute the actual bottleneck - often the memory accesses are slower than the Tensor Core computations.