Complete minimal ptx example for: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32

Here is my minimal example: (doesn’t work)


// tensor_mma_kernel.ptx
.version 8.4
.target sm_89
.address_size 64

// Entry point for the kernel
.entry tensor_mma_kernel(
    .param.u64.ptr.global.align 16 A_ptr,   // Pointer to matrix A in global memory
    .param.u64.ptr.global.align 16 B_ptr,   // Pointer to matrix B in global memory
    .param.u64.ptr.global.align 16 C_ptr    // Pointer to matrix C in global memory
)
{
    // Declare registers for pointers
    .reg .u64 %ra, %rb, %rc;              // Base pointers to A, B, C
    .reg .u64 %rb0, %rb1;                 // Addresses for elements in B
    .reg .f16 %Ra0, %Ra1, %Ra2, %Ra3;     // f16 data for A
    .reg .f16 %Rb0, %Rb1;                 // f16 data for B
    .reg .f32 %Rc0, %Rc1, %Rc2, %Rc3;     // Accumulators for C
    .reg .f32 %Rd0, %Rd1, %Rd2, %Rd3;     // Results for D

    // Load pointers from parameters
    ld.param.u64 %ra, [A_ptr];
    ld.param.u64 %rb, [B_ptr];
    ld.param.u64 %rc, [C_ptr];

    // Initialize accumulators for matrix C to 0.0f
    mov.f32 %Rc0, 0.0;
    mov.f32 %Rc1, 0.0;
    mov.f32 %Rc2, 0.0;
    mov.f32 %Rc3, 0.0;

    // Load raw A data from global memory as u16 and convert to f16
    ld.global.b16 %Ra0, [%ra];
    ld.global.b16 %Ra1, [%ra + 2];
    ld.global.b16 %Ra2, [%ra + 4];
    ld.global.b16 %Ra3, [%ra + 6];

    // Load raw B data from global memory as u16 and convert to f16
    ld.global.b16 %Rb0, [%rb];
    ld.global.b16 %Rb1, [%rb + 2];

    // MMA operation: D = A * B + C
    mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // this is line 42 
      {%Rd0, %Rd1, %Rd2, %Rd3},
      {%Ra0, %Ra1, %Ra2, %Ra3},
      {%Rb0, %Rb1},
      {%Rc0, %Rc1, %Rc2, %Rc3};

    // Store the result back to global memory
    st.global.f32 [%rc + 0], %Rd0;
    st.global.f32 [%rc + 4], %Rd1;
    st.global.f32 [%rc + 8], %Rd2;
    st.global.f32 [%rc + 12], %Rd3;

    // Exit the kernel
    exit;
}

when I run:
ptxas -arch=sm_89 tensor_mma_kernel.ptx -o tensor_mma_kernel.cubin

I get:
ptxas tensor_mma_kernel.ptx, line 42; error : Arguments mismatch for instruction ‘mma’
ptxas fatal : Ptx assembly aborted due to errors

What am I doing wrong?

The PTX manual gives as example:
https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma

.reg .f16 %Ra<4>, %Rb<2>;
.reg .f32 %Rc<2>, %Rd<2>;
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
  {%Rd0, %Rd1, %Rd2, %Rd3},
  {%Ra0, %Ra1, %Ra2, %Ra3},
  {%Rb0, %Rb1},
  {%Rc0, %Rc1, %Rc2, %Rc3};

Perhaps your registers have to be defined as array with the <2> or <4> signifier?

Whereas Cutlass uses int32 instead of fp16 in the asm block:

    uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
    uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
    float const *C = reinterpret_cast<float const *>(&c);
    float *D = reinterpret_cast<float *>(&d);

    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
        "{%10,%11,%12,%13};\n"
        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
          "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));

I’ve tried defining the registers both ways it doesn’t seem to matter. For brevity I’ll use <> syntax in this reply.

I can get it to compile by setting the %Ra<4>,%Rb<2>registers to .b32. This implies that perhaps it is expecting the %Ra<4>. %Rb<2> registers to actually be f16x2 packed?

One would think nvidia could afford to hire a few engineers to meticulously document this stuff instead of having some half baked LLM auto generate the examples…

https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m16n8k16-with-floating-point-type

A: A vector expression containing four .f16x2 registers, with each register containing two .f16 / .bf16 elements from the matrix A.

B: A vector expression containing two .f16x2 registers, with each register containing two .f16 / .bf16 elements from the matrix B.

The examples may be wrong (which is bad), but it is documented at least elsewhere that the registers have to be packed.

I would rather assume human error, I do not think that the examples were done by a LLM.

2 Likes