First, I initialized a mapped distributed shared memory:
// initialization of distributed shared memory barrier (mapped from one CTA)
if (threadIdx.x == 0 && blockRankInCluster == 0) {
cuda::ptx::mbarrier_init(bar, 2); // tried 1 instead of 2 too
cuda::ptx::fence_mbarrier_init(cuda::ptx::sem_release, cuda::ptx::scope_cluster);
cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, cuda::ptx::scope_cluster, cuda::ptx::space_cluster, reinterpret_cast<uint64_t*>(bar), 1); // tried without this too
}
cluster.sync();
Then initiated TMA using tensor map:
if (threadIdx.x == 0 && blockRankInCluster == 0) {
const uint16_t ctaMask = 0b1; // tried 0b11
cuda::ptx::cp_async_bulk_tensor(cuda::ptx::space_cluster, cuda::ptx::space_global, reinterpret_cast<void*>(s_mem), tensorMap, coords, reinterpret_cast<uint64_t*>(bar), ctaMask);
}
After other asynchronous work, wait like this:
// cuda::ptx:: functions don't return any state. How can I know which state to start? Is state a token like non-cluster TMA operations use?
if (threadIdx.x == 0 && blockRankInCluster == 0) {
while (!cuda::ptx::mbarrier_try_wait(cuda::ptx::sem_acquire, cuda::ptx::scope_cluster, reinterpret_cast<_CUDA_VSTD::uint64_t*>(bar), state)) {
__nanosleep(10);
};
}
but it doesn’t get out of this while-loop.
I used ctaMask = 0b1 in a two-block cluster launch (static define with cluster_dims(2, 1, 1)).
State is incremented after wait but its never reached due to infinite loop.
What can cause a try-wait to never return true value?
If this works, I will convert it to a multiple-CTA wait version to make synchronization more efficient. Currently only trying to run it correctly and have no idea how TMA knows the destinations for other CTAs smem. Can it infer their destinations from just the calling CTA’s destination parameter?
Tensor map encoding is correct, tile data on gmem is aligned, and works for non-cluster based TMA operations, distributed shared memory barrier and all the data pointers are aligned to 32 and 128 respectively. No error is returned. Its just running inside the while loop.
Also if someone can discuss the details about how mbarrier works, how a cluster initiates TMA from single and multiple CTAs for all CTAs, I appreciate. For example, more CTAs in cluster would require at least 1 arrival per CTA right? Then maybe try-wait per CTA instead of cluster.sync(). I guess arrive-count expected is (1 from TMA) +(1 for each thread calling try-wait).
I’m trying to do this because generating a single TMA request from single CTA and then distributing from that CTA to other CTAs in cluster (through distributed-smem copy manually) has a very limited bandwidth like 2-3 TB/s only (H100). Compared to the TMA bandwidth, its too low. So a multicast would make it potentially 5-10 times faster (or up to the peak shared-memory write bandwidth which is used as destination).