Hi my understanding about mma instruction with ptx is (please tell me if I’m wrong):
- it is a per warp instruction
- it need to load specific element into register of each thread within the target warp
- The size of the multiplication is fixed (a limited set to choose from)
I have a compiled ptx instruction set and I picked out one mma instruction:
//
// Generated by LLVM NVPTX Back-End
//
.version 8.4
.target sm_89
.address_size 64
// .globl matmul_kernel
.extern .shared .align 16 .b8 global_smem[];
.visible .entry matmul_kernel(
.param .u64 matmul_kernel_param_0,
.param .u64 matmul_kernel_param_1,
.param .u64 matmul_kernel_param_2,
.param .u32 matmul_kernel_param_3,
.param .u32 matmul_kernel_param_4,
.param .u32 matmul_kernel_param_5,
.param .u32 matmul_kernel_param_6,
.param .u32 matmul_kernel_param_7,
.param .u32 matmul_kernel_param_8
)
.maxntid 128, 1, 1
{
...
ldmatrix.sync.aligned.m8n8.x4.shared.b16 { %r3100, %r3101, %r3102, %r3103 }, [ %r561 + 0 ];
...
ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 { %r3084, %r3085, %r3086, %r3087 }, [ %r581 + 0 ];
...
mov.f32 %f2306, 0f00000000;
mov.b32 %r3107, 2;
mov.b32 %r3106, 0;
shl.b32 %r2885, %r100, 1;
shl.b32 %r2894, %r101, 1;
shl.b32 %r2895, %r102, 1;
shl.b32 %r2896, %r103, 1;
shl.b32 %r2897, %r104, 1;
shl.b32 %r2898, %r105, 1;
shl.b32 %r2899, %r106, 1;
mov.u32 %r3104, %r765;
mov.u32 %r3105, %r758;
mov.f32 %f2307, %f2306;
mov.f32 %f2308, %f2306;
mov.f32 %f2309, %f2306;
...
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %f2306, %f2307, %f2308, %f2309 }, { %r3100, %r3101, %r3102, %r3103 }, { %r3084, %r3085 }, { %f2306, %f2307, %f2308, %f2309 };
}
Now my questions are:
- by instruction, it is a D{16, 8}=A{16, 16}B{16, 8}+D{16, 8}$ multiplication with A,B in f16 type and D in f32 type. But the actual load instruction only load 4 regs for A. Does it mean that for each thread in warp, it will only load 8 element (since register is 32 bit width) and the total 32 thread will load 8x32=16x16 element of A? and same for matrix B and the result matrix D (D is 4x32=16x8 in type f32)
- does the tag
syncmeans the target warp will continue until the tensor core finish matrix mult? - it can be seen that the ptx version is 8.4, and from official doc it shows that 8.4 only support the shape of
m16n8k64for sparse and.m16n8k32for dense, there is no support form16n8k16, orm16n8k64means it support any shape that m<16 and n<8 and k<64? - as mentioned above, what is the meaning of dense and sparse of matrix?