Understanding the behaivor of ldmatrix in terms of shared memory access

As indicated by the PTX documentation, when using ldmatrix.x4, every 8 threads will form a group (i.e. T0-T7, T8-T15, T16-T23, T24-T31) executing one memory transaction. Since each memory transaction can process at most 128 bytes, thus 4x8x8 uint16_t matrix requires 4x8x8x2/128=4 memory transactions.

The shared memory bank conflict should occur when the thread from the same group reading the same bank (but different words). However, the following code only has 1 instructions 1 requests 0 wavefronts and 0 bank conflicts in Nsight Compute (tab “Shared Load Matrix”), which seems to be contradictory. I am wondering how to understand the behavior of ldmatrix instructions? How to calculate the bank conflicts of a ldmatrix instruction?

For reference, the code is compiled on RTX 2080 Ti & CUDA 12.3.

#include <iostream>

__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
  uint32_t address;
  asm("{\n\t"
      "  .reg .u64 u64addr;\n\t"
      "  cvta.to.shared.u64 u64addr, %1;\n\t"
      "  cvt.u32.u64 %0, u64addr;\n\t"
      "}"
      : "=r"(address)
      : "l"(pointer));
  return address;
}

__device__ __forceinline__ void ldmatrix_sync_aligned_m8n8_x4_b16(
    uint32_t &d0, uint32_t &d1, uint32_t &d2, uint32_t &d3,
    const uint32_t &address) {
  asm volatile(
      "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
      : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
      : "r"(address));
}

__global__ void ldmatrix_bank_conflict(uint16_t *value) {
  __shared__ uint16_t smem[4096];
  auto tid = threadIdx.x;
  // bank: tid * 64 * 2 / 4 % 32 = 0
  // T0 read: smem[0],   ..., smem[7]   -> bank 0 to bank 3
  // T1 read: smem[64],  ..., smem[71]  -> bank 0 to bank 3
  // T2 read: smem[128], ..., smem[135] -> bank 0 to bank 3
  // T3 read: smem[192], ..., smem[199] -> bank 0 to bank 3
  // T4 read: smem[256], ..., smem[263] -> bank 0 to bank 3
  // T5 read: smem[320], ..., smem[327] -> bank 0 to bank 3
  // T6 read: smem[384], ..., smem[391] -> bank 0 to bank 3
  // T7 read: smem[448], ..., smem[455] -> bank 0 to bank 3
  const uint32_t address =
      cvta_to_shared_u32(smem) + sizeof(uint16_t) * (64 * tid);
  for (uint32_t i = tid; i < 4096; i += blockDim.x) {
    smem[i] = i;
  }
  __syncthreads();
  ldmatrix_sync_aligned_m8n8_x4_b16(
      *reinterpret_cast<uint32_t *>(value + threadIdx.x * 2 + 0 * 2 * 32),
      *reinterpret_cast<uint32_t *>(value + threadIdx.x * 2 + 1 * 2 * 32),
      *reinterpret_cast<uint32_t *>(value + threadIdx.x * 2 + 2 * 2 * 32),
      *reinterpret_cast<uint32_t *>(value + threadIdx.x * 2 + 3 * 2 * 32),
      address);
}

int main() {
  uint16_t *d_value;
  cudaMalloc(&d_value, sizeof(uint16_t));
  ldmatrix_bank_conflict<<<1, 32>>>(d_value);
  cudaDeviceSynchronize();
  cudaFree(d_value);
  return 0;
}

If I recall when the LDSM instruction was added the shared memory PM signals were not updated to correctly count the wavefronts and conflicts for this instruction type. This was fixed in a later chip.

                  sm7.5       sm86
instructions      1           1
requests          1           1
wavefronts        0           32
bank conflicts    0           28

There are source page counters for the instruction. In the Nsight Compute UI:

  • Go to the Source Page
  • Set View = SASS
  • Scroll down to the LDSM.16.M88.4
  • Right click on the column header and executed Column Chooser…
  • Enable L1 Wavefronts Shared Excessive
  • Enable L1 Wavefronts Shared
  • Enable L1 Wavefronts Ideal

For both chips this should show:

L1 Wavefronts Shared Excessive  28
L1 Wavefronts Shared            32
L1 Wavefronts Ideal             4

The Source Page counters are implemented by SASS patching. The Shared Memory table for TU10x appears to only be using HW counters.

If this is a problem you can file a bug against Nsight Compute to have the cells in the memory table for TU10x to add into the current merged Load cell the sum of the counters in the Source Page.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.