How to optimize my cuda code?

I have a simple program, I just want to verify my GPU real performance. but, its result is out of my expectation. I don’t know how to explain it, and how to optimize my program. So, I hope NV’s experts can help me.

the detail about my GPU as following:
Device 0: “NVIDIA RTX A4000”
CUDA Driver Version / Runtime Version 11.6 / 11.3
CUDA Capability Major/Minor version number: 8.6
Total amount of global memory: 16109 MBytes (16891379712 bytes)
(48) Multiprocessors, (128) CUDA Cores/MP: 6144 CUDA Cores
GPU Max Clock rate: 1560 MHz (1.56 GHz)
Memory Clock rate: 7001 Mhz
Memory Bus Width: 256-bit
L2 Cache Size: 4194304 bytes
Maximum Texture Dimension Size (x,y,z) 1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
Maximum Layered 1D Texture Size, (num) layers 1D=(32768), 2048 layers
Maximum Layered 2D Texture Size, (num) layers 2D=(32768, 32768), 2048 layers
Total amount of constant memory: 65536 bytes
Total amount of shared memory per block: 49152 bytes
Total shared memory per multiprocessor: 102400 bytes
Total number of registers available per block: 65536
Warp size: 32
Maximum number of threads per multiprocessor: 1536
Maximum number of threads per block: 1024
Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
Max dimension size of a grid size (x,y,z): (2147483647, 65535, 65535)
Maximum memory pitch: 2147483647 bytes
Texture alignment: 512 bytes
Concurrent copy and kernel execution: Yes with 2 copy engine(s)

ok, let me introduce my program, my program is very very simple, just execut “ffma”, details as following:

  1. each thread process 4 floats(float4);
  2. there are 256 threads in each block;
  3. there are 2 blocks in each SM;
  4. my formular in each thread is very simple: C += A * B ;
  5. each thread just read 4 floats for A from global_memory, just read 4 floats for B from global, and write 4 floats for C into global_memory;
  6. each thread repeat above formular for 2048 rounds, liks this:
    for(int i = 0 ; i < 2048 ; i ++) C += A * B.
  7. I implement the loop body with ptx, I think it can avoid the optimization behavior of nvcc. I just want to avoid nvcc optimizing my code like this: C = 2048 * (A * B)
  8. I define 8 float4 in every thread to avoid the dependency about C, like this:
    float4 A = read_f4(ptr_A), B = read_f4(ptr_B)
    float4 C0, C1, C2, C3, C4, C5, C6, C7;
    loop = 2048 >> 5;
    for(int i = 0 ; i < loop ; i ++){
    C0 += A * B;
    C1 += A * B;
    C2 += A * B;
    C3 += A * B;
    C4 += A * B;
    C5 += A * B;
    C6 += A * B;
    C7 += A * B;
    C0 += A * B;
    C1 += A * B;
    C2 += A * B;
    C3 += A * B;
    C4 += A * B;
    C5 += A * B;
    C6 += A * B;
    C7 += A * B;
    C0 += A * B;
    C1 += A * B;
    C2 += A * B;
    C3 += A * B;
    C4 += A * B;
    C5 += A * B;
    C6 += A * B;
    C7 += A * B;
    C0 += A * B;
    C1 += A * B;
    C2 += A * B;
    C3 += A * B;
    C4 += A * B;
    C5 += A * B;
    C6 += A * B;
    C7 += A * B;
    }
    C0 += C1;
    C2 += C3;
    C4 += C5;
    C6 += C7;
    C0 += C2;
    C4 += C6;
    C0 += C4;
    store_f4(ptr_C, C0);

the performance should be ~10Tflops(x2=20Tflops),but the result of my program is about 6.5T, just about 65% peak performance.
I modified the blocksPerSM to 2, 4, 8, and I modified the threadPerBlock to 128/256/512。 unfortunately, these results are very similar, about 60% - 65%。

and then, I profiled my program with NCU,it tell me “The ratio of peak float (fp32) to double (fp64) performance on this device is 64:1. The kernel achieved 61% of this device’s fp32 peak performance and 0% of its fp64 peak performance.” in “Roofline Analysis”

I think my program have avoid memory-access, avoid register dependency, I don’t know why my peak performance is 61%

I tried to read the profile information in NCU,but, I found I still cannot found the reason of poor performance.

I’ve uploaded my program and the profile file from NCU,
base_mac.tar.gz (3.7 KB)
repoprt.ncu-rep (12.6 MB)

Is there anyone would like to teach me? I think the keypoint is reading the profile file, but I cannot understand them, is there anyone would like to help me?

Highly unlikely to be a good idea. The CUDA compiler is based on LLVM, an extremly powerful framework for code transformations, i.e. optimizations. If you run into the compiler optimizing away code that you don’t want to have optimized away, create dependencies that prevent that from happening. Your chosen approach for measuring peak FP32 throughput appears to be the common method of using independent dot products. You would want to sum these dot products at the end and write the result to global memory to avoid dead code elimination.

Instruction caches on GPUs tend to be pretty small, and your massive loop may exceed the size of the instruction cache, which based on past experience may cost 3% of performance.

It is easier to fill up the SMs as much as possible using relatively fine granularity, e.g. use 128 threads per thread block as a starting point.

You may be better off using dot-products floats instead of float4s. The latter is likely to result in higher register pressure.

You are unlikely to achieve more than 85% of theoretical FP32 throughput, as the ptxas compiler is unlikely to produce a perfect instruction scheduling with perfect register assignment. So there will be bubbles in the pipeline caused by register bank conflicts and execution pipe contention.

[Later: ]

Below is a simple test scaffold for measuring FP32 throughput. It is currently configured for the 9 year old low-end GPU in my web-browsing machine, for which it achieves 86% of theoretical peak FP32 throughput. Increase MAX_BLOCKS, REPS, ITER to adapt to your hardware. Then vary POLY_DEPTH to see how throughput changes.

#include <stdlib.h>
#include <stdio.h>

#define MAX_BLOCKS      (65520)
#define THREADS_PER_BLK (128)
#define LEN             (MAX_BLOCKS * 1024)
#define POLY_DEPTH      (512)
#define REPS            (2)
#define ITER            (10)

#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

// Macro to catch CUDA errors in CUDA runtime calls
#define CUDA_SAFE_CALL(call)                                          \
do {                                                                  \
    cudaError_t err = call;                                           \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

// Macro to catch CUDA errors in kernel launches
#define CHECK_LAUNCH_ERROR()                                          \
do {                                                                  \
    /* Check synchronous errors, i.e. pre-launch */                   \
    cudaError_t err = cudaGetLastError();                             \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
    /* Check asynchronous errors, i.e. kernel failed (ULF) */         \
    err = cudaDeviceSynchronize();                                    \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString( err) );      \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

__global__ void kernel (const float * __restrict__ src, 
                        float * __restrict__ dst, int len)
{
    int stride = gridDim.x * blockDim.x;
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    for (int i = tid; i < len; i += stride) {
        float p = src[i] + 1.000001f;
        float q = src[i] + 1.000002f;
        for (int k = 0; k < REPS; k++) {
#pragma unroll POLY_DEPTH
            for (int j = 0; j < POLY_DEPTH; j++) {
                p = fmaf (p, p, 1.000001f);
                q = fmaf (q, q, 1.000002f);
            }
        }
        dst[i] = p + q;
    }
}    

int main (int argc, char *argv[])
{
    double start, stop, nbr_of_fma;
    float *d_a, *d_b;

    /* Allocate memory on device */
    CUDA_SAFE_CALL (cudaMalloc((void**)&d_a, sizeof(d_a[0]) * LEN));
    CUDA_SAFE_CALL (cudaMalloc((void**)&d_b, sizeof(d_b[0]) * LEN));
    
    /* Initialize device memory */
    CUDA_SAFE_CALL (cudaMemset(d_a, 0x00, sizeof(d_a[0]) * LEN)); // zero

    /* Compute execution configuration */
    dim3 dimBlock(THREADS_PER_BLK);
    int threadBlocks = (LEN + (dimBlock.x - 1)) / dimBlock.x;
    dim3 dimGrid(threadBlocks);
    
    printf ("burn: using %d threads per block, %d blocks, %f GB used\n", 
            dimBlock.x, dimGrid.x, 2*1e-9*LEN*sizeof(d_a[0]));

    start = second();
    for (int k = 0; k < ITER; k++) {
        kernel<<<dimGrid,dimBlock>>>(d_a, d_b, LEN);
        CHECK_LAUNCH_ERROR();
    }
    stop = second();
    nbr_of_fma = (2.0 * POLY_DEPTH * REPS + 3.0) * LEN * ITER;
    printf ("flop=%13.6e  elapsed=%.5f sec  throughput=%.5f FP32 GFLOPS\n", 
            nbr_of_fma * 2, stop-start, nbr_of_fma * 2 * 1e-9 / (stop - start));

    CUDA_SAFE_CALL (cudaFree(d_a));
    CUDA_SAFE_CALL (cudaFree(d_b));

    return EXIT_SUCCESS;
}

so kindly, thanks for your code.
I run your code, and I get the performance above 90%, so cool
but, I found a difference between my poor code and your good code.
I found the clause about fma in your code like this:

                p = fmaf (p, p, 1.000001f);
                q = fmaf (q, q, 1.000002f);

at the same time, my code like this:

c += a * b;

and then, I modified your code like this:

__global__ void kernel (const float * __restrict__ src, 
                        float * __restrict__ dst, int len)
{
    int stride = gridDim.x * blockDim.x;
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    for (int i = tid; i < len; i += stride) {
        float p = src[i] + 1.000001f;
        float q = src[i] + 1.000002f;
        float r = 0.0f;
        for (int k = 0; k < REPS; k++) {
            #pragma unroll(512)
            for (int j = 0; j < POLY_DEPTH; j++) {
                r = fmaf (p, q, r);
            }
        }
        dst[i] = r * 0.0001;
    }
}   

and modify the statistic clause like this:

 nbr_of_fma = (POLY_DEPTH * REPS + 3.0) * LEN * ITER;

and then, the performance is about 50%.

as comparison, I modify my blockPerSM to 64, and modify the loop body like this:

  #pragma unroll(128)
  for(int i = 0 ; i < loop ; i ++){
    vec_C.x = fmaf (vec_A.x, vec_B.x, vec_C.x);
    vec_C.y = fmaf (vec_A.y, vec_B.y, vec_C.y);
    vec_C.z = fmaf (vec_A.z, vec_B.z, vec_C.z);
    vec_C.w = fmaf (vec_A.w, vec_B.w, vec_C.w);
}
  *ptr_C = vec_C;

performance is about 60%.

another modifcation like this:

  loop = loop >> 1;
  #pragma unroll(128)
  for(int i = 0 ; i < loop ; i ++){
    vec_A.x = fmaf (vec_A.x, vec_A.x, 0.0001f);
    vec_A.y = fmaf (vec_A.y, vec_A.y, 0.0001f);
    vec_A.z = fmaf (vec_A.z, vec_A.z, 0.0001f);
    vec_A.w = fmaf (vec_A.w, vec_A.w, 0.0001f);

    vec_B.x = fmaf (vec_B.x, vec_B.x, 0.0002f);
    vec_B.y = fmaf (vec_B.y, vec_B.y, 0.0002f);
    vec_B.z = fmaf (vec_B.z, vec_B.z, 0.0002f);
    vec_B.w = fmaf (vec_B.w, vec_B.w, 0.0002f);
  }
  vec_C.x = vec_A.x + vec_B.x;
  vec_C.y = vec_A.y + vec_B.y;
  vec_C.z = vec_A.z + vec_B.z;
  vec_C.w = vec_A.w + vec_B.w;

  *ptr_C = vec_C;

this performance is also about 90%,
I think the third parameter of fmaf is constant, so, I feel this result(90%) cannot prove the real performance.

I think the keypoint is “register reuse” and data dependency, but your code there are dependency(a = a * a + 1.0001f) also, and the performance is good. why?

I checkd their sass,
the sass about the fma in your code, like this:

        /*0130*/                   FFMA R6, R6, R6, 1.0000009536743164062 ;       /* 0x3f80000806067423 */
                                                                                  /* 0x000fe20000000006 */
        /*0140*/                   FFMA R7, R7, R7, 1.0000020265579223633 ;       /* 0x3f80001107077423 */
                                                                                  /* 0x000fc60000000007 */
        /*0150*/                   FFMA R6, R6, R6, 1.0000009536743164062 ;       /* 0x3f80000806067423 */
                                                                                  /* 0x000fe20000000006 */
        /*0160*/                   FFMA R7, R7, R7, 1.0000020265579223633 ;       /* 0x3f80001107077423 */
                                                                                  /* 0x000fc60000000007 */

the sass about fma in my code, like this:

        /*0190*/                   FFMA R17, R4, R8, R12 ;                      /* 0x0000000804117223 */
                                                                                /* 0x020fe2000000000c */
        /*01a0*/                   FFMA R12, R5, R9, R13 ;                      /* 0x00000009050c7223 */
                                                                                /* 0x000fe2000000000d */
        /*01b0*/                   FFMA R13, R6, R10, R14 ;                     /* 0x0000000a060d7223 */
                                                                                /* 0x000fe2000000000e */
        /*01c0*/                   FFMA R14, R7, R11, R15 ;                     /* 0x0000000b070e7223 */
                                                                                /* 0x000fe2000000000f */
        /*01d0*/                   FFMA R17, R4, R8, R17 ;                      /* 0x0000000804117223 */
                                                                                /* 0x000fe20000000011 */
        /*01e0*/                   FFMA R12, R5, R9, R12 ;                      /* 0x00000009050c7223 */
                                                                                /* 0x000fe2000000000c */
        /*01f0*/                   FFMA R13, R6, R10, R13 ;                     /* 0x0000000a060d7223 */
                                                                                /* 0x000fe2000000000d */
        /*0200*/                   FFMA R14, R7, R11, R14 ;                     /* 0x0000000b070e7223 */
                                                                                /* 0x000fe2000000000e */
        /*0210*/                   FFMA R17, R4, R8, R17 ;                      /* 0x0000000804117223 */
                                                                                /* 0x000fe20000000011 */
        /*0220*/                   FFMA R12, R5, R9, R12 ;                      /* 0x00000009050c7223 */
                                                                                /* 0x000fe2000000000c */
        /*0230*/                   FFMA R13, R6, R10, R13 ;                     /* 0x0000000a060d7223 */
                                                                                /* 0x000fe2000000000d */
        /*0240*/                   FFMA R14, R7, R11, R14 ;                     /* 0x0000000b070e7223 */
                                                                                /* 0x000fe2000000000e */

obviously:

  1. there is not any data dependency in your sass;
  2. there are data dependency in my sass(I just guess)

So, would you like to teach me how to avoid these dependency, if these dependency really exist.

The code I posted above is an ad-hoc adaption of some code I have had sitting around for quite a few years. I seem to recall that I chose the particular arrangement of FMAs used so as to minimize register bank conflicts, but I do not know for sure.

Having been retired for almost a decade, I am a hobbyist these days who will, often on a whim, explore some issue for an extended afternoon, then forget the exploratory code once my curiosity is satisfied: “The journey is the reward”. I rarely keep notes on what I tried and why.

Sustaining three-input operations at full speed is a challenge in all processor architectures due to the tremendous bandwidth required (3 read ports, 1 write port on the register file). The problem is exacerbated by multi-issue capability. A common way to boost register file bandwidth on average is to use register banks, each of which provides fewer read (and possibly, write) ports. To my knowledge, all NVIDIA GPUs use a (publicly undocumented) scheme of this nature to boost practically available bandwidth. Bank conflicts causing pipeline bubbles may occur intra-instruction or inter-instruction in case of multi-issue capability.

thanks for your clear explanation
I see,R4, R8, R12 is conflict(register bank, 0 == 4%4,0 == 8%4, 0 == 12%4, they are in the same bank,0), and R5/R9/R13,R6/R10/R14, R7/R11/R15, they are all conflict. so, performance is poor as your explanation, right?

        /*0190*/                   FFMA R17, R4, R8, R12 ;                      /* 0x0000000804117223 */
                                                                                /* 0x020fe2000000000c */
        /*01a0*/                   FFMA R12, R5, R9, R13 ;                      /* 0x00000009050c7223 */
                                                                                /* 0x000fe2000000000d */
        /*01b0*/                   FFMA R13, R6, R10, R14 ;                     /* 0x0000000a060d7223 */
                                                                                /* 0x000fe2000000000e */
        /*01c0*/                   FFMA R14, R7, R11, R15 ; 

but I don’t know how to solve it, because the SASS is complied by nvcc.
maybe, I don’t know some tricky, would you like to give me more further advise?

The nvcc compiler is aware of performance issues related to register usage. While its certainly possible that this can be improved, its also possible that this is the best tradeoff of choices about register usage.

It should be evident in a large dependent sequence like this that register usage changes will have side effects. Changes you make to address performance on one instruction may have a negative performance impact elsewhere in the dependent chain.

I don’t know of “tricks” to tell nvcc to reorganize its register usage. You could try playing with very crude, coarse controls like -maxrregcount switch to the compiler.

The other options I know of are:

  • modify the source code and study the result changes at the SASS level
  • based on the knowledge of SASS behavior that you have acquired so far, look at the SASS code wholistically, that is for a complete function, not just a single instruction or line of code, and see if you can come up with a wholistic register usage pattern that avoids the “banks” here (if those are actually the banks.)

Based on the 2nd option, you could file a bug to request study by the compiler team, if you think you can do better than the compiler. But you would need a well-documented example, showing a wholistic solution. Even then, there are probably knowledge gaps that make this sort of approach difficult.

To quote Wikipedia:

The problem of optimal register allocation is NP-complete. As a consequence, compilers employ heuristic techniques to approximate its solution.

I think it is reasonable to assume that the CUDA compiler engineers in charge of ptxas are fully aware of the latest developments in the field and that heuristics that consider the general problem constraint by GPU-specific restrictions (such as calling conventions , register aggregation for 64-bit operations or vector loads/stores, dual issue, and register banks) are in place. It would also be reasonable to assume that after 15 years of development, this fundamental building block of a compiler is mature.

That does not mean there could not be room for improvement, just that the burden of demonstrating noticeably improved performance from superior register allocation for specific cases lies with a prospective bug filer.

I just heard of register bank is “%4”,but, I didn’t find it in nv’s public document.
So, I must confirm the restriction for register bank at first, would you like to give me some documents about it?
my compiler is nvcc 11.8, and the arch is sm86

I am not aware that the details of the GPU register file organization are disclosed in official CUDA documentation. I also cannot find any relevant details in papers from people who have explored GPU microarchitectures with targeted microbenchmarks. An older paper by NVIDIA subject matter experts,

Mark Gebhart, Stephen W. Keckler, Brucek Khailany, Ronny Krashinsky, William J. Dally, “Unifying Primary Cache, Scratch, and Register File Memories in a Throughput Processor”

gives some details, but I have no idea how these relate to newer architectures:

Each MRF bank is 16 bytes wide with 4 bytes allocated to the
same-named architectural register for threads in each of the 4 SIMT
lanes in the cluster. Each bank has a capacity of 8KB, providing
a total of 256KB of register file capacity per SM. Registers are
interleaved across the register file banks to minimize bank conflicts.
Instructions that access multiple values from the same bank incur a
cycle of delay for each access beyond the first. The operand buffering
between the MRF and the execution units represents interconnect
and pipeline storage for operands that may be fetched from the MRF
on different cycles. Stalls due to bank conflicts are rare and can be
minimized with compiler techniques

Note that the use of vector types, such as float4, tends to impose an additional burden on register allocation. That is why earlier in this thread, I suggested using scalar computation for the exercise at hand (maximize floating-point throughput), and did so in the sample code I posted.

thanks for your patient

In practical terms, I see the risk of becoming mired in microarchitectural details that have no bearing on 99% of real-life CUDA code out there. Yes, it is cool to figure out how to get close to the theoretical peak FLOPS, but the resulting code often has very little similarity with code people write to address actual use cases. Not much has changed in that regard since I showed how to get 1 GFLOPS out of the AMD K7 (Athlon) processor ca. 1999.

Outside of special scenarios, much forward progress can be made by simply relying on the CUDA compiler and using feedback provided by the CUDA profiler.

“relying on the CUDA compiler and using feedback provided by the CUDA profiler”, yes, you are right
but, my dilemma is I cannot obtain useful clue from NCU’s profile information.

You might want to start with the recommendations from the CUDA Best Practices guide before delving into the profiler. All profilers these days are very sophisticated utilities that one needs to spend some quality time with to get the full benefit. So give it some time, keep on experimenting and exploring.

When profiler use first became common some thirty years ago, their functionality was very limited and interacting with profilers was less overwhelming than today. There is a bit of a trade-off between ease of use and depth of analysis, I would say.

One source of information is pages 6-8 of Dissecting Volta. The same authors wrote Dissecting Ampere, where they state on page 29, that the Ampere register layout is the same.

thanks for your help