I know a problem with the “mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32” instruction, which I am not having with "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 ".
Code for "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 ":
#include <mma.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32() {
float c[4] = {0., 0., 0., 0.};
float d[4] = {0., 0., 0., 0.};
float a32[4] = {1., 1., 1., 1.};
float b32[2] = {1., 1.};
uint32_t A[4];
uint32_t B[2];
// A
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[0]) : "f"(a32[0]), "f"(a32[1]));
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[1]) : "f"(a32[2]), "f"(a32[3]));
// B
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(B[0]) : "f"(b32[0]), "f"(b32[1]));
float const *C = reinterpret_cast<float const *>(&c);
float *D = reinterpret_cast<float *>(&d);
asm(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"(A[0]), "r"(A[1]),
"r"(B[0]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
printf("%i, %f, %f, %f, %f \n", threadIdx.x, D[0], D[1], D[2],D[3]);
}
int main() {
mma_fp16_acc_fp32<<<1, 32>>>();
cudaDeviceSynchronize();
}
result:
0, 8.000000, 8.000000, 8.000000, 8.000000
1, 8.000000, 8.000000, 8.000000, 8.000000
2, 8.000000, 8.000000, 8.000000, 8.000000
3, 8.000000, 8.000000, 8.000000, 8.000000
4, 8.000000, 8.000000, 8.000000, 8.000000
5, 8.000000, 8.000000, 8.000000, 8.000000
6, 8.000000, 8.000000, 8.000000, 8.000000
7, 8.000000, 8.000000, 8.000000, 8.000000
8, 8.000000, 8.000000, 8.000000, 8.000000
9, 8.000000, 8.000000, 8.000000, 8.000000
10, 8.000000, 8.000000, 8.000000, 8.000000
11, 8.000000, 8.000000, 8.000000, 8.000000
12, 8.000000, 8.000000, 8.000000, 8.000000
13, 8.000000, 8.000000, 8.000000, 8.000000
14, 8.000000, 8.000000, 8.000000, 8.000000
15, 8.000000, 8.000000, 8.000000, 8.000000
16, 8.000000, 8.000000, 8.000000, 8.000000
17, 8.000000, 8.000000, 8.000000, 8.000000
18, 8.000000, 8.000000, 8.000000, 8.000000
19, 8.000000, 8.000000, 8.000000, 8.000000
20, 8.000000, 8.000000, 8.000000, 8.000000
21, 8.000000, 8.000000, 8.000000, 8.000000
22, 8.000000, 8.000000, 8.000000, 8.000000
23, 8.000000, 8.000000, 8.000000, 8.000000
24, 8.000000, 8.000000, 8.000000, 8.000000
25, 8.000000, 8.000000, 8.000000, 8.000000
26, 8.000000, 8.000000, 8.000000, 8.000000
27, 8.000000, 8.000000, 8.000000, 8.000000
28, 8.000000, 8.000000, 8.000000, 8.000000
29, 8.000000, 8.000000, 8.000000, 8.000000
30, 8.000000, 8.000000, 8.000000, 8.000000
31, 8.000000, 8.000000, 8.000000, 8.000000
Code for "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ":
#include <mma.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32() {
float c[4] = {0., 0., 0., 0.};
float d[4] = {0., 0., 0., 0.};
float a32[8] = {1., 1., 1., 1., 1., 1., 1., 1.};
float b32[4] = {1., 1., 1., 1.};
uint32_t A[4];
uint32_t B[2];
// A
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[0]) : "f"(a32[0]), "f"(a32[1]));
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[1]) : "f"(a32[2]), "f"(a32[3]));
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[2]) : "f"(a32[4]), "f"(a32[5]));
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(A[3]) : "f"(a32[6]), "f"(a32[7]));
// B
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(B[0]) : "f"(b32[0]), "f"(b32[1]));
asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(B[1]) : "f"(b32[2]), "f"(b32[3]));
float const *C = reinterpret_cast<float const *>(&c);
float *D = reinterpret_cast<float *>(&d);
asm(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
" { %0, %1, %2, %3 }, "
" { %4, %5, %6, %7 }, "
" { %8, %9 }, "
" { %10, %11, %12, %13 };"
:
"=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
printf("%i, %f, %f, %f, %f \n", threadIdx.x, D[0], D[1], D[2],D[3]);
}
int main() {
mma_fp16_acc_fp32<<<1, 32>>>();
cudaDeviceSynchronize();
}
Result:
0, 56.250000, 56.250000, 56.250000, 56.250000
1, 56.250000, 56.250000, 56.250000, 56.250000
2, 56.250000, 56.250000, 56.250000, 56.250000
3, 56.250000, 56.250000, 56.250000, 56.250000
4, 56.250000, 56.250000, 56.250000, 56.250000
5, 56.250000, 56.250000, 56.250000, 56.250000
6, 56.250000, 56.250000, 56.250000, 56.250000
7, 56.250000, 56.250000, 56.250000, 56.250000
8, 56.250000, 56.250000, 56.250000, 56.250000
9, 56.250000, 56.250000, 56.250000, 56.250000
10, 56.250000, 56.250000, 56.250000, 56.250000
11, 56.250000, 56.250000, 56.250000, 56.250000
12, 56.250000, 56.250000, 56.250000, 56.250000
13, 56.250000, 56.250000, 56.250000, 56.250000
14, 56.250000, 56.250000, 56.250000, 56.250000
15, 56.250000, 56.250000, 56.250000, 56.250000
16, 56.250000, 56.250000, 56.250000, 56.250000
17, 56.250000, 56.250000, 56.250000, 56.250000
18, 56.250000, 56.250000, 56.250000, 56.250000
19, 56.250000, 56.250000, 56.250000, 56.250000
20, 56.250000, 56.250000, 56.250000, 56.250000
21, 56.250000, 56.250000, 56.250000, 56.250000
22, 56.250000, 56.250000, 56.250000, 56.250000
23, 56.250000, 56.250000, 56.250000, 56.250000
24, 56.250000, 56.250000, 56.250000, 56.250000
25, 56.250000, 56.250000, 56.250000, 56.250000
26, 56.250000, 56.250000, 56.250000, 56.250000
27, 56.250000, 56.250000, 56.250000, 56.250000
28, 56.250000, 56.250000, 56.250000, 56.250000
29, 56.250000, 56.250000, 56.250000, 56.250000
30, 56.250000, 56.250000, 56.250000, 56.250000
31, 56.250000, 56.250000, 56.250000, 56.250000
if I am not wrong the result is 56.25 then I should get 16, and I do not understand where the error is, Thank you for your help