How does it compute exactly in Tensor Core?

Perhaps this example will help. First let’s point out that the PTX instruction ( mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16) does not actually load data from global or shared memory. It works out of registers. And the diagrams (e.g. Figures 22-26 in the PTX guide) represent the mapping of registers (in each thread) to elements of A,B,C.

It is as if the operation being performed is:

C1 = A1xB1  (handled by threads 0-3, and 16-19)
C2 = A2xB2  (handled by threads 4-7, and 20-23)
C3 = A3xB3  (handled by threads 8-11, and 24-27)
C4 = A4xB4  (handled by threads 12-15, and 28-31)

We will call the above items “computations” 1,2,3, or 4 (to be consistent with PTX nomenclature)

The following example demonstrates this:

# cat t261.cu
#include <cuda_fp16.h>
#include <iostream>
#include <mma.h>
using namespace nvcuda;

__global__ void mma_test(half* C1, half* C2, half *C3, half *C4)
{
    int lane = threadIdx.x % 32;
    uint out[4] = { 0 };

    {
        uint MultiA[2] = { 0 };
        uint MultiB[2] = { 0 };

        half* test1 = reinterpret_cast<half*>(MultiA);
        half* test2 = reinterpret_cast<half*>(MultiB);
        if ((lane < 4) || ((lane > 15) && (lane < 20))) { // row major matrix A1 from PTX figure 22
          test1[0] = 1.0; // you could have just as easily loaded these values from any location in global memory or shared memory
          test1[1] = 1.0;
          test1[2] = 1.0;
          test1[3] = 1.0;}
        if (((lane > 3) && (lane < 8)) || ((lane > 19) && (lane < 24))) { // row major matrix A2 from PTX figure 22
          test1[0] = 2.0;
          test1[1] = 2.0;
          test1[2] = 2.0;
          test1[3] = 2.0;}
        if (((lane > 7) && (lane < 12)) || ((lane > 23) && (lane < 28))) { // row major matrix A3 from PTX figure 22
          test1[0] = 3.0;
          test1[1] = 3.0;
          test1[2] = 3.0;
          test1[3] = 3.0;}
        if (((lane > 11) && (lane < 16)) || (lane > 27)) { // row major matrix A4 from PTX figure 22
          test1[0] = 4.0;
          test1[1] = 4.0;
          test1[2] = 4.0;
          test1[3] = 4.0;}

        // loading B1 - B4, keeping it simple - but the same structure as above could be used to load the B1-B4 "separately"
        test2[0] = 1.0;
        test2[1] = 1.0;
        test2[2] = 1.0;
        test2[3] = 1.0;

        asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
                     "{ %0, %1, %2, %3 },"
                     "{ %4, %5 },"
                     "{ %6, %7 },"
                     "{ %8, %9, %10, %11 };\n"
                     : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3])
                     : "r"(MultiA[0]), "r"(MultiA[1]),
                     "r"(MultiB[0]), "r"(MultiB[1]),
                     "r"(out[0]), "r"(out[1]), "r"(out[2]), "r"(out[3]));
    }
    if (lane < 4) { // C1 from PTX figure 26
      reinterpret_cast<uint4*>(C1)[lane] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 15) && (lane < 20)) { // C1 from PTX figure 26
      reinterpret_cast<uint4*>(C1)[lane+4-16] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 3) && (lane < 8)) { // C2 from PTX figure 26
      reinterpret_cast<uint4*>(C2)[lane-4] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 19) && (lane < 24)) { // C2 from PTX figure 26
      reinterpret_cast<uint4*>(C2)[lane+4-20] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 7) && (lane < 12)) { // C3 from PTX figure 26
      reinterpret_cast<uint4*>(C3)[lane-8] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 23) && (lane < 28)) { // C3 from PTX figure 26
      reinterpret_cast<uint4*>(C3)[lane+4-24] = reinterpret_cast<uint4*>(out)[0];}
    if ((lane > 11) && (lane < 16)) { // C4 from PTX figure 26
      reinterpret_cast<uint4*>(C4)[lane-12] = reinterpret_cast<uint4*>(out)[0];}
    if (lane > 27) { // C4 from PTX figure 26
      reinterpret_cast<uint4*>(C4)[lane+4-28] = reinterpret_cast<uint4*>(out)[0];}
}

int main(int argc, char *argv[])
{
    half* h_C1 = (half*)malloc(sizeof(half) * 8 * 8);
    half* h_C2 = (half*)malloc(sizeof(half) * 8 * 8);
    half* h_C3 = (half*)malloc(sizeof(half) * 8 * 8);
    half* h_C4 = (half*)malloc(sizeof(half) * 8 * 8);
    half *d_C1, *d_C2, *d_C3, *d_C4;
    cudaMalloc(&d_C1, sizeof(half) * 8 * 8);
    cudaMalloc(&d_C2, sizeof(half) * 8 * 8);
    cudaMalloc(&d_C3, sizeof(half) * 8 * 8);
    cudaMalloc(&d_C4, sizeof(half) * 8 * 8);

    mma_test<<<1, 32>>>(d_C1, d_C2, d_C3, d_C4);

    cudaMemcpy(h_C1, d_C1, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_C2, d_C2, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_C3, d_C3, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_C4, d_C4, sizeof(half) * 8 * 8, cudaMemcpyDeviceToHost);
    std::cout << "C1: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C1[i]) << " ";
    }
    std::cout << std::endl;
    std::cout << "C2: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C2[i]) << " ";
    }
    std::cout << std::endl;
    std::cout << "C3: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C3[i]) << " ";
    }
    std::cout << std::endl;
    std::cout << "C4: " << std::endl;
    for (int i = 0; i < 8*8; i++) {
        std::cout << __half2float(h_C4[i]) << " ";
    }
    std::cout << std::endl;
}
# nvcc -o t261 t261.cu -arch=sm_70
# compute-sanitizer ./t261
========= COMPUTE-SANITIZER
C1:
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
C2:
8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
C3:
12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12
C4:
16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16
========= ERROR SUMMARY: 0 errors
# cuobjdump -sass ./t261

Fatbin elf code:
================
arch = sm_70
code version = [1,7]
host = linux
compile_size = 64bit

        code for sm_70

Fatbin elf code:
================
arch = sm_70
code version = [1,7]
host = linux
compile_size = 64bit

        code for sm_70
                Function : _Z8mma_testP6__halfS0_S0_S0_
        .headerflags    @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM70 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM70)"
        /*0000*/                   IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;                      /* 0x00000a00ff017624 */
                                                                                                /* 0x000fc400078e00ff */
        /*0010*/              @!PT SHFL.IDX PT, RZ, RZ, RZ, RZ ;                                /* 0x000000fffffff389 */
                                                                                                /* 0x000fe200000e00ff */
        /*0020*/                   S2R R0, SR_TID.X ;                                           /* 0x0000000000007919 */
                                                                                                /* 0x000e220000002100 */
        /*0030*/                   CS2R R6, SRZ ;                                               /* 0x0000000000067805 */
                                                                                                /* 0x000fe2000001ff00 */
        /*0040*/                   IMAD.MOV.U32 R3, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff037424 */
                                                                                                /* 0x000fe400078e00ff */
        /*0050*/                   IMAD.MOV.U32 R13, RZ, RZ, 0x10 ;                             /* 0x00000010ff0d7424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0060*/                   LOP3.LUT R9, R0.reuse, 0x1c, RZ, 0xc0, !PT ;                 /* 0x0000001c00097812 */
                                                                                                /* 0x041fe400078ec0ff */
        /*0070*/                   LOP3.LUT R0, R0, 0x1f, RZ, 0xc0, !PT ;                       /* 0x0000001f00007812 */
                                                                                                /* 0x000fe400078ec0ff */
        /*0080*/                   LOP3.LUT R2, R9.reuse, 0x10, RZ, 0xfc, !PT ;                 /* 0x0000001009027812 */
                                                                                                /* 0x040fe400078efcff */
        /*0090*/                   ISETP.NE.AND P1, PT, R9, 0x10, PT ;                          /* 0x000000100900780c */
                                                                                                /* 0x000fc40003f25270 */
        /*00a0*/                   PRMT R2, R2, 0x9910, RZ ;                                    /* 0x0000991002027816 */
                                                                                                /* 0x000fe400000000ff */
        /*00b0*/                   ISETP.GT.U32.AND P2, PT, R0, 0x3, P1 ;                       /* 0x000000030000780c */
                                                                                                /* 0x000fe40000f44070 */
        /*00c0*/                   ISETP.NE.AND P4, PT, R2.reuse, 0x14, PT ;                    /* 0x000000140200780c */
                                                                                                /* 0x040fe40003f85270 */
        /*00d0*/                   ISETP.NE.AND P5, PT, R2, 0x18, PT ;                          /* 0x000000180200780c */
                                                                                                /* 0x000fe20003fa5270 */
        /*00e0*/                   IMAD.MOV.U32 R2, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff027424 */
                                                                                                /* 0x000fe200078e00ff */
        /*00f0*/                   ISETP.GE.U32.AND P0, PT, R0.reuse, 0x1c, PT ;                /* 0x0000001c0000780c */
                                                                                                /* 0x040fe40003f06070 */
        /*0100*/                   ISETP.GT.U32.AND P6, PT, R0, 0x3, PT ;                       /* 0x000000030000780c */
                                                                                                /* 0x000fc40003fc4070 */
        /*0110*/                   ISETP.NE.AND P3, PT, R9.reuse, 0xc, !P0 ;                    /* 0x0000000c0900780c */
                                                                                                /* 0x040fe40004765270 */
        /*0120*/              @!P2 IMAD.MOV.U32 R6, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff06a424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0130*/                   IADD3 R8, R0.reuse, -0x10, RZ ;                              /* 0xfffffff000087810 */
                                                                                                /* 0x040fe20007ffe0ff */
        /*0140*/              @!P2 IMAD.MOV.U32 R7, RZ, RZ, 0x3c003c00 ;                        /* 0x3c003c00ff07a424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0150*/                   ISETP.NE.AND P2, PT, R9.reuse, 0x4, PT ;                     /* 0x000000040900780c */
                                                                                                /* 0x040fe20003f45270 */
        /*0160*/              @!P4 IMAD.MOV.U32 R6, RZ, RZ, 0x40004000 ;                        /* 0x40004000ff06c424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0170*/                   IADD3 R10, R0.reuse, -0x14, RZ ;                             /* 0xffffffec000a7810 */
                                                                                                /* 0x040fe20007ffe0ff */
        /*0180*/              @!P4 IMAD.MOV.U32 R7, RZ, RZ, 0x40004000 ;                        /* 0x40004000ff07c424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0190*/                   IADD3 R12, R0, -0x18, RZ ;                                   /* 0xffffffe8000c7810 */
                                                                                                /* 0x000fe20007ffe0ff */
        /*01a0*/              @!P5 IMAD.MOV.U32 R6, RZ, RZ, 0x42004200 ;                        /* 0x42004200ff06d424 */
                                                                                                /* 0x000fe200078e00ff */
        /*01b0*/                   ISETP.NE.AND P4, PT, R9.reuse, 0xc, PT ;                     /* 0x0000000c0900780c */
                                                                                                /* 0x040fe20003f85270 */
        /*01c0*/              @!P5 IMAD.MOV.U32 R7, RZ, RZ, 0x42004200 ;                        /* 0x42004200ff07d424 */
                                                                                                /* 0x000fe200078e00ff */
        /*01d0*/                   ISETP.NE.AND P5, PT, R9, 0x8, PT ;                           /* 0x000000080900780c */
                                                                                                /* 0x000fe20003fa5270 */
        /*01e0*/              @!P3 IMAD.MOV.U32 R6, RZ, RZ, 0x44004400 ;                        /* 0x44004400ff06b424 */
                                                                                                /* 0x000fc400078e00ff */
        /*01f0*/              @!P3 IMAD.MOV.U32 R7, RZ, RZ, 0x44004400 ;                        /* 0x44004400ff07b424 */
                                                                                                /* 0x000fe200078e00ff */
        /*0200*/                   ISETP.NE.AND P3, PT, R9, 0x14, PT ;                          /* 0x000000140900780c */
                                                                                                /* 0x000fe20003f65270 */
        /*0210*/                   IMAD.WIDE R10, R10, R13, c[0x0][0x170] ;                     /* 0x00005c000a0a7625 */
                                                                                                /* 0x000fc800078e020d */
        /*0220*/                   HMMA.884.F16.F16.STEP0 R4, R6.reuse.ROW, R2.reuse.COL, RZ ;  /* 0x0000000206047236 */
                                                                                                /* 0x0c0fe800000004ff */
        /*0230*/                   HMMA.884.F16.F16.STEP1 R6, R6.ROW, R2.COL, RZ ;              /* 0x0000000206067236 */
                                                                                                /* 0x000b6400000084ff */
        /*0240*/                   IMAD.WIDE.U32 R2, R0, R13, c[0x0][0x160] ;                   /* 0x0000580000027625 */
                                                                                                /* 0x020fd600078e000d */
        /*0250*/              @!P6 STG.E.128.SYS [R2], R4 ;                                     /* 0x000000040200e386 */
                                                                                                /* 0x0001e2000010ed00 */
        /*0260*/                   ISETP.NE.AND P6, PT, R9, 0x18, PT ;                          /* 0x000000180900780c */
                                                                                                /* 0x000fe20003fc5270 */
        /*0270*/                   IMAD.WIDE R8, R8, R13.reuse, c[0x0][0x168] ;                 /* 0x00005a0008087625 */
                                                                                                /* 0x080fe400078e020d */
        /*0280*/              @!P1 STG.E.128.SYS [R2+-0xc0], R4 ;                               /* 0xffff400402009386 */
                                                                                                /* 0x0001e4000010ed00 */
        /*0290*/                   IMAD.WIDE R12, R12, R13, c[0x0][0x178] ;                     /* 0x00005e000c0c7625 */
                                                                                                /* 0x000fc800078e020d */
        /*02a0*/              @!P2 STG.E.128.SYS [R8+0xc0], R4 ;                                /* 0x0000c0040800a386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02b0*/              @!P3 STG.E.128.SYS [R8], R4 ;                                     /* 0x000000040800b386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02c0*/              @!P5 STG.E.128.SYS [R10+0xc0], R4 ;                               /* 0x0000c0040a00d386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02d0*/              @!P6 STG.E.128.SYS [R10], R4 ;                                    /* 0x000000040a00e386 */
                                                                                                /* 0x0001e8000010ed00 */
        /*02e0*/              @!P4 STG.E.128.SYS [R12+0xc0], R4 ;                               /* 0x0000c0040c00c386 */
                                                                                                /* 0x0001e2000010ed00 */
        /*02f0*/              @!P0 EXIT ;                                                       /* 0x000000000000894d */
                                                                                                /* 0x000fea0003800000 */
        /*0300*/                   STG.E.128.SYS [R12], R4 ;                                    /* 0x000000040c007386 */
                                                                                                /* 0x000fe2000010ed00 */
        /*0310*/                   EXIT ;                                                       /* 0x000000000000794d */
                                                                                                /* 0x000fea0003800000 */
        /*0320*/                   BRA 0x320;                                                   /* 0xfffffff000007947 */
                                                                                                /* 0x000fc0000383ffff */
        /*0330*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0340*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0350*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0360*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0370*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
                ..........



Fatbin ptx code:
================
arch = sm_70
code version = [8,2]
host = linux
compile_size = 64bit
compressed
#

Observations:

  • we have a if sequence that loads the registers for each thread. This if-sequence is broken down by computation number. For computation 1, we are loading the A1 matrix with values of 1.0. For computation 2, we are loading the A2 matrix with values of 2.0. Likewise for computations 3 and 4. For all computations, we are loading the Bn matrix with 1.0. Therefore we expect computation results of all 4 for C1, all 8 for C2, all 12 for C3, all 16 for C4.
  • rather than loading register values directly, we could have loaded the registers from any shared memory location or any global memory location. This means that the computations 1 through 4 are working on independent data sources. This is true most directly when we look at the register footprint, but also true when we consider that we could have loaded these registers from anywhere.
  • to answer your question from this thread we can see in the SASS dump output that the PTX mma instruction for m8n8k4 compiles to two SASS instructions (in the sm_70 case, anyway), one labelled part0 and the other labelled part1. Since it is evident (now) that these two SASS instructions are somehow computing all 4 independent computations, we must conclude that a sufficient number of FMA ops are provided between those 2 SASS instructions, for all 4 computations. You could paste this code into godbolt to witness the same thing.

I think you are getting confused between FLOPs and FMA. An FMA constitutes two FLOPs. For m8n8k4 matrix-multiply, the number of required FMA operations is 8x8x4 = 256. The number of required FLOPs is 512.

SInce the 4 computations require 256 FMA or 512 FLOPs each, we must conclude that the two-instruction SASS sequence (part0/part1) must provide the necessary FLOPs, i.e. 4x512 = 2048. Beyond that, I’m not sure how to answer the question “where does the increased FLOPs come from?” The flops are provided by the Tensor Core unit. Since the code provably creates the correct results, we must conclude that the two instruction SASS sequence is enough to provide those FLOPs.

3 Likes