Error or incomprehension, MMa ptx mixed precision Bfloat16 rtx3080

As an example demonstrating a “full” M=16, N =8, K=8 matrix multiply, let’s perform the following op:

D = A*B + C

we will set C to all zero. We will choose these values for A,B:

``````            A                         B                         D
0  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1    9  9  9  9  9  9  9  9
1  1  1  1  1  1  1  1     3  3  3  3  3  3  3  3   10 10 10 10 10 10 10 10
2  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1   11 11 11 11 11 11 11 11
3  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1   12 12 12 12 12 12 12 12
4  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1   13 13 13 13 13 13 13 13
5  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1   14 14 14 14 14 14 14 14
6  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1   15 15 15 15 15 15 15 15
7  1  1  1  1  1  1  1 *   1  1  1  1  1  1  1  1 = 16 16 16 16 16 16 16 16
8  1  1  1  1  1  1  1                              17 17 17 17 17 17 17 17
9  1  1  1  1  1  1  1                              18 18 18 18 18 18 18 18
10  1  1  1  1  1  1  1                              19 19 19 19 19 19 19 19
11  1  1  1  1  1  1  1                              20 20 20 20 20 20 20 20
12  1  1  1  1  1  1  1                              21 21 21 21 21 21 21 21
13  1  1  1  1  1  1  1                              22 22 22 22 22 22 22 22
14  1  1  1  1  1  1  1                              23 23 23 23 23 23 23 23
15  1  1  1  1  1  1  1                              24 24 24 24 24 24 24 24
``````

Before going any further, I would encourage you to convince yourself that the above linear algebra is correct.

The following then is code that implements that. Note that for simplicity, I have removed the use of bfloat16 and instead am using fp16. This should not matter conceptually for understanding how the op works:

``````\$ cat t10.cu
#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>

__global__ void mma_fp16_acc_fp32(float *out) {
float c[4] = {0., 0., 0., 0.};
float d[4] = {0., 0., 0., 0.};
half a[4] = {1., 1., 1., 1.};
half b[2] = {1., 1.};
// the above would set our input matrices to all 1
// now lets modify some values
if (threadIdx.x%4 == 0) {
// set the first column of A to be 0, 1, 2, 3, ... 15
a[0] = threadIdx.x/4; a[2] = threadIdx.x/4 + 8;
// set the second row of B to 3,3,3, ... 3
b[1] = 3;}
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
float const *C = reinterpret_cast<float const *>(&c);
float *D = reinterpret_cast<float *>(&d);
asm(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"(A[0]), "r"(A[1]),
"r"(B[0]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
}

int main() {
float* h_C = (float*)malloc(16*8*sizeof(float));
float* d_C;
cudaMalloc(&d_C, 16*8*sizeof(float));
mma_fp16_acc_fp32<<<1, 32>>>(d_C);
cudaMemcpy(h_C, d_C, 16*8*sizeof(float), cudaMemcpyDeviceToHost);
for (int i = 0; i < 16; i++){
for (int j = 0; j < 8; j++) std::cout << h_C[i*8+j] << " ";
std::cout << std::endl;}

}
\$ nvcc -o t10 t10.cu -arch=sm_75
\$ cuda-memcheck ./t10
========= CUDA-MEMCHECK
9 9 9 9 9 9 9 9
10 10 10 10 10 10 10 10
11 11 11 11 11 11 11 11
12 12 12 12 12 12 12 12
13 13 13 13 13 13 13 13
14 14 14 14 14 14 14 14
15 15 15 15 15 15 15 15
16 16 16 16 16 16 16 16
17 17 17 17 17 17 17 17
18 18 18 18 18 18 18 18
19 19 19 19 19 19 19 19
20 20 20 20 20 20 20 20
21 21 21 21 21 21 21 21
22 22 22 22 22 22 22 22
23 23 23 23 23 23 23 23
24 24 24 24 24 24 24 24
========= ERROR SUMMARY: 0 errors
\$
``````

In order to understand the `if` statement in the kernel code (and likewise for the final `memcpy` statements), which is selecting specific matrix rows and columns, by selecting particular elements of the fragments distributed across the warp, I encourage you to study the charts indicating fragment organization in the documentation.

I would also point out that this code is just for understanding the behavior of the selected op. I’m not suggesting this code is how you would write a bulk, efficient matrix-matrix multiply routine. For that I would refer you to CUTLASS.

