Questions about mma instruction with Nvidia ptx

Hi my understanding about mma instruction with ptx is (please tell me if I’m wrong):

  1. it is a per warp instruction
  2. it need to load specific element into register of each thread within the target warp
  3. The size of the multiplication is fixed (a limited set to choose from)

I have a compiled ptx instruction set and I picked out one mma instruction:

//
// Generated by LLVM NVPTX Back-End
//

.version 8.4
.target sm_89
.address_size 64

	// .globl	matmul_kernel
.extern .shared .align 16 .b8 global_smem[];

.visible .entry matmul_kernel(
	.param .u64 matmul_kernel_param_0,
	.param .u64 matmul_kernel_param_1,
	.param .u64 matmul_kernel_param_2,
	.param .u32 matmul_kernel_param_3,
	.param .u32 matmul_kernel_param_4,
	.param .u32 matmul_kernel_param_5,
	.param .u32 matmul_kernel_param_6,
	.param .u32 matmul_kernel_param_7,
	.param .u32 matmul_kernel_param_8
)
.maxntid 128, 1, 1
{
    ...
    ldmatrix.sync.aligned.m8n8.x4.shared.b16 { %r3100, %r3101, %r3102, %r3103 }, [ %r561 + 0 ];
    ...
    ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 { %r3084, %r3085, %r3086, %r3087 }, [ %r581 + 0 ];
    ...
	mov.f32 	%f2306, 0f00000000;
	mov.b32 	%r3107, 2;
	mov.b32 	%r3106, 0;
	shl.b32 	%r2885, %r100, 1;
	shl.b32 	%r2894, %r101, 1;
	shl.b32 	%r2895, %r102, 1;
	shl.b32 	%r2896, %r103, 1;
	shl.b32 	%r2897, %r104, 1;
	shl.b32 	%r2898, %r105, 1;
	shl.b32 	%r2899, %r106, 1;
	mov.u32 	%r3104, %r765;
	mov.u32 	%r3105, %r758;
	mov.f32 	%f2307, %f2306;
	mov.f32 	%f2308, %f2306;
	mov.f32 	%f2309, %f2306;
    ...
    mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %f2306, %f2307, %f2308, %f2309 }, { %r3100, %r3101, %r3102, %r3103 }, { %r3084, %r3085 }, { %f2306, %f2307, %f2308, %f2309 };
}

Now my questions are:

  1. by instruction, it is a D{16, 8}=A{16, 16}B{16, 8}+D{16, 8}$ multiplication with A,B in f16 type and D in f32 type. But the actual load instruction only load 4 regs for A. Does it mean that for each thread in warp, it will only load 8 element (since register is 32 bit width) and the total 32 thread will load 8x32=16x16 element of A? and same for matrix B and the result matrix D (D is 4x32=16x8 in type f32)
  2. does the tag sync means the target warp will continue until the tensor core finish matrix mult?
  3. it can be seen that the ptx version is 8.4, and from official doc it shows that 8.4 only support the shape of m16n8k64 for sparse and .m16n8k32 for dense, there is no support for m16n8k16, or m16n8k64 means it support any shape that m<16 and n<8 and k<64?
  4. as mentioned above, what is the meaning of dense and sparse of matrix?
1 Like