Help with ldmatrix instruction

I’m very new to inline PTX instruction.
This is the code using inline ldmatrix and mma instruction.
However, the ldmatrixs to load ra and rb behave unexpectedly.
s_b and s_a are all 1s, but the result in ra and rb are not 1s.



Can anyone help me debug this?

#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <stdint.h>

#define CHECK_CUDA(call) \
    if ((call) != cudaSuccess) { \
        fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(call)); \
        exit(EXIT_FAILURE); \
    }

#define OFFSET(i, j, ld) (((i) * (ld)) + (j))
#define FETCH_FLOAT4(ptr) (*reinterpret_cast<float4 *>(&(ptr)))

#define WARP_SIZE 32

template <const int BM, const int BN>
__global__ void mma_m16n8k16_ptx2(half* A, half* B, int M, int N, int K, float* C) {
  const int WARP_N = BN / 8;
  const int WRAP_ID = threadIdx.x / WARP_SIZE;
  const int LANE_ID = threadIdx.x % WARP_SIZE;
  const int WARP_ROW = WRAP_ID / WARP_N;
  const int WARP_COL = WRAP_ID % WARP_N;

  const int C_THREAD_ROW = LANE_ID / 4;
  const int C_THREAD_COL = LANE_ID % 4;

  float c[4] = {0.0, 0.0, 0.0, 0.0};
  __shared__ half s_a[BM][16];
  __shared__ half s_b[16][BN];

  uint32_t ra[4];
  uint32_t rb[2];

  const int WARP_ROW_OFFSET_A = WARP_ROW * 16;
  const int WARP_COL_OFFSET_B = WARP_COL * 8;
  const int WARP_ROW_OFFSET_C = WARP_ROW * 16 + blockIdx.y * BM;
  const int WARP_COL_OFFSET_C = WARP_COL * 8 + blockIdx.x * BN;

  const int BLOCK_ROW_OFFSET = blockIdx.y * BM;
  const int BLOCK_COL_OFFSET = blockIdx.x * BN;

  const int NUM_ELEMENT_PER_THREAD = sizeof(float4) / sizeof(half);
  const int ROW_OFFSET_A = threadIdx.x * NUM_ELEMENT_PER_THREAD / 16;
  const int COL_OFFSET_A = (threadIdx.x * NUM_ELEMENT_PER_THREAD) % 16;
  const int ROW_STRIDE_A = blockDim.x * NUM_ELEMENT_PER_THREAD / 16;
  const int ROW_OFFSET_B = threadIdx.x * NUM_ELEMENT_PER_THREAD / BN;
  const int COL_OFFSET_B = (threadIdx.x * NUM_ELEMENT_PER_THREAD) % BN;
  const int ROW_STRIDE_B = blockDim.x * NUM_ELEMENT_PER_THREAD / BN;

  for (int i = 0; i < K; i += 16) {
    for (int offset = 0; offset < BM; offset += ROW_STRIDE_A) {
      if (offset + ROW_OFFSET_A < BM) {
        FETCH_FLOAT4(s_a[offset + ROW_OFFSET_A][COL_OFFSET_A]) =
            FETCH_FLOAT4(A[OFFSET(BLOCK_ROW_OFFSET + offset + ROW_OFFSET_A, COL_OFFSET_A + i, K)]);
      }
    }

    for (int offset = 0; offset < 16; offset += ROW_STRIDE_B) {
      if (offset + ROW_OFFSET_B < 16) {
        FETCH_FLOAT4(s_b[offset + ROW_OFFSET_B][COL_OFFSET_B]) =
            FETCH_FLOAT4(B[OFFSET(i + offset + ROW_OFFSET_B, BLOCK_COL_OFFSET + COL_OFFSET_B, N)]);
      }
    }
    __syncthreads();

    uint32_t addr_a = __cvta_generic_to_shared(&s_a[WARP_ROW_OFFSET_A + LANE_ID % 16][(LANE_ID / 16) * 8]);
    uint32_t addr_b = __cvta_generic_to_shared(&s_b[LANE_ID % 16][WARP_COL_OFFSET_B]);

    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
                 : "=r"(ra[0]), "=r"(ra[1]), "=r"(ra[2]), "=r"(ra[3])
                 : "r"(addr_a));

    if (LANE_ID < 16) {
      asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];"
                   : "=r"(rb[0]), "=r"(rb[1])
                   : "r"(addr_b));
    }

    asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
        " { %0, %1, %2, %3 }, "
        " { %4, %5, %6, %7 }, "
        " { %8, %9 }, "
        " { %0, %1, %2, %3 };"
        : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3])
        : "r"(ra[0]), "r"(ra[1]), "r"(ra[2]), "r"(ra[3]), "r"(rb[0]), "r"(rb[1]));

    __syncthreads();
  }

  C[OFFSET(WARP_ROW_OFFSET_C + C_THREAD_ROW, WARP_COL_OFFSET_C + C_THREAD_COL * 2, N)] = c[0];
  C[OFFSET(WARP_ROW_OFFSET_C + C_THREAD_ROW, WARP_COL_OFFSET_C + C_THREAD_COL * 2 + 1, N)] = c[1];
  C[OFFSET(WARP_ROW_OFFSET_C + C_THREAD_ROW + 8, WARP_COL_OFFSET_C + C_THREAD_COL * 2, N)] = c[2];
  C[OFFSET(WARP_ROW_OFFSET_C + C_THREAD_ROW + 8, WARP_COL_OFFSET_C + C_THREAD_COL * 2 + 1, N)] =
      c[3];
}

int main() {
    int M = 128;
    int N = 128;
    int K = 128;

    size_t size_A = M * K * sizeof(half);
    size_t size_B = K * N * sizeof(half);
    size_t size_C = M * N * sizeof(float);

    half *h_A = (half *)malloc(size_A);
    half *h_B = (half *)malloc(size_B);
    float *h_C = (float *)malloc(size_C);

    for (int i = 0; i < M * K; ++i)
        h_A[i] = __float2half(1.0f);
    for (int i = 0; i < K * N; ++i)
        h_B[i] = __float2half(1.0f);

    half *d_A, *d_B;
    float *d_C;
    CHECK_CUDA(cudaMalloc(&d_A, size_A));
    CHECK_CUDA(cudaMalloc(&d_B, size_B));
    CHECK_CUDA(cudaMalloc(&d_C, size_C));

    CHECK_CUDA(cudaMemcpy(d_A, h_A, size_A, cudaMemcpyHostToDevice));
    CHECK_CUDA(cudaMemcpy(d_B, h_B, size_B, cudaMemcpyHostToDevice));
    CHECK_CUDA(cudaMemset(d_C, 0, size_C));

    const int BM = 32;
    const int BN = 32;
    const int WARP_NUM = (BM * BN) / (16 * 8);
    const dim3 block_dim(WARP_NUM * WARP_SIZE);
    const dim3 grid_dim(N / BN, M / BM);
    mma_m16n8k16_ptx2<BM, BN><<<grid_dim, block_dim>>>(d_A, d_B, M, N, K, d_C);
    CHECK_CUDA(cudaDeviceSynchronize());

    CHECK_CUDA(cudaMemcpy(h_C, d_C, size_C, cudaMemcpyDeviceToHost));

    printf("C[0][0] = %f\n", h_C[OFFSET(0, 0, N)]);
    printf("C[1][0] = %f\n", h_C[OFFSET(1, 0, N)]);
    printf("C[0][1] = %f\n", h_C[OFFSET(0, 1, N)]);

    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
    free(h_A);
    free(h_B);
    free(h_C);

    return 0;
}

This is undefined behaviour. All threads in the warp need to execute the ldmatrix instruction (since all of them receive two elements of each 8x8 tile)

I think .aligned compared to .synchronized even demands exactly the same instruction (same program pointer value).