How many tensor cores to execute the wmma.mma.sync.aligned.{alayout}.{blayout}.m16n16k16 instruction?

Hi NVIDIA expert:

How many tensor cores to execute the wmma.mma.sync.aligned.{alayout}.{blayout}.m16n16k16 instruction for Volta and the other architecture?

The SASS instruction as following:

// Volta

/*01a0*/ HMMA.884.F16.F16.STEP0 R20, R12.reuse.ROW, R16.reuse.COL, RZ ; /* 0x000000100c147236 */
/*01b0*/ HMMA.884.F16.F16.STEP1 R22, R12.ROW, R16.COL, RZ ; /* 0x000000100c167236 */
/*01c0*/ HMMA.884.F16.F16.STEP0 R12, R14.reuse.ROW, R18.reuse.COL, R20 ; /* 0x000000120e0c7236 */
/*01d0*/ HMMA.884.F16.F16.STEP1 R14, R14.ROW, R18.COL, R22 ; /* 0x000000120e0e7236 */
/*01e0*/ HMMA.884.F16.F16.STEP0 R12, R4.reuse.ROW, R8.reuse.COL, R12 ; /* 0x00000008040c7236 */
/*01f0*/ HMMA.884.F16.F16.STEP1 R14, R4.ROW, R8.COL, R14 ; /* 0x00000008040e7236 */
/*0210*/ HMMA.884.F16.F16.STEP0 R12, R6.reuse.ROW, R10.reuse.COL, R12 ; /* 0x0000000a060c7236 */
/*0230*/ HMMA.884.F16.F16.STEP1 R14, R6.ROW, R10.COL, R14 ;

//Ampere
/*0130*/ HMMA.16816.F16 R12, R4.reuse, R12, RZ ; /* 0x0000000c040c723c */
/*0140*/ HMMA.16816.F16 R14, R4, R14, RZ ; /* 0x0000000e040e723c */

//Hopper
/*0160*/ HMMA.16816.F16 R12, R4, R12, RZ ; /* 0x0000000c040c723c */
/*0170*/ HMMA.16816.F16 R14, R4, R14, RZ ;

The wmma shape is 16x16x16, hmma shape is 8x8x4 or 16x8x16, but the tensor core only support 4x4x4, 8x4x8 or 8x4x16. So, I’m confused about how many tensor cores are needed for the wmma instruction?

Each TC (tensor core) SASS instruction will require a TC unit to operate on, per warp. The number of TC units per SM (and their throughput) vary by GPU architecture.

So when an HMMA instruction is issued warp-wide, it will occupy a TC unit for that issue slot/cycle. On subsequent cycles, another HMMA instruction can be issued to that unit, whether from the same warp in the case of independent instructions, or from a different warp assigned to that SM, based on the distribution of warps to warp schedulers/subpartitions.

Some per-architecture accounting can be found here.

To a first order approximation, the throughput of a particular op will be the number of flops/op for that instruction, divided by the number of flops per clock for a particular architecture TC unit, multiplied by the total number of TC units in the GPU (multiplied by the GPU clock rate).

I use the following code to verify the conclusion that “only one tensor core executes the wmma instruction” on H100.

#include <cstdio>
#include <cuda.h>
#include <cuda_runtime.h>
#include <mma.h>      
#include <nvml.h>     

using namespace nvcuda;

#ifndef NWARP
#define NWARP 1          
#endif

#ifndef ITERS
#define ITERS 200000     
#endif

constexpr int M = 16;
constexpr int N = 16;
constexpr int K = 16;

unsigned int get_current_sm_clock_MHz(int device_id)
{
    nvmlReturn_t result;
    nvmlDevice_t device;

    result = nvmlInit();
    if (result != NVML_SUCCESS) {
        std::printf("NVML init failed: %s\n", nvmlErrorString(result));
        return 0;
    }

    result = nvmlDeviceGetHandleByIndex(device_id, &device);
    if (result != NVML_SUCCESS) {
        std::printf("nvmlDeviceGetHandleByIndex failed: %s\n", nvmlErrorString(result));
        nvmlShutdown();
        return 0;
    }

    unsigned int sm_clock;
    result = nvmlDeviceGetClockInfo(device, NVML_CLOCK_SM, &sm_clock);
    if (result != NVML_SUCCESS) {
        std::printf("nvmlDeviceGetClockInfo failed: %s\n", nvmlErrorString(result));
        nvmlShutdown();
        return 0;
    }

    nvmlShutdown();
    return sm_clock;  // MHz
}

// =========================================================
// Kernel
// =========================================================
__global__ void wmma_tc_bench(unsigned long long* timing_out, float* sink)
{
    int warp_id = threadIdx.x / warpSize;
    int lane_id = threadIdx.x % warpSize;

    if (warp_id >= NWARP) return;

    using fragA_t = wmma::fragment<wmma::matrix_a, M, N, K, half,  wmma::row_major>;
    using fragB_t = wmma::fragment<wmma::matrix_b, M, N, K, half,  wmma::col_major>;
    using fragC_t = wmma::fragment<wmma::accumulator, M, N, K, float>;

    fragA_t a;
    fragB_t b;
    fragC_t c;

    wmma::fill_fragment(c, 0.0f);
    for (int i = 0; i < a.num_elements; ++i) a.x[i] = __float2half(1.0f);
    for (int i = 0; i < b.num_elements; ++i) b.x[i] = __float2half(1.0f);

    __syncthreads();

    unsigned long long start = 0, end = 0;
    if (lane_id == 0)
        start = clock64();

    #pragma unroll 1
    for (int it = 0; it < ITERS; ++it) {
        wmma::mma_sync(c, a, b, c);
    }

    if (lane_id == 0) {
        float acc = 0.f;
        for (int i = 0; i < c.num_elements; ++i)
            acc += c.x[i];
        sink[warp_id] = acc;

        end = clock64();
        timing_out[warp_id] = end - start;
    }
}

// =========================================================
// Host
// =========================================================
int main()
{
    int gpu_id = 0;
    cudaGetDevice(&gpu_id);
    std::printf("Using GPU %d\n", gpu_id);

    unsigned int sm_clock_MHz = get_current_sm_clock_MHz(gpu_id);
    if (sm_clock_MHz == 0) {
        std::printf("Failed to get SM clock from NVML.\n");
        return -1;
    }
    std::printf("SM Clock = %u MHz\n", sm_clock_MHz);

    unsigned long long* d_timing = nullptr;
    float* d_sink = nullptr;

    cudaMalloc(&d_timing, sizeof(unsigned long long) * NWARP);
    cudaMalloc(&d_sink, sizeof(float) * NWARP);

    dim3 block_dim(NWARP * 32);
    dim3 grid_dim(1);

    wmma_tc_bench<<<grid_dim, block_dim>>>(d_timing, d_sink);
    cudaDeviceSynchronize();

    unsigned long long h_timing[NWARP];
    float h_sink[NWARP];

    cudaMemcpy(h_timing, d_timing,
               sizeof(unsigned long long) * NWARP,
               cudaMemcpyDeviceToHost);
    cudaMemcpy(h_sink, d_sink,
               sizeof(float) * NWARP,
               cudaMemcpyDeviceToHost);

    cudaFree(d_timing);
    cudaFree(d_sink);

    std::printf("===============================================\n");
    for (int w = 0; w < NWARP; ++w) {
        std::printf("Warp %d cycles = %llu, sink = %f\n", w, h_timing[w], h_sink[w]);
    }
    std::printf("===============================================\n");

    double sm_freq_Hz = static_cast<double>(sm_clock_MHz) * 1e6;

    for (int w = 0; w < NWARP; ++w) {
        double cycles  = static_cast<double>(h_timing[w]);
        double seconds = cycles / sm_freq_Hz;

        double fma_per_mma = static_cast<double>(M) * N * K; // 4096
        double total_fma   = fma_per_mma * ITERS;

        double tfma_s = total_fma / seconds / 1e12;
        // TFLOPS: 1 FMA = 2 FLOPs
        double tflops = 2.0 * tfma_s;

        std::printf("Warp %d Throughput:\n", w);
        std::printf("  TFMA/s = %.3f\n", tfma_s);
        std::printf("  TFLOPS = %.3f\n", tflops);
    }

    std::printf("===============================================\n");
    return 0;
}

and the result is:

Using GPU 0
SM Clock = 1590 MHz
===============================================
Warp 0 cycles = 6600835, sink = 25600000.000000
===============================================
Warp 0 Throughput:
  TFMA/s = 0.197
  TFLOPS = 0.395
===============================================

But the theoretical value of TFMA/s for each TC is:

(8x4x16)(H100 TC unit delivers 512 FMA ops/clk) x (1590x10e6) / 10e12 = 0.81408

this is greater than 0.197, What is going on here?

I mentioned SASS in my response. To find out “What is going on here?” I believe it would be instructive to study the SASS; that is what I would do.

First, you may discover there is not a 1:1 correspondence between the CUDA C++ instrinsic wmma::mma_sync() and SASS HMMA instructions. Second, you will (I think) discover that you don’t really have a significantly long sequence of back-to-back HMMA instructions, which would be needed to get close to the peak theoretical throughput. And #pragma unroll 1 is not helping your case there, either. Don’t know why you would choose that for benchmarking.

Where does the 8x4x16 come from? In your source code you set the size to 16x16x16.

You could use Nsight Compute to show the bottleneck.

To be precise, the H100 delivers 512 FMA ops/clk(The h100 whitepaper said “The new fourth-generation Tensor Core architecture in H100 delivers double the raw dense and sparse matrix math throughput per SM, clock-for-clock, compared to A100“)

Where does this come from?

On all architectures since Ampere there is one TC per SM Partition. So a single warp will use one tensor core.

You cannot necessarily feed one SASS instruction per cycle. It depends on matrix size and data type and architecture.

Your 512 FMA for 16-bit-floats per tensor core on H100 is the published value.

I would test with many warps and one block and assume 4 tensor cores (as the warps are distributed on 4 partitions within the SM). You better hide all kind of latencies compared to testing with a single warp.

You need read the whitepaper again carefully.

Sorry, you mean the HMMA SASS instruction will require a TC unit. If I want to verify this conclusion, Would you like to teach me how to design an effective program?

When the WMMA instruction is compiled into SASS, it contains two HMMA instructions. So, it requires two Tensor Cores (TCs). However, one HMMA instruction involves 2048 floating-point FMA operations, a single Tensor Core only delivers 512 FMA operations on H100. Would you like to tell me how does it work?

You are talking about the physical tc with 512 = 8x4x16 FMA ops/cycle/TC. Okay, understood, I just would have used the 512 directly.

Only one TC has access to the registers of each warp, respectively.

To really prove it, you would have to show which warps interfere with each other and which are independent and you would find out there are 4 distinctive groups for the 4 positions on one SM.

The 4096 (not 2048) ops are done by feeding the one TC over 8 cycles. 16x16x16 can be mathematically split into 8 independent 8x4x16 matrix-matrix multiplications. (Okay not fully independent, some input data is reused.)

The various dissecting papers measured the timing of the tensor core instructions. There you see how many cycles per various tc instructions are needed. Including latency of the tc pipelines and throughput.

But as Robert_Crovella said, each HMMA SASS instruction requires a TC unit, the WMMA instruction is decomposed into 2 HMMA instructions when compile to SASS. So, the WMMA 16x16x16 instruction needs 2 TC unit, is this right?

One unit can sequentially execute several hmma instructions.

Since Ampere there is only one TC unit per SM partition (and 4 partitions per SM). Each warp is resident on one of the partitions.

We are not in FPGA country, where each instruction is mapped to physical circuits, but in processor with software country, where the program is iteratively executed and suitable execution units are chosen for each instruction. When they are finished the next instruction can use them.

The instructions are the SASS instructions (or even smaller units, as you mentioned 4096 vs. 512 FMA per cycle), not the WMMA instruction:

The matrix-matrix-multiplication does not have to be computed as one. Mathematically it comprises 16x16 = 256 independent vector-vector scalar multiplications, one such vector multiplication for each result element.

There is possibly a huge amount of information to cover here; I won’t be able to cover it all here; some research may be necessary.

Furthermore, I don’t believe NVIDIA documents TC specifics to the level that might be needed to answer every possible question.

The GPU is broken into SMs. In modern GPUs, each SM has multiple subpartitions (SMSPs). Warps are statically distributed to the SMSPs, and each SMSP has a collection of execution resources, including functional units, warp schedulers, register file, etc.

As already discussed, the CUDA C++ intrinsic TC op here (wmma::mma_sync...) seems to be decomposed by the compiler into two HMMA SASS instructions. Those instructions appear one after the other in the SASS code, and therefore in the instruction stream.

The warp schedulers in a modern SMSP are not dual-issue capable. Therefore in a single clock cycle they can only issue a single instruction. Furthermore, this is not an out-of-order machine; only the “topmost” instruction in the stream is eligible to be issued, in any particular cycle.

As a mental model, all functional units in a SM are pipelined. They can all accept a new instruction of a particular type/category serviced by that functional unit type, in any given clock cycle. (Yes, there are exceptions to this. It’s a mental model, and it is widely true. Not universally true. For example, some combinations of GPU/functional unit/instruction type may only be able to be issued once every other clock cycle.)

So when we have two HMMA instructions at different points in the instruction stream, it stands to reason that they do not necessarily require 2 TC units in order to issue. They must be issued in separate clock cycles, and based on the pipeline description already given, in those separate clock cycles, 2 HMMA instructions could be issued to the same TC unit. One gets issued in one clock cycle, and another gets issued in some later subsequent cycle.

And I am fairly convinced that is what happens here. I haven’t studied all the extant reverse-engineering papers, and I’m not suggesting that I can quote chapter and verse of NVIDIA documentation that explains exactly the mechanics by which TC instructions may be issued to TC functional units. If someone wants to correct me, please do so.

Your existing approach seems a reasonable path to me. I don’t have other suggestions. I have not studied your code in great detail, nor have I thought about this question in great detail. But with a back-to-back sequence of HMMA instructions, I would expect the dispatch rate for that sequence to be consistent with a calculation along the lines you have shown. I haven’t tried to verify your calculations, either. I would encourage you to take note of the two comments I gave already. In particular, although it comes pretty close, you have not achieved a long back-to-back sequence of HMMA instructions.

I haven’t studied the breakdown of flops to HMMA instructions carefully. If the TC unit can deliver 512 FMA ops/clk, and the given HMMA instruction (or the governing C++ intrinsic) would imply greater than 512 FMA ops per instruction, then clearly those cannot be issued back-to-back. I haven’t worked through all those calculations, but this could be a factor that impacts the observed “issue rate” which is what your code seems to determine. For example if the observed issue rate is lower by a factor of 2, but something consistent with peak throughput is delivered, then the instruction cannot be issued back-to-back. AFAIK these sorts of considerations are undocumented by NVIDIA.

Thank you for your detailed response, it has been very helpful. Thanks again.

Thanks for your help

From the SMSP perspective:
On V100, a HMMA.884.F16.STEP[0..1] instruction can be issued every 4 cycles. Each step instruction performs 512 multiply-add operations. A HMMA.884.F32.STEP[0..3] instruction can be issued every 2 cycles but it only does half the work.

On A100, a HMMA.16816 instruction (which does 2048 mul-add operations) can be issued every 8 cycles, i.e. for the next seven cycles the tensor core pipeline will not accept another instruction. You can confirm this by running the following fragment (in CuAssembler notation):

      [B------:R-:W-:-:S01]  CS2R.32 R4, SR_CLOCKLO ;           // R4 = t0
      [B------:R-:W-:-:S01]  HMMA.16816.F16 R44, R32, R40, RZ ;       //issued at t0 + 1
      [B------:R-:W-:-:S01]  CS2R.32 R5, SR_CLOCKLO ;           // R5 = t0 + 2
      [B------:R-:W-:-:S01]  HMMA.16816.F16 R46, R32, R42, RZ ;       //issued at t0 + 9
      [B------:R-:W-:-:S01]  CS2R.32 R6, SR_CLOCKLO ;           // R6 = t0 + 10

Despite the stall cycles being set to one (S01 in the control word), the second HMMA gets delayed until 8 cycles after the first, not two (with the CS2R issued in between).
You can issue other instructions between the two HMMA, however it’s important to realise that each HMMA needs to read up to 10 registers (4 for A, 2 for B, 4 for C), and the register file can only sustain two per cycle (one even, one odd), which doesn’t leave much for other instructions. That’s why good .reuse is important!

On Hopper you can’t achieve peak TC performance with HMMA. You need to use HGMMA (PTX: wgmma) for that (issued across the whole warpgroup), where either B or both A and B are sourced from shared memory.

1 Like

Very good post!

Should the numbers in the first paragraph be different? 8x8x4=256, not 1024? So with 2 steps 128 FMA operations are done, and with 4 steps half of it = 64?

On the other hand a Volta TC can do 64 FMA per cycle or 256 in 4 cycles, and there are (Volta + Turing) 2 TC per SM Partition. If both are fed with a single SASS instruction, one would get 512 FMA in 4 cycles or if only one is fed they each would need 8 cycles for 512 FMA and would work in an interleaved way.

No, as unintuitively HMMA.884 does four 8x8x4 computations, so it’s really 4x8x8x4=1024.

1 Like