I have to multiply two u8 matrices: A*B, but A is in col-major while B is in row-major. With wmma instructions, this can be done easily by specifying the layout during loading A and B. However, it turned out that mma instructions only accept .row.col, requiring A to be in row major and B to be in col major. I know ‘movmatrix.sync.aligned.shape.trans.type d, a’ may help but it only supports 16 bit data. So what should I do?
with an mma instruction, you load the register footprint yourself. you can transpose during loading
also, reversing the order of multiplication is equivalent to multiplying by the transpose(s) (with a transpose of the result). This could transpose both A and B for you, for “free”.
Yes, I understand that, but that’s quite difficult as u8 is stored as u32 in its continuous direction in the memory. The only way I have to transpose involves 4x reading and a lot of work to recompose them. Do you have any idea how wmma does that transpose?
I do not understand this. I want to get AB, where A is col major and B is row major, but transposing the answer means calculating B^T A^T. But without actually changing the data layout, B^T is in col major while A^T is in row major. Again, this is not the way that mma requires.
Read from global memory only once in a coalesced way (at least complete 32-byte aligned blocks) and then convert it locally with __byte_perm()
and select (three-way question mark operator) and warp shuffle operations. One usually does not need even use shared memory for resorting data.
You would do it for several 32-byte blocks at once.
You could also use the 16-bit ldmatrix
operation, but not sure, if it helps in the 8-bit case combined with some post-shuffling.
Thanks for your reply! It turned out that this problem has already been addressed in FlashAttention3 for FP8, and cutlass-kernels/src/fmha-pipeline/reg2reg.h at master · ColfaxResearch/cutlass-kernels seems to give an implementation for fp8. Their techniques is just what you said:
Nice to see a confirmation by Cutlass that our methods to do it so far really seem to be the fastest. They seem to have tried out different variants.
Watch out
- For direct usage in mma, you do not have to save back the results to shared memory. Perhaps adapt the algorithm to have the right data in the right threads after permutation. It could be that you have to combine a few mma operations, as the data loading could be more efficient, when loading more than one set of data (e.g. 2 or 4 sets).
- Make sure
upper_map[threadIdx.x%4]
is done without local memory and done only once outside of the loops, as it is constant per lane number. - Make sure, all global memory accesses are coalesced (at least to 32 bytes) and all shared memory accesses are bank conflict-free.
This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.