Mma instruction question in memory addressing

Hello everyone, I’m learning using mma instructions for rapid GeMM. Currently I’m following this post to use mma.m8n8k4. However, I encountered an issue.
My environment:
CUDA 11.7, Ubuntu 20.04, Tesla A100.
My code:

#include "cuda.h"                                                                                                                                                                                                                                                                                                                                                    
#include "stdio.h"
#include "cuda_fp16.h"
#include <iostream>

__global__ void mma_m16n16k4(half* A, half* B, half* C, half* bias) {
    int laneid = threadIdx.x % 32; 
    int shard_id = (laneid / 4) % 4;                 // MMA shard id, 4 8x8x4 mma shards in total
    int seg_id = (laneid % 4) + 4 * (laneid / 16);   // segment id inside an mma shard
    int A_begin_index = ((shard_id % 2) * 8 + seg_id    ) * 8  + (shard_id / 2) * 4;
    int B_begin_index = ((shard_id % 2) * 4 + seg_id % 4) * 16 + (shard_id / 2) * 8 + (seg_id / 4) * 4;
    int C_begin_index = ((shard_id % 2) * 8 + seg_id    ) * 16 + (shard_id / 2) * 8;

    asm volatile (
        "   .reg .f16x2 %Ra<2>;\n"
        "   .reg .f16x2 %Rb<2>;\n"
        "   .reg .f16x2 %Rc<4>;\n"
        "   .reg .f16x2 %Rd<4>;\n"
        "   ld.global.b32 %Ra0, [%0];\n"
        "   ld.global.b32 %Ra1, [%0 + 4];\n"
        "   ld.global.b32 %Rb0, [%1];\n"
        "   ld.global.b32 %Rb1, [%1 + 4];\n"
        "   ld.global.b32 %Rc0, [%2];\n"
        "   ld.global.b32 %Rc1, [%2 + 4];\n"
        "   ld.global.b32 %Rc2, [%2 + 8];\n"
        "   ld.global.b32 %Rc3, [%2 + 12];\n"
        "   mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%Rd0, %Rd1, %Rd2, %Rd3}, {%Ra0, %Ra1}, {%Rb0, %Rb1}, {%Rc0, %Rc1, %Rc2, %Rc3};\n"
        "   st.global.b32 [%3], %Rd0;\n"
        "   st.global.b32 [%3 + 4], %Rd1;\n"
        "   st.global.b32 [%3 + 8], %Rd2;\n"
        "   st.global.b32 [%3 + 12], %Rd3;\n"
        :   
        : "l"(&A[A_begin_index]), "l"(&B[B_begin_index]), "l"(&bias[C_begin_index]), "l"(&C[C_begin_index])
    );  
    __syncthreads();

}

int main(int argc, char** argv) {
    half *A, *B, *C, *bias;
    float *C_ref;
    constexpr int M = 16; 
    constexpr int N = 16; 
    constexpr int K = 4;
    A = new half[M * K]; 
    B = new half[K * N]; 
    C = new half[M * N]; 
    bias = new half[M * N]; 
    C_ref = new float[M * N]; 

    for (int i = 0; i < M * K; i++) A[i] = (half)(i * 0.1);
    for (int i = 0; i < K * N; i++) B[i] = (half)(i * 0.1);
    for (int i = 0; i < M * N; i++) bias[i] = (half)(0.00);

    half *dA, *dB, *dC, *dbias;
    cudaMalloc(reinterpret_cast<void**>(&dA), sizeof(half) * M * K); 
    cudaMalloc(reinterpret_cast<void**>(&dB), sizeof(half) * K * N); 
    cudaMalloc(reinterpret_cast<void**>(&dC), sizeof(half) * M * N); 
    cudaMalloc(reinterpret_cast<void**>(&dbias), sizeof(half) * M * N); 
    cudaMemcpy(reinterpret_cast<void*>(dA), A, sizeof(half) * M * K, cudaMemcpyHostToDevice);
    cudaMemcpy(reinterpret_cast<void*>(dB), B, sizeof(half) * K * N, cudaMemcpyHostToDevice);
    cudaMemcpy(reinterpret_cast<void*>(dbias), bias, sizeof(half) * M * N, cudaMemcpyHostToDevice);
    mma_m16n16k4<<<1, 32>>>(dA, dB, dC, dbias);
    cudaMemcpy(C, reinterpret_cast<void*>(dC), sizeof(half) * M * N, cudaMemcpyDeviceToHost);
    printf("-----------------------------------------------------------------------------------------\n");
    for (int m = 0; m < M; m++) {
        for (int n = 0; n < N; n++) {
            C_ref[m * N + n] =  0.0;
            for (int k = 0; k < K; k++) {
                C_ref[m * N + n] += ((float)A[m * K + k] * (float)B[k * N + n]);
            }   
        }   
    }   
    for (int i = 0; i < M; i++) {
        for (int j = 0; j < N; j++) {
            printf("%05.2f%c", (float)C[i * N + j], (j == N - 1) ? '\n' : ' ');
        }   
    }   
    return 0;
}

Expected output:

002.24 002.30 002.36 002.42 002.48 002.54 002.60 002.66 002.72 002.78 002.84 002.90 002.96 003.02 003.08 003.14
006.08 006.30 006.52 006.74 006.96 007.18 007.40 007.62 007.84 008.06 008.28 008.50 008.72 008.94 009.16 009.38
009.92 010.30 010.68 011.06 011.44 011.82 012.20 012.58 012.96 013.34 013.72 014.10 014.48 014.86 015.24 015.62
013.76 014.30 014.84 015.38 015.92 016.46 017.00 017.54 018.08 018.62 019.16 019.70 020.24 020.78 021.32 021.86
017.60 018.30 019.00 019.71 020.40 021.10 021.80 022.50 023.20 023.90 024.60 025.30 026.00 026.70 027.40 028.10
021.44 022.30 023.16 024.03 024.88 025.74 026.60 027.46 028.32 029.18 030.04 030.90 031.76 032.62 033.48 034.34
025.27 026.29 027.32 028.34 029.35 030.37 031.39 032.42 033.44 034.46 035.47 036.49 037.51 038.54 039.56 040.57
029.12 030.30 031.48 032.67 033.84 035.02 036.19 037.38 038.56 039.74 040.92 042.09 043.28 044.47 045.64 046.82
032.96 034.30 035.64 036.99 038.32 039.66 041.00 042.35 043.69 045.03 046.36 047.70 049.04 050.39 051.73 053.06
036.80 038.30 039.80 041.31 042.80 044.30 045.80 047.31 048.81 050.31 051.80 053.30 054.80 056.31 057.81 059.30
040.64 042.30 043.96 045.63 047.28 048.94 050.60 052.27 053.93 055.59 057.24 058.90 060.56 062.23 063.89 065.54
044.48 046.30 048.12 049.95 051.76 053.58 055.39 057.23 059.05 060.86 062.68 064.49 066.32 068.15 069.97 071.78
048.32 050.30 052.29 054.28 056.24 058.22 060.20 062.19 064.17 066.15 068.12 070.10 072.08 074.07 076.05 078.02
052.15 054.29 056.44 058.59 060.71 062.85 064.99 067.14 069.28 071.42 073.55 075.68 077.83 079.98 082.12 084.25
055.99 058.29 060.59 062.91 065.19 067.49 069.79 072.10 074.40 076.70 078.99 081.28 083.59 085.90 088.20 090.49
059.84 062.30 064.76 067.24 069.68 072.14 074.60 077.07 079.54 081.99 084.44 086.89 089.36 091.84 094.29 096.74

My actual output:

02.24 02.30 02.36 02.42 02.48 02.54 02.60 02.66 07.84 08.06 08.28 08.50 08.72 08.95 09.16 09.38
09.91 10.30 10.68 11.06 11.44 11.82 12.20 12.58 18.08 18.62 19.16 19.70 20.23 20.78 21.33 21.86
17.59 18.30 19.00 19.70 20.41 21.09 21.80 22.50 28.33 29.19 30.05 30.89 31.77 32.62 33.47 34.34
25.28 26.30 27.31 28.34 29.34 30.38 31.39 32.41 38.56 39.75 40.91 42.09 43.28 44.47 45.66 46.81
32.97 34.31 35.66 37.00 38.31 39.66 41.00 42.34 48.81 50.31 51.81 53.28 54.81 56.31 57.81 59.31
40.66 42.31 43.97 45.62 47.28 48.94 50.59 52.28 59.06 60.88 62.69 64.50 66.31 68.12 69.94 71.75
48.31 50.31 52.28 54.28 56.25 58.22 60.19 62.19 69.31 71.44 73.56 75.69 77.81 80.00 82.12 84.25
56.00 58.28 60.59 62.91 65.19 67.50 69.81 72.12 79.56 82.00 84.44 86.88 89.38 91.81 94.31 96.75
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00
00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00 00.00

Difference:

Half of my result are zeros, and the rest results are correct in values while incorrect in positions. I guess that A and B are correctly loaded into the regs since if they are not correctly loaded, the rest results should not be correct. I wonder what are potential causes of this problem?

here is an example that may be of interest. It will fill in the full 256 output elements for the m8n8k4 mma instruction. To witness it, all you need to do is change the final printout loop in main to print 256 elements.