PTXAS: mysterious warning for wgmma.mma_async instruction serialization

I encountered a strange warning when compiling a GEMM kernel for Hopper cards.

ptxas info    : (C7514) Potential Performance Loss: wgmma.mma_async instructions are serialized due to non wgmma instructions reading accumulator registers of  a wgmma between start and end of the pipeline stage in the function

Then I disassembled the executable, seeing PTXAS inserted “WARPGROUP.DEPBAR.LE+WARPGROUP.ARRIVE” between some QGMMA instructions where there are actually NO OTHER INSTRUCTIONS VISITING ACCUMULATOR REGISTERS.



So the question is why PTXAS decides to add synchronization instructions between them (and emit this warning) ? The CUDA toolchain version is 12.8
a9526b2fbd6c7252e30a6b23830dcca
Here are the compilation flags:

-arch=sm_90a -Xcompiler -march=native -Xptxas -v -keep -DNDEBUG

See Log in | NVIDIA Developer for details.

Now I believe this is a bug of NVCC. Here is a simple program that can reproduce this issue:

#include <cstdint>

__launch_bounds__(128, 1)
__global__ void test(unsigned *result, unsigned iters) {
  __shared__ unsigned buffer[8192];
  unsigned *buf1 = &buffer[0], *buf2 = &buffer[4096];
  unsigned *bufs1, *bufs2;
  asm ("cvta.to.shared.u64 %0,%1;\n":"=l"(bufs1):"l"(buf1));
  asm ("cvta.to.shared.u64 %0,%1;\n":"=l"(bufs2):"l"(buf2));
  unsigned startaddr = threadIdx.x;
  for (; startaddr < 8192; startaddr += blockDim.x) {
    buffer[startaddr] = result[startaddr];
  }
  __syncthreads();
  // const unsigned input_swizzle_bytes = 128;
  const unsigned input_stride_16B = 2;
  const unsigned input_lbo = 16;
  const unsigned input_sbo = 1024u;
  const unsigned buf1addr = (unsigned)reinterpret_cast<std::uintptr_t>(bufs1);
  const unsigned buf2addr = (unsigned)reinterpret_cast<std::uintptr_t>(bufs2);
  auto descr_encode = [](unsigned inp)->std::size_t { return (inp & 0x3FFFF) >> 4; };
  const std::size_t basic_descr = (descr_encode(input_lbo) << 16) |
    (descr_encode(input_sbo) << 32) | (std::size_t(1) << 62 /* swizzle 128B */);
  const std::size_t descr1_a = basic_descr | descr_encode(buf1addr);
  const std::size_t descr1_b = basic_descr | descr_encode(buf1addr + 8192);
  const std::size_t descr2_a = basic_descr | descr_encode(buf2addr);
  const std::size_t descr2_b = basic_descr | descr_encode(buf2addr + 8192);
  auto commit_m64n64k128 = [&](unsigned (&result)[32],
    std::size_t descr_a, std::size_t descr_b)->void {

    asm volatile ("fence.proxy.async.shared::cta;\n":::"memory");
    asm volatile ("wgmma.fence.sync.aligned;\n":::"memory");
    asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
      "{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
      "%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,0;\n"
     :"=r"(result[0]),"=r"(result[1]),"=r"(result[2]),"=r"(result[3]),
      "=r"(result[4]),"=r"(result[5]),"=r"(result[6]),"=r"(result[7]),
      "=r"(result[8]),"=r"(result[9]),"=r"(result[10]),"=r"(result[11]),
      "=r"(result[12]),"=r"(result[13]),"=r"(result[14]),"=r"(result[15]),
      "=r"(result[16]),"=r"(result[17]),"=r"(result[18]),"=r"(result[19]),
      "=r"(result[20]),"=r"(result[21]),"=r"(result[22]),"=r"(result[23]),
      "=r"(result[24]),"=r"(result[25]),"=r"(result[26]),"=r"(result[27]),
      "=r"(result[28]),"=r"(result[29]),"=r"(result[30]),"=r"(result[31])
     :"l"(descr_a),"l"(descr_b):"memory");
    /* descr_a and descr_b are matrix descriptors, not accumulators */
    descr_a += input_stride_16B;
    descr_b += input_stride_16B;
    asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
      "{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
      "%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,1;\n"
     :"+r"(result[0]),"+r"(result[1]),"+r"(result[2]),"+r"(result[3]),
      "+r"(result[4]),"+r"(result[5]),"+r"(result[6]),"+r"(result[7]),
      "+r"(result[8]),"+r"(result[9]),"+r"(result[10]),"+r"(result[11]),
      "+r"(result[12]),"+r"(result[13]),"+r"(result[14]),"+r"(result[15]),
      "+r"(result[16]),"+r"(result[17]),"+r"(result[18]),"+r"(result[19]),
      "+r"(result[20]),"+r"(result[21]),"+r"(result[22]),"+r"(result[23]),
      "+r"(result[24]),"+r"(result[25]),"+r"(result[26]),"+r"(result[27]),
      "+r"(result[28]),"+r"(result[29]),"+r"(result[30]),"+r"(result[31])
     :"l"(descr_a),"l"(descr_b):"memory");
    descr_a += input_stride_16B;
    descr_b += input_stride_16B;
    asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
      "{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
      "%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,1;\n"
     :"+r"(result[0]),"+r"(result[1]),"+r"(result[2]),"+r"(result[3]),
      "+r"(result[4]),"+r"(result[5]),"+r"(result[6]),"+r"(result[7]),
      "+r"(result[8]),"+r"(result[9]),"+r"(result[10]),"+r"(result[11]),
      "+r"(result[12]),"+r"(result[13]),"+r"(result[14]),"+r"(result[15]),
      "+r"(result[16]),"+r"(result[17]),"+r"(result[18]),"+r"(result[19]),
      "+r"(result[20]),"+r"(result[21]),"+r"(result[22]),"+r"(result[23]),
      "+r"(result[24]),"+r"(result[25]),"+r"(result[26]),"+r"(result[27]),
      "+r"(result[28]),"+r"(result[29]),"+r"(result[30]),"+r"(result[31])
     :"l"(descr_a),"l"(descr_b):"memory");
    descr_a += input_stride_16B;
    descr_b += input_stride_16B;
    asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
      "{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
      "%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,1;\n"
     :"+r"(result[0]),"+r"(result[1]),"+r"(result[2]),"+r"(result[3]),
      "+r"(result[4]),"+r"(result[5]),"+r"(result[6]),"+r"(result[7]),
      "+r"(result[8]),"+r"(result[9]),"+r"(result[10]),"+r"(result[11]),
      "+r"(result[12]),"+r"(result[13]),"+r"(result[14]),"+r"(result[15]),
      "+r"(result[16]),"+r"(result[17]),"+r"(result[18]),"+r"(result[19]),
      "+r"(result[20]),"+r"(result[21]),"+r"(result[22]),"+r"(result[23]),
      "+r"(result[24]),"+r"(result[25]),"+r"(result[26]),"+r"(result[27]),
      "+r"(result[28]),"+r"(result[29]),"+r"(result[30]),"+r"(result[31])
     :"l"(descr_a),"l"(descr_b):"memory");
    asm volatile ("wgmma.commit_group.sync.aligned;\n":::"memory");
  };
  auto save = [](const unsigned (&result)[32], unsigned *start)->void {
    //blockDim.x == 128
    #pragma unroll
    for (int i = 0; i < 32; ++i) start[i * blockDim.x + threadIdx.x] = result[i];
    __syncthreads();
  };

  unsigned result1[32], result2[32];
  commit_m64n64k128(result1, descr1_a, descr1_b);
  for (unsigned i = 1; i < iters; ++i) {
    commit_m64n64k128(result2, descr2_a, descr2_b);
    asm volatile ("wgmma.wait_group.sync.aligned 1;\n":::"memory");
    save(result1, buffer);
    commit_m64n64k128(result1, descr1_a, descr1_b);
    asm volatile ("wgmma.wait_group.sync.aligned 1;\n":::"memory");
    save(result2, buffer + 4096);
  }
  commit_m64n64k128(result2, descr2_a, descr2_b);
  asm volatile ("wgmma.wait_group.sync.aligned 1;\n":::"memory");
  save(result1, buffer);
  asm volatile ("wgmma.wait_group.sync.aligned 0;\n":::"memory");
  save(result2, buffer + 4096);

  for (startaddr = threadIdx.x; startaddr < 8192; startaddr += blockDim.x) {
    result[startaddr] = buffer[startaddr];
  }
}

When changing the loop times of the “for” loop to an interger constant at line 98 of the code above (and write “#pragma unroll” before that loop), the compiler stops emitting warnings and there are no additional insertions of any synchronization instructions between wgmma instructions.

Generally speaking, the NVCC compiler fails to analyze dependencies of wgmma instructions correctly when encountered simple loop branchings, which is annoying. This issue does not exist for other asynchronous instructions like “cp.async(.wait_group N)”.

This issue still exist for CUDA-13.0

You might want to file a bug and refer to this thread.

A bug has already been filed as indicated in comment #2 in this thread. The development team is looking at it internally as recently as August 6th. It doesn’t appear to be resolved yet. However there is no indication (that I can see) in the bug referring back to this thread. I’ve added a note to the bug referencing this thread.

1 Like

Hi, thanks for reporting NVBUG 5431330 . Below the latest engineering team update .

Thank you for reporting this case to us. As you’ve already discovered in your other report, changing the first wgmma.wait_group.sync.aligned 1 in the loop to wgmma.wait_group.sync.aligned 0 can serve as a workaround for the observed behavior.

There is indeed a diagnostic issue in our compiler, it is currently too conservative in detecting register usage within wgmma loops. Specifically, it conservatively assume that certain registers from wgmma are not ready to be accessed in the loop even when they are completed with the mix of “wgmma.wait_group.sync.aligned count 1” and other gmma set.

Switching to wgmma.wait_group.sync.aligned 0 ensures that all wgmma operations are completed, making it the safest way to guarantee the correctness of the results. Until we relax the diagnostic behavior in the compiler, this remains the recommended approach when using wgmma.

We appreciate your feedback and will keep you updated as we address this issue. Thanks again for reaching out.

Best,

Yuki

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.