As an example demonstrating a “full” M=16, N =8, K=8 matrix multiply, let’s perform the following op:
D = A*B + C
we will set C to all zero. We will choose these values for A,B:
A B D
0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 9 9 9 9 9 9 9 9
1 1 1 1 1 1 1 1 3 3 3 3 3 3 3 3 10 10 10 10 10 10 10 10
2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 11 11 11 11 11 11 11
3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 12 12 12 12 12 12 12 12
4 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 13 13 13 13 13 13 13 13
5 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 14 14 14 14 14 14 14 14
6 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 15 15 15 15 15 15 15 15
7 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 = 16 16 16 16 16 16 16 16
8 1 1 1 1 1 1 1 17 17 17 17 17 17 17 17
9 1 1 1 1 1 1 1 18 18 18 18 18 18 18 18
10 1 1 1 1 1 1 1 19 19 19 19 19 19 19 19
11 1 1 1 1 1 1 1 20 20 20 20 20 20 20 20
12 1 1 1 1 1 1 1 21 21 21 21 21 21 21 21
13 1 1 1 1 1 1 1 22 22 22 22 22 22 22 22
14 1 1 1 1 1 1 1 23 23 23 23 23 23 23 23
15 1 1 1 1 1 1 1 24 24 24 24 24 24 24 24
Before going any further, I would encourage you to convince yourself that the above linear algebra is correct.
The following then is code that implements that. Note that for simplicity, I have removed the use of bfloat16 and instead am using fp16. This should not matter conceptually for understanding how the op works:
$ cat t10.cu
#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32(float *out) {
float c[4] = {0., 0., 0., 0.};
float d[4] = {0., 0., 0., 0.};
half a[4] = {1., 1., 1., 1.};
half b[2] = {1., 1.};
// the above would set our input matrices to all 1
// now lets modify some values
if (threadIdx.x%4 == 0) {
// set the first column of A to be 0, 1, 2, 3, ... 15
a[0] = threadIdx.x/4; a[2] = threadIdx.x/4 + 8;
// set the second row of B to 3,3,3, ... 3
b[1] = 3;}
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
float const *C = reinterpret_cast<float const *>(&c);
float *D = reinterpret_cast<float *>(&d);
asm(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"(A[0]), "r"(A[1]),
"r"(B[0]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
memcpy(out+threadIdx.x*2, D, 8);
memcpy(out+8*8+threadIdx.x*2, D+2, 8);
}
int main() {
float* h_C = (float*)malloc(16*8*sizeof(float));
float* d_C;
cudaMalloc(&d_C, 16*8*sizeof(float));
mma_fp16_acc_fp32<<<1, 32>>>(d_C);
cudaDeviceSynchronize();
cudaMemcpy(h_C, d_C, 16*8*sizeof(float), cudaMemcpyDeviceToHost);
for (int i = 0; i < 16; i++){
for (int j = 0; j < 8; j++) std::cout << h_C[i*8+j] << " ";
std::cout << std::endl;}
}
$ nvcc -o t10 t10.cu -arch=sm_75
$ cuda-memcheck ./t10
========= CUDA-MEMCHECK
9 9 9 9 9 9 9 9
10 10 10 10 10 10 10 10
11 11 11 11 11 11 11 11
12 12 12 12 12 12 12 12
13 13 13 13 13 13 13 13
14 14 14 14 14 14 14 14
15 15 15 15 15 15 15 15
16 16 16 16 16 16 16 16
17 17 17 17 17 17 17 17
18 18 18 18 18 18 18 18
19 19 19 19 19 19 19 19
20 20 20 20 20 20 20 20
21 21 21 21 21 21 21 21
22 22 22 22 22 22 22 22
23 23 23 23 23 23 23 23
24 24 24 24 24 24 24 24
========= ERROR SUMMARY: 0 errors
$
In order to understand the if
statement in the kernel code (and likewise for the final memcpy
statements), which is selecting specific matrix rows and columns, by selecting particular elements of the fragments distributed across the warp, I encourage you to study the charts indicating fragment organization in the documentation.
I would also point out that this code is just for understanding the behavior of the selected op. I’m not suggesting this code is how you would write a bulk, efficient matrix-matrix multiply routine. For that I would refer you to CUTLASS.