Synchronizing Threads Between Loading Q/K and V in WASP

In WASP, we can’t use __syncthreads(), but my requirement is: first, load Q and K, then I want to sync, and finally load V. How should I write this sync?

Is this correct?

cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);

I don’t understand the concept of NamedBarrier. Could anyone explain if this is the right approach?

In F2 you should load Q first (in a 128 bit transaction from global to to local SMEM).

Typically they will insert :

  • “cp.async.commit_group”
  • “cp.async.wait_group 0”

Then syn the whole threads block.

  • cg::this_thread_block().sync()

K V will be loaded inside a for loop to iteratve over warp fragments (typically 16x16 fragmenbts) after load q from gloabl to the smem buffer, each followed by commit group instruction and finally:

  • “cp.async.wait_group 1”
  • cg::this_thread_block().sync()
1 Like

Thanks! But seems you are discussing FA2, without WASP?