How does it compute exactly in Tensor Core?

Hello~

When I read the the fragment in mma with tensor core. it said that “A warp executing mma.m8n8k4 with .f16 floating point type will compute 4 MMA operations of shape .m8n8k4.”

First, I do not know why it should be 4 MMA operations of shape .m8n8k4.

Sencond, At MMA computation 1 part detailed in the document, Since the register of thread T0 owns the data of the 0th row of the A matrix, T0 also owns the first 4 data of the first row of the B matrix. According to the matrix calculation, a0 in T0 should only be multiplied by b0 b1 b2 b3 in T0, then a1 a2 a3 does not participate in the computation1? How is this MMA computation 1 calculated in detail?

The TC unit of cc7.0 devices was designed to have that kind of hardware behavior, for at least one of the supported hardware paths. “Why is it that way?” questions can be difficult to answer. You will find with a bit of searching that the performance characteristic of that particular tensor op form varies significantly depending on GPU arch, so it seems clear to me that the GPU designers had different ideas as GPU development and TC development progressed.

The calculation produces a (four) m8n8k4 matrix-matrix multiply. There is only one correct behavior for that statement. Beyond that, I don’t know of any detailed specifications for TC unit behavior.

Here are some questions that may be of interest: 1 2 3

Thank you for your reply. But I think if there are some statements in the document, it is necessary to make it clear, otherwise the document will no longer make sense.

It seems to me that, the matrix A and B are loaded by all of the 4 quad pairs, which results in duplicated (and unnecessary) data load.
Also, the 4 quad pairs seems to calculate the exact same 8x8 result, which is also confusing. Will only one result of the 4 quad pairs be kept and the other three are discarded?

The only guess I can think about is that it might be hard to add “mask” logic for tensor core instruction at warp-level, to keep 3/4 part of the warp idle, so that it sacrifices efficiency for simplicity.

There was an academic (3rd party) paper analyzing the m8n8k4 behavior with what each of the 4 steps does.

Does it have so much relevancy now? 7.0 tensor cores are rather outdated now and newer architectures support this matrix size in a slower compatibility mode.

It computes 4 separate results.

But based on the figure in 9.7.15.4.1. Matrix Fragments for mma.m8n8k4 with .f16 floating point type, each of the four MMA operation loads a 8x4 of matrix A, and 4x8 of matrix B. This is exactly the shape of source matrix A and B, which looks to me that each MMA operation loads the full matrix A and B.
How is it able to calculate 4 separate part if it loads the full(and same) matrix A and B?

Those 4 mma operation register layouts correspond to 4 separate A and B matrices of the appropriate shape(s). They are not all referring to the same A and B matrices.

Thank you for your explanation, I wonder is there any figure to illustrate the fragment to thread mapping?
I still couldn’t get the mapping right
(1) Source Matrix A is of shape 8x4
(2) For MMA operation 1, each thread of a QP loads four elements, which is also a shape 8x4
(3) For MMA 2/3/4, same as (2)

How does each QP load separate part of A and B?

Besides, if we calculate the FLOPs, a 8x8x4 operation requires 512 FMA, but calculating 4 of 8x8x4 needs 2048, where does the increased FLOPs come from?

I’ve read the document again, and draw a figure to illustrate how the fragment is mapping to each QP.
Let’s support A is row-major and B is column-major, and A has shape. 8x4, and B has shape 4x8.

The row and column of A matrix fragment can be computed as:

row =            %laneid % 4          if %laneid < 16
                (%laneid % 4) + 4     otherwise

col =            i                    for ai where i = {0,..,3}

The rule for B

The row and column of B matrix fragment can be computed as:

row =       i                 for bi   where i = {0,..,3}

col =      %laneid % 4        if %laneid < 16
          (%laneid % 4) + 4   otherwise

Therefore, we get the following figure, in which each QP loads the upper half and lower half of matrix A and B. Together, each QP loads a full part of matrix A and B.

Perhaps this example will help. First let’s point out that the PTX instruction ( mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16) does not actually load data from global or shared memory. It works out of registers. And the diagrams (e.g. Figures 22-26 in the PTX guide) represent the mapping of registers (in each thread) to elements of A,B,C.

It is as if the operation being performed is:

C1 = A1xB1  (handled by threads 0-3, and 16-19)
C2 = A2xB2  (handled by threads 4-7, and 20-23)
C3 = A3xB3  (handled by threads 8-11, and 24-27)
C4 = A4xB4  (handled by threads 12-15, and 28-31)

We will call the above items “computations” 1,2,3, or 4 (to be consistent with PTX nomenclature)

The following example demonstrates this:

# cat t261.cu
#include <cuda_fp16.h>
#include <iostream>
#include <mma.h>
using namespace nvcuda;

__global__ void mma_test(half* C1, half* C2, half *C3, half *C4)
{
    int lane = threadIdx.x % 32;
    uint out[4] = { 0 };

    {
        uint MultiA[2] = { 0 };
        uint MultiB[2] = { 0 };

        half* test1 = reinterpret_cast<half*>(MultiA);
        half* test2 = reinterpret_cast<half*>(MultiB);
        if ((lane < 4) || ((lane > 15) && (lane < 20))) { // row major matrix A1 from PTX figure 22
          test1[0] = 1.0; // you could have just as easily loaded these values from any location in global memory or shared memory
          test1[1] = 1.0;
          test1[2] = 1.0;
          test1[3] = 1.0;}
        if (((lane > 3) && (lane < 8)) || ((lane > 19) && (lane < 24))) { // row major matrix A2 from PTX figure 22
          test1[0] = 2.0;
          test1[1] = 2.0;
          test1[2] = 2.0;
          test1[3] = 2.0;}
        if (((lane > 7) && (lane < 12)) || ((lane > 23) && (lane < 28))) { // row major matrix A3 from PTX figure 22
          test1[0] = 3.0;
          test1[1] = 3.0;
          test1[2] = 3.0;
          test1[3] = 3.0;}
        if (((lane > 11) && (lane < 16)) || (lane > 27)) { // row major matrix A4 from PTX figure 22
          test1[0] = 4.0;
          test1[1] = 4.0;
          test1[2] = 4.0;
          test1[3] = 4.0;}

        // loading B1 - B4, keeping it simple - but the same structure as above could be used to load the B1-B4 "separately"
        test2[0] = 1.0;
        test2[1] = 1.0;
        test2[2] = 1.0;
        test2[3] = 1.0;

        asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
                     "{ %0, %1, %2, %3 },"
                     "{ %4, %5 },"
                     "{ %6, %7 },"
                     "{ %8, %9, %10, %11 };\n"
                     : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3])
                     : "r"(MultiA[0]), "r"(MultiA[1]),
                     "r"(MultiB[0]), "r"(MultiB[1]),
                     "r"(out[0]), "r"(out[1]), "r"(out[2]), "r"(out[3]));
    }
    if (lane < 4) { // C1 from PTX figure 26
      reinterpret_cast<uint4*>(C1)[lane] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 15) && (lane < 20)) { // C1 from PTX figure 26
      reinterpret_cast<uint4*>(C1)[lane+4-16] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 3) && (lane < 8)) { // C2 from PTX figure 26
      reinterpret_cast<uint4*>(C2)[lane-4] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 19) && (lane < 24)) { // C2 from PTX figure 26
      reinterpret_cast<uint4*>(C2)[lane+4-20] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 7) && (lane < 12)) { // C3 from PTX figure 26
      reinterpret_cast<uint4*>(C3)[lane-8] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 23) && (lane < 28)) { // C3 from PTX figure 26
      reinterpret_cast<uint4*>(C3)[lane+4-24] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 11) && (lane < 16)) { // C4 from PTX figure 26
      reinterpret_cast<uint4*>(C4)[lane-12] = reinterpret_cast<uint4*>(out)[0];}
    if (lane > 27) { // C4 from PTX figure 26
      reinterpret_cast<uint4*>(C4)[lane+4-28] = reinterpret_cast<uint4*>(out)[0];}
}

int main(int argc, char *argv[])
{
    half* h_C1 = (half*)malloc(sizeof(half) * 8 * 8);
    half* h_C2 = (half*)malloc(sizeof(half) * 8 * 8);
    half* h_C3 = (half*)malloc(sizeof(half) * 8 * 8);
    half* h_C4 = (half*)malloc(sizeof(half) * 8 * 8);
    half *d_C1, *d_C2, *d_C3, *d_C4;
    cudaMalloc(&d_C1, sizeof(half) * 8 * 8);
    cudaMalloc(&d_C2, sizeof(half) * 8 * 8);
    cudaMalloc(&d_C3, sizeof(half) * 8 * 8);
    cudaMalloc(&d_C4, sizeof(half) * 8 * 8);

    mma_test<<<1, 32>>>(d_C1, d_C2, d_C3, d_C4);

    cudaMemcpy(h_C1, d_C1, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_C2, d_C2, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_C3, d_C3, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_C4, d_C4, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    std::cout << "C1: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C1[i]) << " ";
    }
    std::cout << std::endl;
    std::cout << "C2: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C2[i]) << " ";
    }
    std::cout << std::endl;
    std::cout << "C3: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C3[i]) << " ";
    }
    std::cout << std::endl;
    std::cout << "C4: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C4[i]) << " ";
    }
    std::cout << std::endl;
}
# nvcc -o t261 t261.cu -arch=sm_70
# compute-sanitizer ./t261
========= COMPUTE-SANITIZER
C1:
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
C2:
8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
C3:
12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12
C4:
16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16
========= ERROR SUMMARY: 0 errors
# cuobjdump -sass ./t261

Fatbin elf code:
================
arch = sm_70
code version = [1,7]
host = linux
compile_size = 64bit

        code for sm_70

Fatbin elf code:
================
arch = sm_70
code version = [1,7]
host = linux
compile_size = 64bit

        code for sm_70
                Function : _Z8mma_testP6__halfS0_S0_S0_
        .headerflags    @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM70 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM70)"
        /*0000*/                   IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;                      /* 0x00000a00ff017624 */
                                                                                                /* 0x000fc400078e00ff */
        /*0010*/              @!PT SHFL.IDX PT, RZ, RZ, RZ, RZ ;                                /* 0x000000fffffff389 */
                                                                                                /* 0x000fe200000e00ff */
        /*0020*/                   S2R R0, SR_TID.X ;                                           /* 0x0000000000007919 */
                                                                                                /* 0x000e220000002100 */
        /*0030*/                   CS2R R6, SRZ ;                                               /* 0x0000000000067805 */
                                                                                                /* 0x000fe2000001ff00 */
        /*0040*/                   IMAD.MOV.U32 R3, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff037424 */
                                                                                                /* 0x000fe400078e00ff */
        /*0050*/                   IMAD.MOV.U32 R13, RZ, RZ, 0x10 ;                             /* 0x00000010ff0d7424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0060*/                   LOP3.LUT R9, R0.reuse, 0x1c, RZ, 0xc0, !PT ;                 /* 0x0000001c00097812 */
                                                                                                /* 0x041fe400078ec0ff */
        /*0070*/                   LOP3.LUT R0, R0, 0x1f, RZ, 0xc0, !PT ;                       /* 0x0000001f00007812 */
                                                                                                /* 0x000fe400078ec0ff */
        /*0080*/                   LOP3.LUT R2, R9.reuse, 0x10, RZ, 0xfc, !PT ;                 /* 0x0000001009027812 */
                                                                                                /* 0x040fe400078efcff */
        /*0090*/                   ISETP.NE.AND P1, PT, R9, 0x10, PT ;                          /* 0x000000100900780c */
                                                                                                /* 0x000fc40003f25270 */
        /*00a0*/                   PRMT R2, R2, 0x9910, RZ ;                                    /* 0x0000991002027816 */
                                                                                                /* 0x000fe400000000ff */
        /*00b0*/                   ISETP.GT.U32.AND P2, PT, R0, 0x3, P1 ;                       /* 0x000000030000780c */
                                                                                                /* 0x000fe40000f44070 */
        /*00c0*/                   ISETP.NE.AND P4, PT, R2.reuse, 0x14, PT ;                    /* 0x000000140200780c */
                                                                                                /* 0x040fe40003f85270 */
        /*00d0*/                   ISETP.NE.AND P5, PT, R2, 0x18, PT ;                          /* 0x000000180200780c */
                                                                                                /* 0x000fe20003fa5270 */
        /*00e0*/                   IMAD.MOV.U32 R2, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff027424 */
                                                                                                /* 0x000fe200078e00ff */
        /*00f0*/                   ISETP.GE.U32.AND P0, PT, R0.reuse, 0x1c, PT ;                /* 0x0000001c0000780c */
                                                                                                /* 0x040fe40003f06070 */
        /*0100*/                   ISETP.GT.U32.AND P6, PT, R0, 0x3, PT ;                       /* 0x000000030000780c */
                                                                                                /* 0x000fc40003fc4070 */
        /*0110*/                   ISETP.NE.AND P3, PT, R9.reuse, 0xc, !P0 ;                    /* 0x0000000c0900780c */
                                                                                                /* 0x040fe40004765270 */
        /*0120*/              @!P2 IMAD.MOV.U32 R6, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff06a424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0130*/                   IADD3 R8, R0.reuse, -0x10, RZ ;                              /* 0xfffffff000087810 */
                                                                                                /* 0x040fe20007ffe0ff */
        /*0140*/              @!P2 IMAD.MOV.U32 R7, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff07a424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0150*/                   ISETP.NE.AND P2, PT, R9.reuse, 0x4, PT ;                     /* 0x000000040900780c */
                                                                                                /* 0x040fe20003f45270 */
        /*0160*/              @!P4 IMAD.MOV.U32 R6, RZ, RZ, 0x40004000 ;                        /* 0x40004000ff06c424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0170*/                   IADD3 R10, R0.reuse, -0x14, RZ ;                             /* 0xffffffec000a7810 */
                                                                                                /* 0x040fe20007ffe0ff */
        /*0180*/              @!P4 IMAD.MOV.U32 R7, RZ, RZ, 0x40004000 ;                        /* 0x40004000ff07c424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0190*/                   IADD3 R12, R0, -0x18, RZ ;                                   /* 0xffffffe8000c7810 */
                                                                                                /* 0x000fe20007ffe0ff */
        /*01a0*/              @!P5 IMAD.MOV.U32 R6, RZ, RZ, 0x42004200 ;                        /* 0x42004200ff06d424 */
                                                                                                /* 0x000fe200078e00ff */
        /*01b0*/                   ISETP.NE.AND P4, PT, R9.reuse, 0xc, PT ;                     /* 0x0000000c0900780c */
                                                                                                /* 0x040fe20003f85270 */
        /*01c0*/              @!P5 IMAD.MOV.U32 R7, RZ, RZ, 0x42004200 ;                        /* 0x42004200ff07d424 */
                                                                                                /* 0x000fe200078e00ff */
        /*01d0*/                   ISETP.NE.AND P5, PT, R9, 0x8, PT ;                           /* 0x000000080900780c */
                                                                                                /* 0x000fe20003fa5270 */
        /*01e0*/              @!P3 IMAD.MOV.U32 R6, RZ, RZ, 0x44004400 ;                        /* 0x44004400ff06b424 */
                                                                                                /* 0x000fc400078e00ff */
        /*01f0*/              @!P3 IMAD.MOV.U32 R7, RZ, RZ, 0x44004400 ;                        /* 0x44004400ff07b424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0200*/                   ISETP.NE.AND P3, PT, R9, 0x14, PT ;                          /* 0x000000140900780c */
                                                                                                /* 0x000fe20003f65270 */
        /*0210*/                   IMAD.WIDE R10, R10, R13, c[0x0][0x170] ;                     /* 0x00005c000a0a7625 */
                                                                                                /* 0x000fc800078e020d */
        /*0220*/                   HMMA.884.F16.F16.STEP0 R4, R6.reuse.ROW, R2.reuse.COL, RZ ;  /* 0x0000000206047236 */
                                                                                                /* 0x0c0fe800000004ff */
        /*0230*/                   HMMA.884.F16.F16.STEP1 R6, R6.ROW, R2.COL, RZ ;              /* 0x0000000206067236 */
                                                                                                /* 0x000b6400000084ff */
        /*0240*/                   IMAD.WIDE.U32 R2, R0, R13, c[0x0][0x160] ;                   /* 0x0000580000027625 */
                                                                                                /* 0x020fd600078e000d */
        /*0250*/              @!P6 STG.E.128.SYS [R2], R4 ;                                     /* 0x000000040200e386 */
                                                                                                /* 0x0001e2000010ed00 */
        /*0260*/                   ISETP.NE.AND P6, PT, R9, 0x18, PT ;                          /* 0x000000180900780c */
                                                                                                /* 0x000fe20003fc5270 */
        /*0270*/                   IMAD.WIDE R8, R8, R13.reuse, c[0x0][0x168] ;                 /* 0x00005a0008087625 */
                                                                                                /* 0x080fe400078e020d */
        /*0280*/              @!P1 STG.E.128.SYS [R2+-0xc0], R4 ;                               /* 0xffff400402009386 */
                                                                                                /* 0x0001e4000010ed00 */
        /*0290*/                   IMAD.WIDE R12, R12, R13, c[0x0][0x178] ;                     /* 0x00005e000c0c7625 */
                                                                                                /* 0x000fc800078e020d */
        /*02a0*/              @!P2 STG.E.128.SYS [R8+0xc0], R4 ;                                /* 0x0000c0040800a386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02b0*/              @!P3 STG.E.128.SYS [R8], R4 ;                                     /* 0x000000040800b386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02c0*/              @!P5 STG.E.128.SYS [R10+0xc0], R4 ;                               /* 0x0000c0040a00d386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02d0*/              @!P6 STG.E.128.SYS [R10], R4 ;                                    /* 0x000000040a00e386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02e0*/              @!P4 STG.E.128.SYS [R12+0xc0], R4 ;                               /* 0x0000c0040c00c386 */
                                                                                                /* 0x0001e2000010ed00 */
        /*02f0*/              @!P0 EXIT ;                                                       /* 0x000000000000894d */
                                                                                                /* 0x000fea0003800000 */
        /*0300*/                   STG.E.128.SYS [R12], R4 ;                                    /* 0x000000040c007386 */
                                                                                                /* 0x000fe2000010ed00 */
        /*0310*/                   EXIT ;                                                       /* 0x000000000000794d */
                                                                                                /* 0x000fea0003800000 */
        /*0320*/                   BRA 0x320;                                                   /* 0xfffffff000007947 */
                                                                                                /* 0x000fc0000383ffff */
        /*0330*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0340*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0350*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0360*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0370*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
                ..........



Fatbin ptx code:
================
arch = sm_70
code version = [8,2]
host = linux
compile_size = 64bit
compressed
#

Observations:

  • we have a if sequence that loads the registers for each thread. This if-sequence is broken down by computation number. For computation 1, we are loading the A1 matrix with values of 1.0. For computation 2, we are loading the A2 matrix with values of 2.0. Likewise for computations 3 and 4. For all computations, we are loading the Bn matrix with 1.0. Therefore we expect computation results of all 4 for C1, all 8 for C2, all 12 for C3, all 16 for C4.
  • rather than loading register values directly, we could have loaded the registers from any shared memory location or any global memory location. This means that the computations 1 through 4 are working on independent data sources. This is true most directly when we look at the register footprint, but also true when we consider that we could have loaded these registers from anywhere.
  • to answer your question from this thread we can see in the SASS dump output that the PTX mma instruction for m8n8k4 compiles to two SASS instructions (in the sm_70 case, anyway), one labelled part0 and the other labelled part1. Since it is evident (now) that these two SASS instructions are somehow computing all 4 independent computations, we must conclude that a sufficient number of FMA ops are provided between those 2 SASS instructions, for all 4 computations. You could paste this code into godbolt to witness the same thing.

I think you are getting confused between FLOPs and FMA. An FMA constitutes two FLOPs. For m8n8k4 matrix-multiply, the number of required FMA operations is 8x8x4 = 256. The number of required FLOPs is 512.

SInce the 4 computations require 256 FMA or 512 FLOPs each, we must conclude that the two-instruction SASS sequence (part0/part1) must provide the necessary FLOPs, i.e. 4x512 = 2048. Beyond that, I’m not sure how to answer the question “where does the increased FLOPs come from?” The flops are provided by the Tensor Core unit. Since the code provably creates the correct results, we must conclude that the two instruction SASS sequence is enough to provide those FLOPs.

1 Like