How does 4x4 mma at tensor core level translate to 16x16 mma at warp level?

I have recently started using tensor core operations with CUDA, and am a bit confused on how do we arrive at a 16x16 matrix size at warp level. I am on V100, and my understanding is that each tensor core can perform a 4x4 matrix multiplication. Within an SM, a warp scheduler has access to only two tensor cores, so I don’t understand how using two tensor allow quadrupling the size of the matrix. Can someone please explain what I could be missing?

I guess you might be thinking that a 16x16 matrix multiply on a V100 all happens in a single instruction or single push of a button, or something like that. It sort of appears that way at the CUDA C++ intrinsic level.

It does not, at the machine code level. Remember that CUDA C++ is a compiled language. Something that is elegantly expressed at the C++ level may decompose into multiple steps at the machine level.

On V100, I believe the basic op performed by the TC unit was an m=8,n=8,k=4 (884) multiply. It looks like there was a two-step process to get to a 888 multiply (perhaps – I haven’t given this lots of study). To get to the full 16x16 multiply would require 4 of these. Clever carrying of the intermediate sums in registers from one op to the next should allow stitching the ops together with minimum effort, I think.

You can get an idea by compiling the code and inspecting the SASS.

For example, take the code here, change the loop in the kernel to have a limit of 1 instead of 20, then compile the code for v100 and inspect the SASS with cuobjdump. You will discover that the matrix multiply itself, at least in part, is handled by this sequence of SASS instructions:

        /*01d0*/                   HMMA.884.F16.F16.STEP0 R20, R4.reuse.ROW, R8.reuse.COL, RZ ;   /* 0x0000000804147236 */
                                                                                                  /* 0x0c1fe800000004ff */
        /*01e0*/                   HMMA.884.F16.F16.STEP1 R22, R4.ROW, R8.COL, RZ ;               /* 0x0000000804167236 */
                                                                                                  /* 0x000f6800000084ff */
        /*01f0*/                   HMMA.884.F16.F16.STEP0 R4, R6.reuse.ROW, R10.reuse.COL, R20 ;  /* 0x0000000a06047236 */
                                                                                                  /* 0x0e0fe80000000414 */
        /*0200*/                   HMMA.884.F16.F16.STEP1 R6, R6.ROW, R10.COL, R22 ;              /* 0x0000000a06067236 */
                                                                                                  /* 0x000f680000008416 */
        /*0210*/                   HMMA.884.F16.F16.STEP0 R4, R12.reuse.ROW, R16.reuse.COL, R4 ;  /* 0x000000100c047236 */
                                                                                                  /* 0x0e2fe80000000404 */
        /*0220*/                   HMMA.884.F16.F16.STEP1 R6, R12.ROW, R16.COL, R6 ;              /* 0x000000100c067236 */
                                                                                                  /* 0x000f680000008406 */
        /*0230*/                   HMMA.884.F16.F16.STEP0 R4, R14.reuse.ROW, R18.reuse.COL, R4 ;  /* 0x000000120e047236 */
                                                                                                  /* 0x0e0f680000000404 */
        /*0240*/                   HMMA.884.F16.F16.STEP1 R6, R14.ROW, R18.COL, R6 ;              /* 0x000000120e067236 */

So each “basic” 8x8 multiply appears to require two SASS instructions, and the overall 16x16 multiply appears to require 4 of these two-step sequences, 8 instructions total.

I don’t plan to describe in detail how the component multiplies are assembled into a final result, but the process of performing a matrix multiply using sub-matrices and partial results is well described elsewhere 1 2.

As an aside, the code I mentioned looks very different at the SASS level if you compile for a newer architecture, such as sm_80 or sm_90. These TC units do a 16/8/16 multiply in a single SASS instruction, so only 2 instructions needed for the 16x16 full multiply.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.