Correct usage of mbarriers on Ampere / Ada

Hi All,

I am trying to understand what would be the correct way to use mbarriers on Ampere for tracking multiple cp.async instructions per thread.

I wrote this very small program, where each thread copies 16 contiguous floats, using 4 16 cp.async instructions and calls cp.async.mbarrier.noinc.arrive.

Now if I understand correctly, after all the threads (32 in my case), call that instruction, at some point the expected arrival count of the mbarrier object to decrement to 33 (initial value) - 32 = 1.

Then thread 0 calls mbarrier.arrive to query the state (thus decrementing the pending arrival count to 1 - 1 = 0), and ONLY thread_0 has a spin loop to check if the phase bit has flipped.

Then all the threads copy the contents of shared memory to global memory. I however face incorrect results and it completely hangs under compute-sanitizer. Where am I going wrong ?

__global__ void kernel(float* a, float* b) {
    __shared__ uint64_t mbar;
    __shared__ alignas(128) float smem[16 * 32];
    auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&smem));
    auto mbar_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&mbar));

    if (threadIdx.x == 0) {
        asm volatile("mbarrier.init.shared.b64 [%0], %1;\n\t" :: "r"(mbar_addr), "r"(blockDim.x + 1));
    }
    __syncthreads();

    #pragma unroll
    for (int i = 0; i < 4; i++) {
        asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" ::
                     "r"(smem_addr + (threadIdx.x * 16 + i * 4) * 4), "l"(a + threadIdx.x * 16 + i * 4) : "memory");
    }

    asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0]; \n" :: "r"(mbar_addr));
    __syncthreads();
    if (threadIdx.x == 0) {
    int64_t state;
    asm volatile("mbarrier.arrive.release.cta.shared.b64 %0, [%1];\n" : "=l"(state) : "r"(mbar_addr));
    int32_t expected_parity = static_cast<uint32_t>(((state >> 63) ^ 1) & 1);
    int complete = 0;
    do {
        asm volatile("{\n\t"
                     ".reg .pred %%p;\n\t"
                     "mbarrier.test_wait.parity.shared.b64 %%p, [%1], %2;\n\t"
                     "selp.u32 %0, 1, 0, %%p;\n\t"
                     "}\n\t"
                     : "=r"(complete)
                     : "r"(mbar_addr), "r"(expected_parity)
                     : "memory");
    } while (!complete);
}

    __syncthreads();


    for (int i = threadIdx.x; i < 16 * 32; i += blockDim.x) {
        b[i] = smem[i];
    }
}

int main() {
    float *a;
    float* b;
    cudaMallocManaged(&a, 16 * 32 * sizeof(float));
    cudaMallocManaged(&b, 16 * 32 * sizeof(float));
    for (int i = 0; i < 16 * 32; i++) {
        a[i] = i + 1;
        b[i] = 0;
    }
    kernel<<<1, 32>>>(a, b);
    cudaDeviceSynchronize();
    for (int i = 0; i < 16 * 32; i++) {
        if (a[i] != b[i]) {
            printf("%f %f %d\n", a[i], b[i], i);
            throw std::runtime_error("failed");
        }
    }
    std::cout << "passed" << std::endl;
}


Thanks