Now I believe this is a bug of NVCC. Here is a simple program that can reproduce this issue:
#include <cstdint>
__launch_bounds__(128, 1)
__global__ void test(unsigned *result, unsigned iters) {
__shared__ unsigned buffer[8192];
unsigned *buf1 = &buffer[0], *buf2 = &buffer[4096];
unsigned *bufs1, *bufs2;
asm ("cvta.to.shared.u64 %0,%1;\n":"=l"(bufs1):"l"(buf1));
asm ("cvta.to.shared.u64 %0,%1;\n":"=l"(bufs2):"l"(buf2));
unsigned startaddr = threadIdx.x;
for (; startaddr < 8192; startaddr += blockDim.x) {
buffer[startaddr] = result[startaddr];
}
__syncthreads();
// const unsigned input_swizzle_bytes = 128;
const unsigned input_stride_16B = 2;
const unsigned input_lbo = 16;
const unsigned input_sbo = 1024u;
const unsigned buf1addr = (unsigned)reinterpret_cast<std::uintptr_t>(bufs1);
const unsigned buf2addr = (unsigned)reinterpret_cast<std::uintptr_t>(bufs2);
auto descr_encode = [](unsigned inp)->std::size_t { return (inp & 0x3FFFF) >> 4; };
const std::size_t basic_descr = (descr_encode(input_lbo) << 16) |
(descr_encode(input_sbo) << 32) | (std::size_t(1) << 62 /* swizzle 128B */);
const std::size_t descr1_a = basic_descr | descr_encode(buf1addr);
const std::size_t descr1_b = basic_descr | descr_encode(buf1addr + 8192);
const std::size_t descr2_a = basic_descr | descr_encode(buf2addr);
const std::size_t descr2_b = basic_descr | descr_encode(buf2addr + 8192);
auto commit_m64n64k128 = [&](unsigned (&result)[32],
std::size_t descr_a, std::size_t descr_b)->void {
asm volatile ("fence.proxy.async.shared::cta;\n":::"memory");
asm volatile ("wgmma.fence.sync.aligned;\n":::"memory");
asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
"{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
"%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,0;\n"
:"=r"(result[0]),"=r"(result[1]),"=r"(result[2]),"=r"(result[3]),
"=r"(result[4]),"=r"(result[5]),"=r"(result[6]),"=r"(result[7]),
"=r"(result[8]),"=r"(result[9]),"=r"(result[10]),"=r"(result[11]),
"=r"(result[12]),"=r"(result[13]),"=r"(result[14]),"=r"(result[15]),
"=r"(result[16]),"=r"(result[17]),"=r"(result[18]),"=r"(result[19]),
"=r"(result[20]),"=r"(result[21]),"=r"(result[22]),"=r"(result[23]),
"=r"(result[24]),"=r"(result[25]),"=r"(result[26]),"=r"(result[27]),
"=r"(result[28]),"=r"(result[29]),"=r"(result[30]),"=r"(result[31])
:"l"(descr_a),"l"(descr_b):"memory");
/* descr_a and descr_b are matrix descriptors, not accumulators */
descr_a += input_stride_16B;
descr_b += input_stride_16B;
asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
"{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
"%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,1;\n"
:"+r"(result[0]),"+r"(result[1]),"+r"(result[2]),"+r"(result[3]),
"+r"(result[4]),"+r"(result[5]),"+r"(result[6]),"+r"(result[7]),
"+r"(result[8]),"+r"(result[9]),"+r"(result[10]),"+r"(result[11]),
"+r"(result[12]),"+r"(result[13]),"+r"(result[14]),"+r"(result[15]),
"+r"(result[16]),"+r"(result[17]),"+r"(result[18]),"+r"(result[19]),
"+r"(result[20]),"+r"(result[21]),"+r"(result[22]),"+r"(result[23]),
"+r"(result[24]),"+r"(result[25]),"+r"(result[26]),"+r"(result[27]),
"+r"(result[28]),"+r"(result[29]),"+r"(result[30]),"+r"(result[31])
:"l"(descr_a),"l"(descr_b):"memory");
descr_a += input_stride_16B;
descr_b += input_stride_16B;
asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
"{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
"%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,1;\n"
:"+r"(result[0]),"+r"(result[1]),"+r"(result[2]),"+r"(result[3]),
"+r"(result[4]),"+r"(result[5]),"+r"(result[6]),"+r"(result[7]),
"+r"(result[8]),"+r"(result[9]),"+r"(result[10]),"+r"(result[11]),
"+r"(result[12]),"+r"(result[13]),"+r"(result[14]),"+r"(result[15]),
"+r"(result[16]),"+r"(result[17]),"+r"(result[18]),"+r"(result[19]),
"+r"(result[20]),"+r"(result[21]),"+r"(result[22]),"+r"(result[23]),
"+r"(result[24]),"+r"(result[25]),"+r"(result[26]),"+r"(result[27]),
"+r"(result[28]),"+r"(result[29]),"+r"(result[30]),"+r"(result[31])
:"l"(descr_a),"l"(descr_b):"memory");
descr_a += input_stride_16B;
descr_b += input_stride_16B;
asm volatile ("wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
"{%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,"
"%20,%21,%22,%23,%24,%25,%26,%27,%28,%29,%30,%31},%32,%33,1;\n"
:"+r"(result[0]),"+r"(result[1]),"+r"(result[2]),"+r"(result[3]),
"+r"(result[4]),"+r"(result[5]),"+r"(result[6]),"+r"(result[7]),
"+r"(result[8]),"+r"(result[9]),"+r"(result[10]),"+r"(result[11]),
"+r"(result[12]),"+r"(result[13]),"+r"(result[14]),"+r"(result[15]),
"+r"(result[16]),"+r"(result[17]),"+r"(result[18]),"+r"(result[19]),
"+r"(result[20]),"+r"(result[21]),"+r"(result[22]),"+r"(result[23]),
"+r"(result[24]),"+r"(result[25]),"+r"(result[26]),"+r"(result[27]),
"+r"(result[28]),"+r"(result[29]),"+r"(result[30]),"+r"(result[31])
:"l"(descr_a),"l"(descr_b):"memory");
asm volatile ("wgmma.commit_group.sync.aligned;\n":::"memory");
};
auto save = [](const unsigned (&result)[32], unsigned *start)->void {
//blockDim.x == 128
#pragma unroll
for (int i = 0; i < 32; ++i) start[i * blockDim.x + threadIdx.x] = result[i];
__syncthreads();
};
unsigned result1[32], result2[32];
commit_m64n64k128(result1, descr1_a, descr1_b);
for (unsigned i = 1; i < iters; ++i) {
commit_m64n64k128(result2, descr2_a, descr2_b);
asm volatile ("wgmma.wait_group.sync.aligned 1;\n":::"memory");
save(result1, buffer);
commit_m64n64k128(result1, descr1_a, descr1_b);
asm volatile ("wgmma.wait_group.sync.aligned 1;\n":::"memory");
save(result2, buffer + 4096);
}
commit_m64n64k128(result2, descr2_a, descr2_b);
asm volatile ("wgmma.wait_group.sync.aligned 1;\n":::"memory");
save(result1, buffer);
asm volatile ("wgmma.wait_group.sync.aligned 0;\n":::"memory");
save(result2, buffer + 4096);
for (startaddr = threadIdx.x; startaddr < 8192; startaddr += blockDim.x) {
result[startaddr] = buffer[startaddr];
}
}
When changing the loop times of the “for” loop to an interger constant at line 98 of the code above (and write “#pragma unroll” before that loop), the compiler stops emitting warnings and there are no additional insertions of any synchronization instructions between wgmma instructions.