Alignment requirement for the `ldmatrix` instruction

Hello. I have a question regarding the alignment requirement for the ldmatrix instruction.

As documented in ldmatrix’s document, the matrix addresses must be naturally aligned accordingly, and the definition of “naturally aligned” can be found here: The address must be naturally aligned to a multiple of the access size.

My question is: What is the specific “access size” for the ldmatrix instruction, and what exact alignment is required for the pointer we pass to this instruction?

When reading 8x8 matrices, a group of four consecutive threads loads 16 bytes. The matrix addresses must be naturally aligned accordingly.

I think the pointers have to be 16-byte aligned.

I think the alignment should be 16 bytes. I’ve written a simple GEMM kernel using CUTLASS and tried to alter the shared memory offset (which results in different alignments). The code is here.
What I’ve found out is that, when the SMEM_OFFSET_BYTES is a multiple of 16, it works perfectly, otherwise it results in a “CUDA error: misaligned address”. I think now we are confident to say that the address should be 16-bytes aligned.

The code is here:

#include <chrono>
#include <iostream>

#include <cublas_v2.h>

#include <cutlass/cutlass.h>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>
#include <cute/numeric/numeric_types.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/copy_atom.hpp>

#include "common.h"

using data_t = cute::half_t;
static constexpr size_t SMEM_OFFSET_BYTES = 0;

std::pair<float, float> gemm_cublas(
  int M, int N, int K,
  data_t *A, data_t *B, data_t *C
) {
  using cuda_half = __half;
  cublasHandle_t handle;
  cublasCreate(&handle);
  cuda_half alpha = 1.0, beta = 0.0;
  CUDATimer timer;
  int num_warmup_iters = 4;
  int num_iters = 4;
  for (int i = 0; i < num_warmup_iters + num_iters; ++i) {
    if (i == num_warmup_iters) timer.start();
    // A, B, C are all row major
    // A: MxK, B: NxK, C: MxN
    cublasHgemm(
      handle,
      CUBLAS_OP_T,
      CUBLAS_OP_N,
      N, M, K,
      &alpha,
      (cuda_half*)B, K,
      (cuda_half*)A, K,
      &beta,
      (cuda_half*)C, N
    );
  }
  float elapsed_time_ms = timer.stop();
  cublasDestroy(handle);
  float flops = 2.0 * M * N * K * num_iters;
  float gflops = flops / (elapsed_time_ms/1000) / 1e9;
  return {elapsed_time_ms / num_iters, gflops};
}

using namespace cute;
template<typename T_, int TILE_M_, int TILE_N_, int TILE_K_, int NUM_LDST_, int CTA_TILE_M_, int CTA_TILE_N_>
struct Gemm3Config {
  using T = T_;
  using TILE_M = Int<TILE_M_>;
  using TILE_N = Int<TILE_N_>;
  using TILE_K = Int<TILE_K_>;
  using NUM_LDST = Int<NUM_LDST_>;
  using CTA_TILE_M = Int<CTA_TILE_M_>;
  using CTA_TILE_N = Int<CTA_TILE_N_>;

  using SmemLayoutA = decltype(composition(
    Swizzle<3, 3>{},
    Layout<Shape<TILE_M, TILE_K, NUM_LDST>, Stride<TILE_K, _1, decltype(TILE_M{}*TILE_K{})>>{}
  ));
  using SmemLayoutB = decltype(composition(
    Swizzle<3, 3>{},
    Layout<Shape<TILE_N, TILE_K, NUM_LDST>, Stride<TILE_K, _1, decltype(TILE_N{}*TILE_K{})>>{}
  ));
  CUTE_STATIC_ASSERT_V(TILE_M{} == size<0>(SmemLayoutA{}));
  CUTE_STATIC_ASSERT_V(TILE_K{} == size<1>(SmemLayoutA{}));
  CUTE_STATIC_ASSERT_V(TILE_N{} == size<0>(SmemLayoutB{}));
  CUTE_STATIC_ASSERT_V(TILE_K{} == size<1>(SmemLayoutB{}));

  using TiledMMA = decltype(make_tiled_mma(
    SM80_16x8x16_F16F16F16F16_TN{},
    Layout<Shape<_2, _2, _1>>{},
    Tile<_32, _32, _16>{}
  ));

  using G2STiledCopy = decltype(make_tiled_copy(
    Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, T>{},
    Layout<Shape<_32, _4>, Stride<_4, _1>>{},
    Layout<Shape<_1, _8>>{}
  ));
  CUTE_STATIC_ASSERT_V(thr_size(TiledMMA{}) == size(G2STiledCopy{}));

  using S2RTiledCopyA = decltype(make_tiled_copy_A(
    Copy_Atom<SM75_U32x4_LDSM_N, T>{},
    TiledMMA{}
  ));
  using S2RTiledCopyB = decltype(make_tiled_copy_B(
    Copy_Atom<SM75_U32x4_LDSM_N, T>{},
    TiledMMA{}
  ));

  using SmemLayoutC = decltype(tile_to_shape(
    composition(
      Swizzle<3, 3, 3>{},
      Layout<Shape<_8, _64>, Stride<_64, _1>>{} // _64 = 128/2 is the number of halfs in every 128B
    ),
    Layout<Shape<_32, TILE_N, _2>>{}  // _32 is the length on the m-axis for one tile in TiledMMA
  ));
  CUTE_STATIC_ASSERT_V(TILE_N{} == size<1>(SmemLayoutC{}));
  CUTE_STATIC_ASSERT_V(cosize(SmemLayoutC{}) <= cosize(SmemLayoutA{}) + cosize(SmemLayoutB{}));

  using R2STiledCopyC = decltype(make_tiled_copy_C(
    Copy_Atom<UniversalCopy<uint32_t>, T>{},
    TiledMMA{}
  ));
  using S2GTiledCopy = decltype(make_tiled_copy(
    Copy_Atom<UniversalCopy<uint128_t>, T>{},
    Layout<Shape<_8, _16>, Stride<_16, _1>>{},
    Layout<Shape<_1, _8>>{}
  ));
  CUTE_STATIC_ASSERT_V(thr_size(TiledMMA{}) == size(S2GTiledCopy{}));
};

template<class Config>
__launch_bounds__(decltype(thr_size((typename Config::TiledMMA){}))::value)
__global__ void gemm3_fwd(int M, int N, int K, data_t *A, data_t *B, data_t *C) {
  using namespace cute;

  using T = typename Config::T;
  using TILE_M = typename Config::TILE_M;
  using TILE_N = typename Config::TILE_N;
  using TILE_K = typename Config::TILE_K;
  using SmemLayoutA = typename Config::SmemLayoutA;
  using SmemLayoutB = typename Config::SmemLayoutB;

  // A mapping from CTA id to the tile id
  static constexpr int CTA_TILE_M = Config::CTA_TILE_M::value;
  static constexpr int CTA_TILE_N = Config::CTA_TILE_N::value;
  int num_m_tiles = CDIV(M, TILE_M{});
  int num_n_tiles = CDIV(N, TILE_N{});
  int num_tiles_in_one_cta_row = CTA_TILE_M * num_n_tiles;
  int cta_row_idx = blockIdx.x / num_tiles_in_one_cta_row;
  int cta_tile_size_in_cur_row = min(CTA_TILE_M, num_m_tiles - cta_row_idx*CTA_TILE_M) * CTA_TILE_N;
  int cta_col_idx = blockIdx.x % num_tiles_in_one_cta_row / cta_tile_size_in_cur_row;
  int idx_in_cta_tile = blockIdx.x % num_tiles_in_one_cta_row % cta_tile_size_in_cur_row;

  int cur_cta_tile_n = min(CTA_TILE_N, num_n_tiles - cta_col_idx*CTA_TILE_N);
  int tile_m_idx = cta_row_idx * CTA_TILE_M + idx_in_cta_tile / cur_cta_tile_n;
  int tile_n_idx = cta_col_idx * CTA_TILE_N + idx_in_cta_tile % cur_cta_tile_n;

  Tensor gA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
  Tensor gB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
  Tensor gC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));

  Tensor tile_gA = local_tile(gA, Tile<TILE_M, TILE_K>{}, make_coord(tile_m_idx, _)); // (TILE_M, TILE_K, NUM_K_TILES)
  Tensor tile_gB = local_tile(gB, Tile<TILE_N, TILE_K>{}, make_coord(tile_n_idx, _)); // (TILE_N, TILE_K, NUM_K_TILES)
  Tensor tile_gC = local_tile(gC, Tile<TILE_M, TILE_N>{}, make_coord(tile_m_idx, tile_n_idx));  // (TILE_M, TILE_N)

  extern __shared__ T smem_buf[];  // size: cosize(SmemLayoutA{}) + cosize(SmemLayoutB{})
  Tensor tile_sA = make_tensor(make_smem_ptr((T*)((char*)smem_buf+SMEM_OFFSET_BYTES)), SmemLayoutA{}); // (TILE_M, TILE_K, NUM_LDST)
  Tensor tile_sB = make_tensor(make_smem_ptr((T*)((char*)smem_buf+SMEM_OFFSET_BYTES) + cosize(SmemLayoutA{})), SmemLayoutB{}); // (TILE_N, TILE_K, NUM_LDST)

  typename Config::TiledMMA tiled_mma;
  ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);
  Tensor mma_frag_rA = thr_mma.partition_fragment_A(tile_gA(_, _, 0));  // (MMA, MMA_M, MMA_K)
  Tensor mma_frag_rB = thr_mma.partition_fragment_B(tile_gB(_, _, 0));  // (MMA, MMA_N, MMA_K)
  Tensor mma_frag_rC = thr_mma.partition_fragment_C(tile_gC(_, _));  // (MMA, MMA_M, MMA_N)
  clear(mma_frag_rC);

  typename Config::G2STiledCopy g2s_tiled_copy_a;
  ThrCopy g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
  Tensor g2s_copy_frag_gA = g2s_thr_copy_a.partition_S(tile_gA);  // (G2S, G2S_M, G2S_K, NUM_K_TILES)
  Tensor g2s_copy_frag_sA = g2s_thr_copy_a.partition_D(tile_sA);  // (G2S, G2S_M, G2S_K, NUM_LDST)

  typename Config::G2STiledCopy g2s_tiled_copy_b;
  ThrCopy g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x);
  Tensor g2s_copy_frag_gB = g2s_thr_copy_b.partition_S(tile_gB);  // (G2S, G2S_N, G2S_K, NUM_K_TILES)
  Tensor g2s_copy_frag_sB = g2s_thr_copy_b.partition_D(tile_sB);  // (G2S, G2S_N, G2S_K, NUM_LDST)

  typename Config::S2RTiledCopyA s2r_tiled_copy_a;
  ThrCopy thr_s2r_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x);
  Tensor s2r_copy_frag_sA = thr_s2r_copy_a.partition_S(tile_sA);  // (S2R, S2R_M, S2R_K, NUM_LDST)
  Tensor s2r_copy_frag_rA = thr_s2r_copy_a.retile_D(mma_frag_rA);  // (S2R, S2R_M, S2R_K)

  typename Config::S2RTiledCopyB s2r_tiled_copy_b;
  ThrCopy thr_s2r_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x);
  Tensor s2r_copy_frag_sB = thr_s2r_copy_b.partition_S(tile_sB);  // (S2R, S2R_N, S2R_K, NUM_LDST)
  Tensor s2r_copy_frag_rB = thr_s2r_copy_b.retile_D(mma_frag_rB);  // (S2R, S2R_N, S2R_K)

  CUTE_STATIC_ASSERT_V(size<2>(s2r_copy_frag_sA) == size<2>(mma_frag_rA));
  
  int num_k_tiles = get<2>(shape(tile_gA));
  static constexpr int NUM_LDST = Config::NUM_LDST::value;
  static constexpr int NUM_MMA_STAGES = size<2>(s2r_copy_frag_rA);
  static constexpr int NUM_ON_THE_FLY_COPIES = NUM_LDST;  // NUM_LDST or NUM_LDST-1. If is NUM_LDST, then can issue one more copy instructions but needs one more sync for each tile
  static_assert(NUM_ON_THE_FLY_COPIES <= NUM_LDST);
  static_assert(NUM_MMA_STAGES > 1);
  
  #pragma unroll
  for (int i = 0; i < NUM_ON_THE_FLY_COPIES; ++i) {
    copy(g2s_tiled_copy_a, g2s_copy_frag_gA(_, _, _, i), g2s_copy_frag_sA(_, _, _, i));
    copy(g2s_tiled_copy_b, g2s_copy_frag_gB(_, _, _, i), g2s_copy_frag_sB(_, _, _, i));
    cp_async_fence();
  }

  cp_async_wait<NUM_ON_THE_FLY_COPIES-1>();
  __syncthreads();

  copy(s2r_tiled_copy_a, s2r_copy_frag_sA(_, _, 0, 0), s2r_copy_frag_rA(_, _, 0));
  copy(s2r_tiled_copy_b, s2r_copy_frag_sB(_, _, 0, 0), s2r_copy_frag_rB(_, _, 0));

  for (int i_tile = 0; i_tile < num_k_tiles; ++i_tile) {
    // At the beginning of the loop, we should have already waited for the earliest copy,
    // performed __syncthreads(), and copy the first fragment to the register
    #pragma unroll
    for (int i_frag = 0; i_frag < NUM_MMA_STAGES; ++i_frag) {
      if (i_frag == NUM_MMA_STAGES - 1) {
        if constexpr(NUM_ON_THE_FLY_COPIES == NUM_LDST) {
          // Need to synchronize all threads, to prevent that:
          // - The current thread has finished frag #i_frag's S->R copying and is about to launch the next G->S copy
          // - Another thread is still performing the S->R copying for frag #i_frag-1
          __syncthreads();
        }
        // We are about to deal with the last one fragment in the current tile
        // Launch a new G->S copy
        int target_i_tile = i_tile + NUM_ON_THE_FLY_COPIES;
        if (target_i_tile < num_k_tiles) {
          int target_i_smem_slot = target_i_tile % NUM_LDST;
          copy(g2s_tiled_copy_a, g2s_copy_frag_gA(_, _, _, target_i_tile), g2s_copy_frag_sA(_, _, _, target_i_smem_slot));
          copy(g2s_tiled_copy_b, g2s_copy_frag_gB(_, _, _, target_i_tile), g2s_copy_frag_sB(_, _, _, target_i_smem_slot));
        }

        // Wait for the next G->S (for tile i_tile+1) to finish
        cp_async_fence();
        cp_async_wait<NUM_ON_THE_FLY_COPIES-1>();
        __syncthreads();

        // Copy the next fragment to the register
        copy(s2r_tiled_copy_a, s2r_copy_frag_sA(_, _, 0, (i_tile+1)%NUM_LDST), s2r_copy_frag_rA(_, _, 0));
        copy(s2r_tiled_copy_b, s2r_copy_frag_sB(_, _, 0, (i_tile+1)%NUM_LDST), s2r_copy_frag_rB(_, _, 0));
      } else {
        copy(s2r_tiled_copy_a, s2r_copy_frag_sA(_, _, i_frag+1, i_tile%NUM_LDST), s2r_copy_frag_rA(_, _, i_frag+1));
        copy(s2r_tiled_copy_b, s2r_copy_frag_sB(_, _, i_frag+1, i_tile%NUM_LDST), s2r_copy_frag_rB(_, _, i_frag+1));
      }

      gemm(tiled_mma, mma_frag_rC, mma_frag_rA(_, _, i_frag), mma_frag_rB(_, _, i_frag), mma_frag_rC);
    }
  }

  Tensor tile_sC = make_tensor(make_smem_ptr(smem_buf), typename Config::SmemLayoutC{});  // (MMA_TILE_SIZE_M, TILE_N)

  typename Config::R2STiledCopyC r2s_tiled_copy_c;
  ThrCopy r2s_thr_copy_c = r2s_tiled_copy_c.get_slice(threadIdx.x);
  Tensor r2s_copy_frag_rC = r2s_thr_copy_c.retile_S(mma_frag_rC);  // (R2S, MMA_M, MMA_N)
  Tensor r2s_copy_frag_sC = r2s_thr_copy_c.partition_D(tile_sC);  // (R2S, MMA_M_PER_TILE, MMA_N, 2)

  typename Config::S2GTiledCopy s2g_tiled_copy;
  ThrCopy s2g_thr_copy = s2g_tiled_copy.get_slice(threadIdx.x);
  Tensor s2g_copy_frag_sC = s2g_thr_copy.partition_S(tile_sC);  // (S2G, S2G_M_PER_TILE, S2G_N, 2)
  Tensor s2g_copy_frag_gC = s2g_thr_copy.partition_D(tile_gC);  // (S2G, S2G_M, S2G_N)

  auto mma_m_per_tile = size<1>(r2s_copy_frag_sC);
  auto mma_m = size<1>(mma_frag_rC);
  auto s2g_m_per_tile = size<1>(s2g_copy_frag_sC);
  auto s2g_m = size<1>(s2g_copy_frag_gC);
  CUTE_STATIC_ASSERT_V(mma_m / mma_m_per_tile == s2g_m / s2g_m_per_tile);

  #pragma unroll
  for (int r2s_start_m = 0, s2g_start_m = 0, cur_buf = 0; r2s_start_m < mma_m; r2s_start_m += mma_m_per_tile, s2g_start_m += s2g_m_per_tile, cur_buf ^= 1) {
    #pragma unroll
    for (int j = 0; j < mma_m_per_tile; ++j) {
      copy(r2s_tiled_copy_c, r2s_copy_frag_rC(_, r2s_start_m+j, _), r2s_copy_frag_sC(_, j, _, cur_buf));
    }
    __syncthreads();
    #pragma unroll
    for (int j = 0; j < s2g_m_per_tile; ++j) {
      copy(s2g_tiled_copy, s2g_copy_frag_sC(_, j, _, cur_buf), s2g_copy_frag_gC(_, s2g_start_m+j, _));
    }
  }
}

void gemm3(int M, int N, int K, data_t *A, data_t *B, data_t *C) {
  using namespace cute;
  using config = Gemm3Config<half_t, 128, 128, 32, 4, 1, 1>;
  dim3 grid_shape = CDIV(M, config::TILE_M{}) * CDIV(N, config::TILE_N{});
  dim3 block_shape = dim3(size(config::TiledMMA{}));
  auto smem_size_bytes = (cosize(config::SmemLayoutA{}) + cosize(config::SmemLayoutB{})) * sizeof(data_t);
  static bool smem_attribute_set = false;
  if (!smem_attribute_set) {
    cudaFuncSetAttribute(gemm3_fwd<config>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_bytes+SMEM_OFFSET_BYTES);
    smem_attribute_set = true;
  }
    CUDA_CHECK(cudaDeviceSynchronize());
  gemm3_fwd<config><<<grid_shape, block_shape, smem_size_bytes+SMEM_OFFSET_BYTES>>>(M, N, K, A, B, C);
    CUDA_CHECK(cudaDeviceSynchronize());
}

int main() {
  srand(0);
  int M = 4096, N = 4096, K = 1024;

  // Allocate host tensors and initialize
  data_t* A_h, *B_h;
  A_h = (data_t*)malloc(M * K * sizeof(data_t));
  B_h = (data_t*)malloc(N * K * sizeof(data_t));
  init_host_tensor(A_h, M * K);
  init_host_tensor(B_h, N * K);

  // Allocate device tensors and copy from host
  data_t* A, *B, *C, *C_std;
  CUDA_CHECK(cudaMalloc(&A, M * K * sizeof(data_t)));
  CUDA_CHECK(cudaMalloc(&B, N * K * sizeof(data_t)));
  CUDA_CHECK(cudaMalloc(&C, M * N * sizeof(data_t)));
  CUDA_CHECK(cudaMalloc(&C_std, M * N * sizeof(data_t)));
  CUDA_CHECK(cudaMemcpy(A, A_h, M * K * sizeof(data_t), cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(B, B_h, N * K * sizeof(data_t), cudaMemcpyHostToDevice));

  // Calculate standard answer by cublas
  auto [cublas_ms, cublas_gflops] = gemm_cublas(M, N, K, A, B, C_std);
  printf("CUBLAS: %.2f ms, %f gflops\n", cublas_ms, cublas_gflops);
  
  auto test_gemm = [&](auto gemm_func) {
    // Test correctness and warmup
    gemm_func(M, N, K, A, B, C);
    CUDA_CHECK(cudaDeviceSynchronize());
    if (!compare_device_device_tensor(C_std, C, M * N)) {
      printf("WARNING: Incorrect result\n");
    }

    // Warmup
    gemm_func(M, N, K, A, B, C);
    gemm_func(M, N, K, A, B, C);
    gemm_func(M, N, K, A, B, C);

    // Benchmark
    const int num_iters = 8;
    CUDA_CHECK(cudaDeviceSynchronize());
    CUDATimer timer;
    timer.start();
    for (int i = 0; i < num_iters; ++i) {
      gemm_func(M, N, K, A, B, C);
    }
    double elapsed_time_ms = timer.stop();
    elapsed_time_ms /= num_iters;
    CUDA_CHECK(cudaDeviceSynchronize());

    double flops = 2.0 * M * N * K;
    double gflops = flops / (elapsed_time_ms/1000) / 1e9;
    printf("%.2f ms, %f gflops\n", elapsed_time_ms, gflops);
    double mem_bw_gBps = (M*K*(N/128) + K*N*(M/128) + M*N) * sizeof(data_t) / (elapsed_time_ms/1000) / 1e9;
    printf("Memory bandwidth: %f GB/s\n", mem_bw_gBps);
  };

  // test_gemm(gemm1);
  // test_gemm(gemm2);
  test_gemm(gemm3);

  return 0;
}

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.