Problem with the instruction "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"

I know a problem with the “mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32” instruction, which I am not having with "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 ".

Code for "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 ":

#include <mma.h>
#include <iostream>
#include <stdio.h>

__global__ void mma_fp16_acc_fp32() {
    float c[4] = {0., 0., 0., 0.};
    float d[4] = {0., 0., 0., 0.};
    float a32[4] = {1., 1., 1., 1.};
    float b32[2] = {1., 1.};
    uint32_t A[4];
    uint32_t B[2];

    // A
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[0]) : "f"(a32[0]), "f"(a32[1]));
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[1]) : "f"(a32[2]), "f"(a32[3]));
    // B
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(B[0]) : "f"(b32[0]), "f"(b32[1]));

    float const *C = reinterpret_cast<float const *>(&c);
    float *D = reinterpret_cast<float *>(&d);
    asm(
      "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
      "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
      : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
      : 
        "r"(A[0]), "r"(A[1]), 
        "r"(B[0]), 
        "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
    );
    printf("%i, %f, %f, %f, %f \n", threadIdx.x, D[0], D[1], D[2],D[3]);
}

int main() {
    mma_fp16_acc_fp32<<<1, 32>>>();
    cudaDeviceSynchronize();
}

result:

0, 8.000000, 8.000000, 8.000000, 8.000000 
1, 8.000000, 8.000000, 8.000000, 8.000000 
2, 8.000000, 8.000000, 8.000000, 8.000000 
3, 8.000000, 8.000000, 8.000000, 8.000000 
4, 8.000000, 8.000000, 8.000000, 8.000000 
5, 8.000000, 8.000000, 8.000000, 8.000000 
6, 8.000000, 8.000000, 8.000000, 8.000000 
7, 8.000000, 8.000000, 8.000000, 8.000000 
8, 8.000000, 8.000000, 8.000000, 8.000000 
9, 8.000000, 8.000000, 8.000000, 8.000000 
10, 8.000000, 8.000000, 8.000000, 8.000000 
11, 8.000000, 8.000000, 8.000000, 8.000000 
12, 8.000000, 8.000000, 8.000000, 8.000000 
13, 8.000000, 8.000000, 8.000000, 8.000000 
14, 8.000000, 8.000000, 8.000000, 8.000000 
15, 8.000000, 8.000000, 8.000000, 8.000000 
16, 8.000000, 8.000000, 8.000000, 8.000000 
17, 8.000000, 8.000000, 8.000000, 8.000000 
18, 8.000000, 8.000000, 8.000000, 8.000000 
19, 8.000000, 8.000000, 8.000000, 8.000000 
20, 8.000000, 8.000000, 8.000000, 8.000000 
21, 8.000000, 8.000000, 8.000000, 8.000000 
22, 8.000000, 8.000000, 8.000000, 8.000000 
23, 8.000000, 8.000000, 8.000000, 8.000000 
24, 8.000000, 8.000000, 8.000000, 8.000000 
25, 8.000000, 8.000000, 8.000000, 8.000000 
26, 8.000000, 8.000000, 8.000000, 8.000000 
27, 8.000000, 8.000000, 8.000000, 8.000000 
28, 8.000000, 8.000000, 8.000000, 8.000000 
29, 8.000000, 8.000000, 8.000000, 8.000000 
30, 8.000000, 8.000000, 8.000000, 8.000000 
31, 8.000000, 8.000000, 8.000000, 8.000000

Code for "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ":

#include <mma.h>
#include <iostream>
#include <stdio.h>

__global__ void mma_fp16_acc_fp32() {
    float c[4] = {0., 0., 0., 0.};
    float d[4] = {0., 0., 0., 0.};
    float a32[8] = {1., 1., 1., 1., 1., 1., 1., 1.};
    float b32[4] = {1., 1., 1., 1.};
    uint32_t A[4];
    uint32_t B[2];

    // A
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[0]) : "f"(a32[0]), "f"(a32[1]));
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[1]) : "f"(a32[2]), "f"(a32[3]));
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[2]) : "f"(a32[4]), "f"(a32[5]));
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[3]) : "f"(a32[6]), "f"(a32[7]));
    // B
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(B[0]) : "f"(b32[0]), "f"(b32[1]));
    asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(B[1]) : "f"(b32[2]), "f"(b32[3]));

    float const *C = reinterpret_cast<float const *>(&c);
    float *D = reinterpret_cast<float *>(&d);
    asm(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
        " { %0, %1, %2, %3 }, "
        " { %4, %5, %6, %7 }, "
        " { %8, %9 }, "
        " { %10, %11, %12, %13 };"
        :
        "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
        :
        "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
        "r"(B[0]), "r"(B[1]),
        "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
    );
    printf("%i, %f, %f, %f, %f \n", threadIdx.x, D[0], D[1], D[2],D[3]);
}

int main() {
    mma_fp16_acc_fp32<<<1, 32>>>();
    cudaDeviceSynchronize();
}

Result:

0, 56.250000, 56.250000, 56.250000, 56.250000 
1, 56.250000, 56.250000, 56.250000, 56.250000 
2, 56.250000, 56.250000, 56.250000, 56.250000 
3, 56.250000, 56.250000, 56.250000, 56.250000 
4, 56.250000, 56.250000, 56.250000, 56.250000 
5, 56.250000, 56.250000, 56.250000, 56.250000 
6, 56.250000, 56.250000, 56.250000, 56.250000 
7, 56.250000, 56.250000, 56.250000, 56.250000 
8, 56.250000, 56.250000, 56.250000, 56.250000 
9, 56.250000, 56.250000, 56.250000, 56.250000 
10, 56.250000, 56.250000, 56.250000, 56.250000 
11, 56.250000, 56.250000, 56.250000, 56.250000 
12, 56.250000, 56.250000, 56.250000, 56.250000 
13, 56.250000, 56.250000, 56.250000, 56.250000 
14, 56.250000, 56.250000, 56.250000, 56.250000 
15, 56.250000, 56.250000, 56.250000, 56.250000 
16, 56.250000, 56.250000, 56.250000, 56.250000 
17, 56.250000, 56.250000, 56.250000, 56.250000 
18, 56.250000, 56.250000, 56.250000, 56.250000 
19, 56.250000, 56.250000, 56.250000, 56.250000 
20, 56.250000, 56.250000, 56.250000, 56.250000 
21, 56.250000, 56.250000, 56.250000, 56.250000 
22, 56.250000, 56.250000, 56.250000, 56.250000 
23, 56.250000, 56.250000, 56.250000, 56.250000 
24, 56.250000, 56.250000, 56.250000, 56.250000 
25, 56.250000, 56.250000, 56.250000, 56.250000 
26, 56.250000, 56.250000, 56.250000, 56.250000 
27, 56.250000, 56.250000, 56.250000, 56.250000 
28, 56.250000, 56.250000, 56.250000, 56.250000 
29, 56.250000, 56.250000, 56.250000, 56.250000 
30, 56.250000, 56.250000, 56.250000, 56.250000 
31, 56.250000, 56.250000, 56.250000, 56.250000

if I am not wrong the result is 56.25 then I should get 16, and I do not understand where the error is, Thank you for your help

why are you converting to bf16:

"cvt.rn.bf16x2.f32 %0, %1, %2;\n" 
        ^^^^

for an instruction that expects f16:

"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
                                       ^^^ ^^^

?

Those are two different data types.

1 Like

very sorry thank for help