Why the performance of tf32 tensor_core is poor?

hello, NV experts:

I want to test the performance of tf32 tensor core, so I create 2 tests with cublass, and set tf32 through this interface of cublass:

cublasSetMathMode(blas_handle, CUBLAS_TF32_TENSOR_OP_MATH)

I test it on two type of cuda gpu: RTX A4000 and RTX3090 .

first, let’s check the parameters of RTX A4000:

Device 0: "NVIDIA RTX A4000"
  CUDA Driver Version / Runtime Version          11.6 / 11.3
  CUDA Capability Major/Minor version number:    8.6
  Total amount of global memory:                 16109 MBytes (16891379712 bytes)
  (48) Multiprocessors, (128) CUDA Cores/MP:     6144 CUDA Cores
  GPU Max Clock rate:                            1560 MHz (1.56 GHz)
  Memory Clock rate:                             7001 Mhz
  Memory Bus Width:                              256-bit
  L2 Cache Size:                                 4194304 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total shared memory per multiprocessor:        102400 bytes
  Total number of registers available per block: 65536

and then, let’s check the result:

second, let’s check the parameters of RTX 3090:

Device 0: "NVIDIA GeForce RTX 3090"
  CUDA Driver Version / Runtime Version          11.6 / 11.3
  CUDA Capability Major/Minor version number:    8.6
  Total amount of global memory:                 24268 MBytes (25447170048 bytes)
  (82) Multiprocessors, (128) CUDA Cores/MP:     10496 CUDA Cores
  GPU Max Clock rate:                            1695 MHz (1.70 GHz)
  Memory Clock rate:                             9751 Mhz
  Memory Bus Width:                              384-bit
  L2 Cache Size:                                 6291456 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total shared memory per multiprocessor:        102400 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32

and then, let’s check the test result:

an important assumption about tensor core for tf32 in ampere:
each SM hold 4 tensor cores, and each tensor core can execute 128 tf32-fma per cycle.

based on above test and assumption, we can found following laws:
1). the capability usage of cuda core, RTX 3090 is better than RTX A4000;
2). the capability usage of tensor core, RTX A4000 is better than RTX 3090;
3). RTX 3090 and RTX A4000 are both poor at tensor core;

So, I have following questions:
1). my assumption about tf32 tensor core, is it right ?
2). why RTX 3090 is better than RTX A4000 at cuda core, at the same time, why RTX A4000 is better than RTX 3090 at tensor core?
3). why RTX 3090 and RTX A4000 are both poor at tensor core?if my assumption is right.

is there anyone would like to tell me the secret?

What did you actually time, and how did you compute the GFLOPS rate? I went over the post twice, but that information does not seem to be there. If you post minimal reproducer code, we could all be on the same page.

Generally speaking, given the amount of parallelism available in these high-end GPUs, I am pretty sure you want matrices bigger than 8K x 8K if the goal is to demonstrate maximum computational throughput. 8K x 8K is what we used to benchmark high-end GPUs in 2014 when I retired from NVIDIA. You might want to keep doubling the dimensions of the matrices until the GFLOPS rate no longer increases.

If this exercise usesSGEMM calls, I wonder whether it could become partially limited by the memory hierarchy if tensor cores are used. You may want to use the CUDA profiler to double check that the performance of the code is largely bound by computational throughput.

I just found the performance of my gpu is very poor in real application.
So, I build some base benchmark to figure out some base problem.
Yes, I’ve post my mini-reproducer,
base_test.zip (4.5 KB)

I just modify some key parameters to compute the expected gflops in main.cu, like this:

  TargetMacModel model(128, 48, 1560, 7001, 256);
  TargetMacModel::KERNEL_COST cost;

128 is cuda_cores_per_sm
48 is number of SM
1560 is GPU’s frequency(MHz)
7001 is memory frequency(MHz)
256 is memory bandwidth

and the implementation of TargetMacModel is in this file: ./common/mac_utils.h

and, I compute the real gflops through this clause:

model.cacl_gemm_mac_usage(M, N, K, ((double)elapsedTimeInMs / (double)loop) * 1000000.0, cost);

and, I just comment out this clause to switch between fp32 and tf32:

  cublasSetMathMode(blas, CUBLAS_TF32_TENSOR_OP_MATH);

You may want to use the CUDA profiler to double check that the performance of the code is largely bound by computational throughput

yes, you are right, I should check the profile to figure out the result

Sorry for the delayed response. The linked .zip file is missing some header files, so I created my own minimalist single-file benchmark, sgemm_bench.cu:

#include <stdio.h>
#include <stdlib.h>
#include <cublas_v2.h>

// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

int main (void)
{
    cublasOperation_t transa = CUBLAS_OP_N;
    cublasOperation_t transb = CUBLAS_OP_T;
    int m = 16384;
    int n = 16384;
    int k = 16384;
    float alpha = 1.0f;
    float beta = 0.0f;

    double start, stop, elapsed, tflop;
    cublasStatus_t stat = CUBLAS_STATUS_SUCCESS;
    cublasHandle_t handle;
    cublasCreate (&handle);

//    stat = cublasSetMathMode (handle, CUBLAS_TF32_TENSOR_OP_MATH);
    printf ("stat = %d\n", stat);

    int lda = (transa == CUBLAS_OP_N) ? max (1, m) : max (1, k);
    int ldb = (transb == CUBLAS_OP_N) ? max (1, k) : max (1, n);
    int ldc = max (1, m);
    int ka = (transa == CUBLAS_OP_N) ? k : m;
    int kb = (transb == CUBLAS_OP_N) ? n : k;
    
    size_t Asz = (size_t)lda * ka * sizeof (float);
    size_t Bsz = (size_t)ldb * kb * sizeof (float);
    size_t Csz = (size_t)ldc * n  * sizeof (float);
    float *A_d = 0, *B_d = 0, *C_d = 0;
    cudaMalloc ((void**)&A_d, Asz);
    cudaMalloc ((void**)&B_d, Bsz);
    cudaMalloc ((void**)&C_d, Csz);
    
    float *A = 0, *B = 0, *C = 0;
    A = (float*) malloc (Asz);
    B = (float*) malloc (Bsz);
    C = (float*) malloc (Csz);
    for (int i = 0; i < lda * ka; i++) A [i] = 1.0f;
    for (int i = 0; i < ldb * kb; i++) B [i] = 2.0f;
    cudaMemcpy (A_d, A, Asz, cudaMemcpyHostToDevice);
    cudaMemcpy (B_d, B, Bsz, cudaMemcpyHostToDevice);
    cudaMemset (C_d, 0xff, Csz); // initialize with NaN
    
    cudaDeviceSynchronize();
    for (int i = 0; i < 3; i++) {
        start = second();
        stat = cublasSgemm (handle, transa, transb, m, n, k, &alpha, A_d, lda, B_d, ldb, &beta, C_d, ldc);
        cudaDeviceSynchronize();
        stop = second();
    }
    cudaMemcpy (C, C_d, Csz, cudaMemcpyDeviceToHost);
//    for (int i = 0; i < m * n; i++) printf ("%15.8e \n", C[i]);

    elapsed = stop - start;
    tflop = 2.0e-12 * m * n *k;
    printf ("stat = %d\n", stat);
    printf ("elapsed = %.6f seconds\n", elapsed);
    printf ("TFLOPS  = %.6f\n", tflop / elapsed);
    return EXIT_SUCCESS;
}

I have a Quadro RTX 4000 with compute capability 7.5 and therefore compiled with

nvcc -arch=sm_75 -o sgemm_bench.exe sgemm_bench.cu cublas.lib

Running GPU-Z I can see that the GPU clock is boosted to 1725 MHz when the app is running, so the theoretical FP32 throughput is 7.95 TFLOPS. What I measure with the apps for matrices of size 16K x 16K across all four transpose mode variants is as follows:

N, N  7.65 TFLOPS
N, T  7.76 TFLOPS
T, N  7.32 TFLOPS
T, T  7.62 TFLOPS

The measured performance is very close to the theoretical limit in all cases. The fact that the N,T variant is the fastest variant is expected, as this data arrangement causes the least amount of overhead in a GEMM computation.

Note that the Quadro RTX 4000 is a mid-range GPU, you would want to increase the matrix size to at least 32K x 32K when testing with a high-end GPU.

thanks for your fast response.
I’ve test your code, I just test it on RTX 3090, the result without tensor_core as following:
8192x8192x8192:
elapsed = 0.040332 seconds
TFLOPS = 27.261467
16384x16384x16384:
elapsed = 0.325122 seconds
TFLOPS = 27.054756
32768x32768x32768:
elapsed = 2.585470 seconds
TFLOPS = 27.217003

So, the result of your code is the similar with mine.

the results of tf32(tensor core) are also similar, I don’t want to print them here.

So, I’m very confused, why the performance is so poor

at the same time, I compared the parameters between RTX 3090 and A100. there is an important difference: Global mem of RTX 3090 is DDR5 and A100 is HBM。
the memory bandwidth of 3090 is about 936GB/s, and A100 is 1.9TB/s。
So, I think memory bandwidth of 3090 and A4000 cannot support their theoretical FP32(and TF32) throughput, is it right?
is there any other method to verify this trouble?

(1) Are you using CUDA 12.2 for your experiments? From historical observation it takes quite some time for the myriad GEMM kernels to become fully optimized based on hand-crafted assembly code.

(2) As I said, it is possible that GEMM becomes memory bandwidth limited as the FLOPS keep growing faster than memory bandwidth. There is a 5x gap in computational throughput between the Quadro RTX 4000 and the RTX 3090, but only a 2.3x gap in memory throughput. I don’t have access to high-end hardware, so you will have to investigate this hypothesis yourself with the help of the CUDA profiler.

(3) I assume you have double-checked that your high-end GPUs boost to the expected GPU clocks? The clock speed of modern GPUs is widely variable based on environmental factors, the biggest of which is often cooling.

1). my toolchain is cuda_sdk_118
2). yes, I’ll check this trouble through NCU.
3). for example, RTX 3090, its base frequency is 1400MHz, and its boost frequency is 1700(1695 read frome deviceQuery). I just compute peak usage with boost frequency, I think the difference between base frequency and boost frequency is not obvious. I don’t know how to understand the “double-checked”?

https://en.wiktionary.org/wiki/double_check

The GPU boost clock of 1695 MHz for the RTX 3090 is the nominal boost clock. Opportunistically, depending on environmental factors it can boost to the 1800 MHz to 1900 MHz range from what I read on the internet. Special vendor-tuned variants can reach 2000 MHz.

The clock is adjusted dynamically with latencies in the millisecond range. Under Windows there is a nice graphical utility called GPU-Z which allows the real-time tracking of the GPU clocks. For my Quadro RTX 4000 the nominal boost clock is 1545 MHz, but this is clearly exceed when doing sgemm, per the above data.

If my hunch about memory throughput limitations is correct, higher boost clocks wouldn’t help. On my GPU I can see a difference (higher boost clock → higher TFLOPS)

thank you

hello, njuffa
I’ve checked the profile through NCU on my RTX A4000(I test 4096x4096x4096), would you like to help me to figure out the reason of poor performance?
the Pipe Utilization show me as following:

the FMA utilization is about 50%, I think it match my test result.
and NCU tell me as following:

[Warning] LSU is the highest-utilized pipeline (87.9%). It executes load/store memory operations. The pipeline is over-utilized and likely a performance bottleneck.

I don’t know why LSU is the bottleneck, I guess the latency of shared_mem is very poor. So, I create a test about the latency of shared_mem, and the result as following:

shared memory accessed: 2097152 byte
duration: 18766 cycles
shared memory bandwidth per SM (measured): 111.752747 byte/cycle
shared memory bandwidth per SM (theoretical): 128 byte/cycle
standard clock frequency: 1560 MHz
SM: 48
whole chip shared memory bandwidth (theoretical): 9584.639648 GB/s
shared memory latency 23 cycles

the measured bandwidth of shared memory is close to the theoretical, though shared memory latency is 23 cycles

I’m very confused, how to understand the NCU’s “Warniing” about LSU?

and then, I continue to check the “memory work load” in NCU, like this:

the “% Peak” is 53.5 for “Shared Load”, and 2.11 for “Shared Store”. obviously, these values are low. more important things from above table is Bank Conflicts. I think those “Peak” value are low, is because of these bank conflict?

then, I continue to check the “Scheduler Statistics”, like this:


“Issued Warp Per Scheduler” is 0.71, I think this value cannot match expectation.

then, I continue to check the “warp state statistics”, like this:

the topest is “Stall Not Selected”, instead of “Stall Long Scoreboard”, So, I don’t think the bandwidth of Global Memory is the main reason for poor performance, is it right?

at last, I checked the “Instruction Statistics”, like this:

topest is FFMA and LDS, the number of FFMA is 2,417,483,648, and the number of LDS is 269,156,352, FFMA/LDS < 10, So, I think the real bandwidth of shared memory cannot support higher performance, and, I think FFMA/LDS should be > 20, if RTX A4000 take full use of FFMA capability, because the latency of shared_mem is 23, is it right?

conclusion:
I think the main reason of poor performance is shared memory instead of global memory. I do not hold other cuda GPUs, so, I don’t know what is the suitable value for the latency of shared memory. Would you like to tell me?

njuffa, would you like to check my analyzation ?

That seems like plausible reasoning to me, but I am not an expert on analyzing the memory hierarchy in this fashion and I am very tired now. NVIDIA has experts that can perform such analysis. Consider posting the question in the profiler sub-forum, that might be a better place to ask. Generally speaking, the floating-point intensity of the code may be too low to cover all applicable latency in the memory hierarchy.

A latency of 23 cycles for a shared memory access seems eminently plausible to me, or more generally around 20 cycles. GPUs are built as throughput machines, not ninja-tuned for low latency.

That strikes me as a bit low. Where does that number come from?

like this:

float chip_bandwidth = float(sm) * bw_theoretical * clk / 1000;

this is theoretical bandwidth:
48 * 128 * 1560MHz, so the result is 9584.639648 GB/s

Please post the full NCU report. Based on the pipeline utilization the kernel is executing full FP32 vs TF32. If TF32 was being used then Tensor FP pipe would be high.

The bandwidth calculation for shared memory is correct. A latency limited application would not have high LSU bandwidth. The LSU pipe includes shared, global, and local memory accesess.

Hello Greg,
my profile was based on fp32, instead of tf32
I think I should figure out the bottleneck of fp32 at first, so, I didn’t profile tf32.
I’ve uploaded the full profile generated by ncu, check it please.

some important assumption about tensor_core@tf32:
I estimated that: “each SM hold 4 tensor cores, and each tensor core can execute 128 tf32-fma per cycle” in previous post, is it right?
now, I feel A4000 is 64 tf32-fma per cycle, and 3090 is 32 tf32-fma per cycle, is it right?I’m not sure.
report.ncu-rep (6.2 MB)

The A4000 (GA104) based GPU has a sustained throughput of 0.5 Load Store Unit (LSU) instructions/cycle and 1 LSU wavefront/cycle. The ampere_sgemm_128x32_nn is utilizing shared memory and global memory very efficiently; however, the performance is limited by the LSU.

The GPU Speed of Light SOL SM and SOL Memory have the same value (87.14%). For CC 7.0 - 9.0 the LSU instruction throughput and the LSU request throughput are the same. The former is part of sm__instruction_throughput and the latter the gpu__compute_memory_throughput.

sm__inst_executed_pipe_lsu.sum approx= SharedMemory::Instructions[Total] + L1/TEX Cache::Instructions[Total]
437,551,104 = 353,140,736 + 84,410,368

The report did not maintain the value of sm__cycles_elapsed.sum but it can be approximated as gpc__cycles_elapsed.max x 48 (SM count) = 20,955,367 x 48 = 1,005,857,616

sm__inst_executed_pipe_lsu.avg.peak_sustained = 0.5

437,551,104 / (1,005,857,616 x 0.5) = 87%

CC 7.0 (GV100), CC 8.0 (GA100) and CC 9.0 (GH100) can sustain 1 LSU instruction/cycle/SM.

For FP32 and TF32 the critical issue is shared memory instruction throughput. The kernel is already efficiently using LDS.128 and STS.128 for many of the accesses.

1 Like

Hi, Greg
so kindly, thanks for your so clear description.
I think you are a good teacher
But, I still have some question about some data:

The A4000 (GA104) based GPU has a sustained throughput of 0.5 Load Store Unit (LSU) instructions/cycle and 1 LSU wavefront/cycle

  1. 0.5 and 1, they are hardware limit or runtime result? would you like to tell me where to find this data? if they are hardware limit.
  2. how to understand “wavefront/cycle”? what is wavefront? why this conception only appear in the “memory work load”?
  3. how to caculate the bandwidth of shared memory? I printed my formular in previous post, like this:
float chip_bandwidth = float(sm) * bw_theoretical * clk / 1000;

I think my formular is wrong if the throughput of LSU is 0.5 instructions/cycle.
the chip_bandwidth should be 0.5 * 48 * 128 * 1560MHz, so the result is 0.5*9584.639648 GB/s
but, you said my formular is correct.:(
So, I don’t know how to understand the relationship between throught of LSU and wavefront/cycle.

  1. there are many bank conflicts in the table of “memory work load”, but, you said “ampere_sgemm_128x32_nn is utilizing shared memory and global memory very efficiently”, why? how to understand these bank conflicts?

Would you like to help me figure out these conceptions? or, where can I find these reference documents?

  1. 0.5 and 1, they are hardware limit or runtime result? would you like to tell me where to find this data? if they are hardware limit.

The sustained rate of LSU instructions and L1TEX wavefronts is a hardware limit.

  1. how to understand “wavefront/cycle”? what is wavefront? why this conception only appear in the “memory work load”?

wavefront is the unit of work passed through the L1TEX pipes including:

  • LSUIN to the shared memory pipe
  • LSUIN to the tagged global/local/dsmem pipe
  • TEXIN to the tagged texture/surface pipe

The Kernel Profiling Guide :: Nsight Compute Documentation section on Hardware Model provides and overview of the L1TEX unit and the a high level diagram.

The Kernel Profiling Guide :: Nsight Compute Documentation section on Metrics Decoder provides the following information on wavefronts:

wavefront - Unique “work package” generated at the end of the processing stage for requests. All work items of a wavefront are processed in parallel, while work items of different wavefronts are serialized and processed on different cycles. At least one wavefront is generated for each request.

A simplified model for the processing in L1TEX for Volta and newer architectures can be described as follows: When an SM executes a global or local memory instruction for a warp, a single request is sent to L1TEX. This request communicates the information for all participating threads of this warp (up to 32). For local and global memory, based on the access pattern and the participating threads, the request requires to access a number of cache lines, and sectors within these cache lines. The L1TEX unit has internally multiple processing stages operating in a pipeline.

A wavefront is the maximum unit that can pass through that pipeline stage per cycle. If not all cache lines or sectors can be accessed in a single wavefront, multiple wavefronts are created and sent for processing one by one, i.e. in a serialized manner. Limitations of the work within a wavefront may include the need for a consistent memory space, a maximum number of cache lines that can be accessed, as well as various other reasons. Each wavefront then flows through the L1TEX pipeline and fetches the sectors handled in that wavefront. The given relationships of the three key values in this model are requests:sectors is 1:N, wavefronts:sectors 1:N, and requests:wavefronts is 1:N.

A wavefront is described as a (work) package that can be processed at once, i.e. there is a notion of processing one wavefront per cycle in L1TEX. Wavefronts therefore represent the number of cycles required to process the requests, while the number of sectors per request is a property of the access pattern of the memory instruction for all participating threads. For example, it is possible to have a memory instruction that requires 4 sectors per request in 1 wavefront. However, you can also have a memory instruction having 4 sectors per request, but requiring 2 or more wavefronts.

  1. how to caculate the bandwidth of shared memory? I printed my formular in previous post, like this:

The bandwidth is based upon the read and/or write bandwidth which is 128B/cycle. The LSUIN interface is only 16 threads/cycle. In order to obtain maximum throughput LDS.64 or STS.64 have to be used.

The following GTC presentations provide additional information:

1 Like

Hi Greg,
you prompt that “The sustained rate of LSU instructions and L1TEX wavefronts is a hardware limit”, I search it in Ampere’s whitepaper, and I didn’t find anything about these.
furthore more, I search the descriptions about “LSUIN interface is only 16 threads/cycle”, I still didn’t find it in Ampere’s whitepaper.
Would you like to tell me where can I find these limits about a specific cuda arch?
you said:

CC 7.0 (GV100), CC 8.0 (GA100) and CC 9.0 (GH100) can sustain 1 LSU instruction/cycle/SM

how about other archs? for example GP102 GP107 etc.
So, how to find the hardware limit in NV’s documents?

in the previous post, you prompt this: “sm__inst_executed_pipe_lsu.avg.peak_sustained = 0.5”
yes, I think this clue is very important for above hard limits, but, I’m very confused with it, I think it is conflict with “memory work load”:
1691319581(1)
you see, the “% Peak”, total value is 63.13%, how to understand their differences?

The GPU whitepapers do not contain the level of micro-architecture detail that you are requesting. The limits are disclosed in the metrics peak_sustained value.

GP10x is a significantly different architecture. I would advise you to review the metrics (exposed in earlier version of the tools) and via micro benchmarks.

In the Shared Memory Table the % Peak column is the % peak throughput for the wavefronts.
The metric for Shared Load is l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum.pct_of_peak_sustained_elapsed
The Total row is the summation of the rows.

1 Like

ok
thanks for your patient