Hello everyone, I’m learning using mma
instructions for rapid GeMM. Currently I’m following this post to use mma.m8n8k4. However, I encountered an issue.
My environment:
CUDA 11.7, Ubuntu 20.04, Tesla A100.
My code:
#include "cuda.h"
#include "stdio.h"
#include "cuda_fp16.h"
#include <iostream>
__global__ void mma_m16n16k4(half* A, half* B, half* C, half* bias) {
int laneid = threadIdx.x % 32;
int shard_id = (laneid / 4) % 4; // MMA shard id, 4 8x8x4 mma shards in total
int seg_id = (laneid % 4) + 4 * (laneid / 16); // segment id inside an mma shard
int A_begin_index = ((shard_id % 2) * 8 + seg_id ) * 8 + (shard_id / 2) * 4;
int B_begin_index = ((shard_id % 2) * 4 + seg_id % 4) * 16 + (shard_id / 2) * 8 + (seg_id / 4) * 4;
int C_begin_index = ((shard_id % 2) * 8 + seg_id ) * 16 + (shard_id / 2) * 8;
asm volatile (
" .reg .f16x2 %Ra<2>;\n"
" .reg .f16x2 %Rb<2>;\n"
" .reg .f16x2 %Rc<4>;\n"
" .reg .f16x2 %Rd<4>;\n"
" ld.global.b32 %Ra0, [%0];\n"
" ld.global.b32 %Ra1, [%0 + 4];\n"
" ld.global.b32 %Rb0, [%1];\n"
" ld.global.b32 %Rb1, [%1 + 4];\n"
" ld.global.b32 %Rc0, [%2];\n"
" ld.global.b32 %Rc1, [%2 + 4];\n"
" ld.global.b32 %Rc2, [%2 + 8];\n"
" ld.global.b32 %Rc3, [%2 + 12];\n"
" mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%Rd0, %Rd1, %Rd2, %Rd3}, {%Ra0, %Ra1}, {%Rb0, %Rb1}, {%Rc0, %Rc1, %Rc2, %Rc3};\n"
" st.global.b32 [%3], %Rd0;\n"
" st.global.b32 [%3 + 4], %Rd1;\n"
" st.global.b32 [%3 + 8], %Rd2;\n"
" st.global.b32 [%3 + 12], %Rd3;\n"
:
: "l"(&A[A_begin_index]), "l"(&B[B_begin_index]), "l"(&bias[C_begin_index]), "l"(&C[C_begin_index])
);
__syncthreads();
}
int main(int argc, char** argv) {
half *A, *B, *C, *bias;
float *C_ref;
constexpr int M = 16;
constexpr int N = 16;
constexpr int K = 4;
A = new half[M * K];
B = new half[K * N];
C = new half[M * N];
bias = new half[M * N];
C_ref = new float[M * N];
for (int i = 0; i < M * K; i++) A[i] = (half)(i * 0.1);
for (int i = 0; i < K * N; i++) B[i] = (half)(i * 0.1);
for (int i = 0; i < M * N; i++) bias[i] = (half)(0.00);
half *dA, *dB, *dC, *dbias;
cudaMalloc(reinterpret_cast<void**>(&dA), sizeof(half) * M * K);
cudaMalloc(reinterpret_cast<void**>(&dB), sizeof(half) * K * N);
cudaMalloc(reinterpret_cast<void**>(&dC), sizeof(half) * M * N);
cudaMalloc(reinterpret_cast<void**>(&dbias), sizeof(half) * M * N);
cudaMemcpy(reinterpret_cast<void*>(dA), A, sizeof(half) * M * K, cudaMemcpyHostToDevice);
cudaMemcpy(reinterpret_cast<void*>(dB), B, sizeof(half) * K * N, cudaMemcpyHostToDevice);
cudaMemcpy(reinterpret_cast<void*>(dbias), bias, sizeof(half) * M * N, cudaMemcpyHostToDevice);
mma_m16n16k4<<<1, 32>>>(dA, dB, dC, dbias);
cudaMemcpy(C, reinterpret_cast<void*>(dC), sizeof(half) * M * N, cudaMemcpyDeviceToHost);
printf("-----------------------------------------------------------------------------------------\n");
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
C_ref[m * N + n] = 0.0;
for (int k = 0; k < K; k++) {
C_ref[m * N + n] += ((float)A[m * K + k] * (float)B[k * N + n]);
}
}
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
printf("%05.2f%c", (float)C[i * N + j], (j == N - 1) ? '\n' : ' ');
}
}
return 0;
}
Expected output:
002.24 002.30 002.36 002.42 002.48 002.54 002.60 002.66 002.72 002.78 002.84 002.90 002.96 003.02 003.08 003.14
006.08 006.30 006.52 006.74 006.96 007.18 007.40 007.62 007.84 008.06 008.28 008.50 008.72 008.94 009.16 009.38
009.92 010.30 010.68 011.06 011.44 011.82 012.20 012.58 012.96 013.34 013.72 014.10 014.48 014.86 015.24 015.62
013.76 014.30 014.84 015.38 015.92 016.46 017.00 017.54 018.08 018.62 019.16 019.70 020.24 020.78 021.32 021.86
017.60 018.30 019.00 019.71 020.40 021.10 021.80 022.50 023.20 023.90 024.60 025.30 026.00 026.70 027.40 028.10
021.44 022.30 023.16 024.03 024.88 025.74 026.60 027.46 028.32 029.18 030.04 030.90 031.76 032.62 033.48 034.34
025.27 026.29 027.32 028.34 029.35 030.37 031.39 032.42 033.44 034.46 035.47 036.49 037.51 038.54 039.56 040.57
029.12 030.30 031.48 032.67 033.84 035.02 036.19 037.38 038.56 039.74 040.92 042.09 043.28 044.47 045.64 046.82
032.96 034.30 035.64 036.99 038.32 039.66 041.00 042.35 043.69 045.03 046.36 047.70 049.04 050.39 051.73 053.06
036.80 038.30 039.80 041.31 042.80 044.30 045.80 047.31 048.81 050.31 051.80 053.30 054.80 056.31 057.81 059.30
040.64 042.30 043.96 045.63 047.28 048.94 050.60 052.27 053.93 055.59 057.24 058.90 060.56 062.23 063.89 065.54
044.48 046.30 048.12 049.95 051.76 053.58 055.39 057.23 059.05 060.86 062.68 064.49 066.32 068.15 069.97 071.78
048.32 050.30 052.29 054.28 056.24 058.22 060.20 062.19 064.17 066.15 068.12 070.10 072.08 074.07 076.05 078.02
052.15 054.29 056.44 058.59 060.71 062.85 064.99 067.14 069.28 071.42 073.55 075.68 077.83 079.98 082.12 084.25
055.99 058.29 060.59 062.91 065.19 067.49 069.79 072.10 074.40 076.70 078.99 081.28 083.59 085.90 088.20 090.49
059.84 062.30 064.76 067.24 069.68 072.14 074.60 077.07 079.54 081.99 084.44 086.89 089.36 091.84 094.29 096.74
My actual output:
02.24 02.30 02.36 02.42 02.48 02.54 02.60 02.66 07.84 08.06 08.28 08.50 08.72 08.95 09.16 09.38
09.91 10.30 10.68 11.06 11.44 11.82 12.20 12.58 18.08 18.62 19.16 19.70 20.23 20.78 21.33 21.86
17.59 18.30 19.00 19.70 20.41 21.09 21.80 22.50 28.33 29.19 30.05 30.89 31.77 32.62 33.47 34.34
25.28 26.30 27.31 28.34 29.34 30.38 31.39 32.41 38.56 39.75 40.91 42.09 43.28 44.47 45.66 46.81
32.97 34.31 35.66 37.00 38.31 39.66 41.00 42.34 48.81 50.31 51.81 53.28 54.81 56.31 57.81 59.31
40.66 42.31 43.97 45.62 47.28 48.94 50.59 52.28 59.06 60.88 62.69 64.50 66.31 68.12 69.94 71.75
48.31 50.31 52.28 54.28 56.25 58.22 60.19 62.19 69.31 71.44 73.56 75.69 77.81 80.00 82.12 84.25
56.00 58.28 60.59 62.91 65.19 67.50 69.81 72.12 79.56 82.00 84.44 86.88 89.38 91.81 94.31 96.75
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
Difference:
Half of my result are zeros, and the rest results are correct in values while incorrect in positions. I guess that A and B are correctly loaded into the regs since if they are not correctly loaded, the rest results should not be correct. I wonder what are potential causes of this problem?