What is the best way to re-use a tensor core C fragment now as A or B input when their types differ?

Dear Community,
Lately we are having an efficiency issue when using the CUDA WMMA tensor core matrix multiplication for a research project. To put into context, we need to perform a sequence of WMMA operations within the same kernel, re-using the resulting C fragment of the last operation, now as an A or B matrix for the next WMMA operation. When we use FP16 WMMA there is no issue, because after an mma_sync completes, in the following lines of code the C fragment can be stored in a 16x16 shared memory buffer with store_matrix_sync, then loaded again as an “A” or “B” fragment using load_matrix_sync.

The newer MMA data types INT8, INT4 could really provide a boost in performance for what we are doing, but when doing MMA in these new types the resulting fragment is in INT32, so for now we are forced to dedicate a special INT32 shared memory buffer to save the fragment, and then manually cast all values into another INT8/INT4 buffer, and only then we can load again the result as an A or B fragment for the next MMA. This longer pipeline makes the whole process become even slower than in FP16 and uses more shared memory. In theory using these new types could be much more efficient if it wasn’t for these technical difficulties.

Does CUDA allow a more efficient way on re-using the C fragment again as A or B, when its data type differs?

You can see the element layout per thread in the ptx documentation: PTX ISA 8.3

For mma int8, each thread holds 2 elements of C, each in a separate 32-bit register.
However, the input of mma int8 , requires a single 32-bit register which packs 4 8-bit elements.
This transformation cannot be achieved by a simple cast.

Thanks striker159,
EDIT: sorry I had misread your comment but now I understood the point. I see, so do you think that the inefficiency we are dealing with INT8 from the manual conversions seem inevitable in the current state of tensor cores?

Regarding the FP16 approach, although fragments are of the same type, currently in our solution the resulting data of a C fragment passes through a shared memory copy before getting into an A or B fragment. What we would really want is to find a way to pass the data directly from C to A | B, but it feels as if the C/C++ template design does not allow it. Does anyone know if there is a way to achieve this direct transfer of the fragment from C to A/B at C/C++ level? or in the worst case at the PTX level?

Can you show the exact fragment types you are using?

Currently, with the FP16 approach, we have these fragments

   wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
   wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
   wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_frag;

I don’t think there is a robust way with wmma to move the contents of C to A without materializing the matrix in shared memory. The problem is that the storage layout of wmma fragments is unspecified.

For mma with specified mapping of elements to threads, you might be able to rearrange the data using warp shuffles instead of shared memory. I did not look into this approach.