Bf16 slower than fp32 on A10 and A100?

I am writing some efficient CUDA kernels for a deep learning problem that doesn’t seem to fit cleanly into PyTorch. I had planned to use bf16, but I hit some issues where I found bf16 is slower than fp32. I kept simplifying and distilling the problem until I had gone all the way to a simple kernel that elementwise adds 2 arrays. Even for this simple problem, on A10 and A100 GPUs I am seeing that bf16 is slower than fp32. Surely I must be doing something wrong!

Here I will walk you through the code, experiments, and results. And, I put a runnable version of the code in this repo: GitHub - forresti/cuda_bf16_benchmarks

For the following example I will add two arrays of 2^24=16777216 elements. I find that the results are similar for larger arrays too.

minimal example

First, here’s the C++ fp32 reference implementation of elementwise adding 2 arrays:

void add_gold(float* input1, float* input2, float* output, size_t N){
   for(int32_t i=0; i<N; i++){
       output[i] = input1[i] + input2[i];
   }
}

Next, here are simple bf16 and fp32 implementations.

__global__ void add_kernel_bf16(
   const __nv_bfloat162* __restrict__ d_input1,
   const __nv_bfloat162* __restrict__ d_input2,
   __nv_bfloat162* __restrict__ d_output,
   const size_t N
) {
   auto block = cg::this_thread_block();
   int32_t gid = block.group_index().x;
   int32_t tid = block.thread_index().x;
   int32_t out_idx = (gid*block.size() + tid);

   d_output[out_idx] = d_input1[out_idx] + d_input2[out_idx];
}


void add_gpu_bf16(float4* d_input1, float4* d_input2, float4* d_output, size_t N){
   dim3 threads = {256, 1, 1};
   dim3 blocks = {N/(256*2), 1, 1};  // assume output_N is divisible by 256
   add_kernel_bf16<<<blocks, threads>>>((__nv_bfloat162*)d_input1, (__nv_bfloat162*)d_input2, (__nv_bfloat162*)d_output, N);
}

__global__ void add_kernel_fp32(
   const float* __restrict__ d_input1,
   const float* __restrict__ d_input2,
   float* __restrict__ d_output,
   const size_t N
) {
   auto block = cg::this_thread_block();
   int32_t gid = block.group_index().x;
   int32_t tid = block.thread_index().x;
   int32_t out_idx = (gid*block.size() + tid);

    // this does 2 elements at a time
   d_output[out_idx] = d_input1[out_idx] + d_input2[out_idx];
}

void add_gpu_fp32(float4* d_input1, float4* d_input2, float4* d_output, size_t N){
   dim3 threads = {256, 1, 1};
   dim3 blocks = {N/256, 1, 1};
   add_kernel_fp32<<<blocks, threads>>>((float*)d_input1, (float*)d_input2, (float*)d_output, N);
}

The results on NVIDIA A10 are:

  • run_add(): type=fp32, N: 16777216, 0.134 GB,gpu fp32 latency: 0.00492876 sec, 27 GB/s,
  • run_add(): type=bf16, N: 16777216, 0.067, GB gpu bf16 latency: 0.00390167 sec, 17 GB/s,

So, here bf16 is a bit faster than fp32, but the absolute performance is quite poor: 17 GB/s on a device with memory bandwidth of 600 GB/s.

optimized version (but still bf16 is slower)

To get better efficiency, let’s try unrolling the code, using fewer threads, and loading aligned 128-bit data by casting to float4.

__global__ void add_kernel_bf16_unroll(
   const float4* __restrict__ d_input1,
   const float4* __restrict__ d_input2,
   float4* __restrict__ d_output,
   const size_t N
) {
   auto block = cg::this_thread_block();
   int32_t gid = block.group_index().x;
   int32_t tid = block.thread_index().x;
   int32_t out_idx = (gid*block.size() + tid);

   float4 input1 = d_input1[out_idx];
   float4 input2 = d_input2[out_idx];
   float4 output;

   __nv_bfloat162* input1_bf16 = reinterpret_cast<__nv_bfloat162*>(&input1);
   __nv_bfloat162* input2_bf16 = reinterpret_cast<__nv_bfloat162*>(&input2);
   __nv_bfloat162* output_bf16 = reinterpret_cast<__nv_bfloat162*>(&output);

   // each of these does 2 elements at a time
   output_bf16[0] = input1_bf16[0] + input2_bf16[0];
   output_bf16[1] = input1_bf16[1] + input2_bf16[1];
   output_bf16[2] = input1_bf16[2] + input2_bf16[2];
   output_bf16[3] = input1_bf16[3] + input2_bf16[3];

   // write it all at once
   d_output[out_idx] = output;
}


void add_gpu_bf16_unroll(float4* d_input1, float4* d_input2, float4* d_output, size_t N){
   dim3 threads = {256, 1, 1};
   dim3 blocks = {N/(256*8), 1, 1};  // extra /8 because each thread makes 8 outputs
   add_kernel_bf16_unroll<<<blocks, threads>>>(d_input1, d_input2, d_output, N);
}

__global__ void add_kernel_fp32_unroll(
   const float4* __restrict__ d_input1,
   const float4* __restrict__ d_input2,
   float4* __restrict__ d_output,
   const size_t N
) {
   auto block = cg::this_thread_block();
   int32_t gid = block.group_index().x;
   int32_t tid = block.thread_index().x;
   int32_t out_idx = 2*(gid*block.size() + tid);

   float4 input1[2] = {d_input1[out_idx], d_input1[out_idx+1]};
   float4 input2[2] = {d_input2[out_idx], d_input2[out_idx+1]};
   float4 output[2];

   output[0].w = input1[0].w + input2[0].w;
   output[0].x = input1[0].x + input2[0].x;
   output[0].y = input1[0].y + input2[0].y;
   output[0].z = input1[0].z + input2[0].z;
   output[1].w = input1[1].w + input2[1].w;
   output[1].x = input1[1].x + input2[1].x;
   output[1].y = input1[1].y + input2[1].y;
   output[1].z = input1[1].z + input2[1].z;

   d_output[out_idx] = output[0];
   d_output[out_idx+1] = output[1];
}

void add_gpu_fp32_unroll(float4* d_input1, float4* d_input2, float4* d_output, size_t N){
   dim3 threads = {256, 1, 1};
   dim3 blocks = {N/(256*8), 1, 1};  // extra /8 because each thread makes 8 outputs
   add_kernel_fp32_unroll<<<blocks, threads>>>(d_input1, d_input2, d_output, N);
}

The results on NVIDIA A10 are:

  • run_add(): type=fp32, N: 16777216, 0.134 GB, latency: 0.00185661 sec, 72 GB/s,
  • run_add(): type=bf16, N: 16777216, 0.067 GB, latency: 0.00259557 sec, 26 GB/s,

I also ran these kernels on A100:

  • run_add(): type=fp32, N: 16777216, 0.13 GB, latency: 0.000708723 sec, 189 GB/s,
  • run_add(): type=bf16, N: 16777216, 0.067 GB, latency: 0.00132306 sec, 51 GB/s,

Uh-oh. bf16 is a lot slower than fp32! What might I be doing wrong?

how about PyTorch?

And finally, here’s a PyTorch implementation. I would love to match the bandwidth of the PyTorch version with my custom kernel. Then I will be able to enjoy similarly high bandwidth in the application that I am developing.

output = input1 + input2

And I get much better results in PyTorch on A10. And in PyTorch, there’s a ~2x speedup when we change from fp32 to bf16.

  • dtype: torch.float32, N: 16777216, 0.134 GB, latency: 0.00048 sec, 277 GB/s
  • dtype: torch.bfloat16, N: 16777216, 0.067 GB, latency: 0.00027 sec, 244 GB/s

Also, on A10 GPU, I find that cudaMemcpy DeviceToDevice of this size array runs at 243 GB/s, which is roughly the same as PyTorch’s add() kernels.

Final thoughts & questions

  • When I examine the ptx code, for my bf16 CUDA implementations, bf16 or the number 16 rarely appears in the instructions. Is it possible that for some reason the data is getting casted to 32-bit at some point in the bf16 kernel?
  • How is PyTorch so much faster than me? Is there something I can do in a self-contained CUDA kernel to be more competitive with PyTorch’s memory bandwidth?

If I were wanting to compare compute throughput (FP32 add throughput vs. bf16 add throughput), I would want to get memory considerations mostly out of, or minimized in, my test code.

At the moment I don’t have convenient access to an A100 or A10. I can conveniently get to an L4 GPU, cc8.9, and I note that it has the same throughput for bf16 and fp32. Therefore, I would want to do half as many bf16x2 operations as I do FP32 operations, and see if those time approximately equal. According to my test case, they do:

# cat t223.cu
#include <cuda_bf16.h>

const int al = 8;
template <typename T>
__global__ void k(T *d, int l){

  T a[al];
  for (int i = 0; i < al; i++) a[i] = d[i];
  T r;
  for (int i = 0; i < l; i++)
    for (int j = 0; j < al; j++)
      r += a[j];
  d[threadIdx.x] = r;
}
using ft1 = float;
int main(){
  ft1 *d;
  const int bs = 256;
  const int nb = 1024;
  cudaMalloc(&d, bs * sizeof(ft1));
  const int l=1000;
  k<<<nb,bs>>>(d, l);
  cudaDeviceSynchronize();
  k<<<nb/2,bs>>>(reinterpret_cast<__nv_bfloat162 *>(d), l);
  cudaDeviceSynchronize();
}
# nvcc -o t223 t223.cu -arch=sm_89
# nsys profile --stats=true ./t223
Generating '/tmp/nsys-report-d985.qdstrm'
[1/8] [========================100%] report48.nsys-rep
[2/8] [========================100%] report48.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /root/bobc/report48.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)     Med (ns)    Min (ns)   Max (ns)    StdDev (ns)        Name
 --------  ---------------  ---------  ------------  -----------  --------  -----------  ------------  --------------
     51.8      213,610,012        476     448,760.5     14,893.0     1,033   78,773,099   4,104,860.2  ioctl
     41.5      170,863,015         11  15,533,001.4  3,030,461.0     4,832  100,130,001  29,513,044.4  poll
      5.7       23,639,646         29     815,160.2      5,600.0     2,114   23,450,998   4,353,491.9  fopen
      0.5        2,039,590         27      75,540.4     13,036.0    10,947    1,252,944     236,605.2  mmap64
      0.2          789,453         44      17,942.1     16,729.5     7,052       31,456       4,501.9  open64
      0.1          402,621          9      44,735.7     41,305.0    28,960       74,022      12,741.7  sem_timedwait
      0.1          270,289          2     135,144.5    135,144.5   113,872      156,417      30,083.9  pthread_create
      0.0          162,570         14      11,612.1      5,156.0     2,484       70,706      17,556.5  mmap
      0.0           84,152         48       1,753.2         67.0        58       80,782      11,649.5  fgets
      0.0           72,314         23       3,144.1      3,253.0     1,625        5,540         911.7  fclose
      0.0           53,886         51       1,056.6      1,033.0       730        1,740         183.6  fcntl
      0.0           40,425          6       6,737.5      6,219.5     2,922       10,447       2,908.5  open
      0.0           32,301         13       2,484.7      2,047.0     1,271        6,013       1,297.8  read
      0.0           30,735          5       6,147.0      5,975.0     4,644        7,361       1,081.0  munmap
      0.0           28,165         10       2,816.5      2,739.0     1,400        4,747         877.8  write
      0.0           17,115          2       8,557.5      8,557.5     5,540       11,575       4,267.4  socket
      0.0           16,263          1      16,263.0     16,263.0    16,263       16,263           0.0  fread
      0.0           14,185          1      14,185.0     14,185.0    14,185       14,185           0.0  connect
      0.0            9,193          1       9,193.0      9,193.0     9,193        9,193           0.0  pipe2
      0.0            6,395          7         913.6        911.0       844          987          56.0  dup
      0.0            2,345          1       2,345.0      2,345.0     2,345        2,345           0.0  bind
      0.0            1,417          1       1,417.0      1,417.0     1,417        1,417           0.0  listen

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)       Med (ns)      Min (ns)     Max (ns)    StdDev (ns)           Name
 --------  ---------------  ---------  -------------  -------------  -----------  -----------  -----------  ----------------------
     99.7      188,588,879          1  188,588,879.0  188,588,879.0  188,588,879  188,588,879          0.0  cudaMalloc
      0.2          311,180          2      155,590.0      155,590.0      147,560      163,620     11,356.1  cudaDeviceSynchronize
      0.1          197,077          2       98,538.5       98,538.5       28,755      168,322     98,688.8  cudaLaunchKernel
      0.0            1,365          1        1,365.0        1,365.0        1,365        1,365          0.0  cuModuleGetLoadingMode

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                Name
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------
     52.9          164,352          1  164,352.0  164,352.0   164,352   164,352          0.0  void k<float>(T1 *, int)
     47.1          146,048          1  146,048.0  146,048.0   146,048   146,048          0.0  void k<__nv_bfloat162>(T1 *, int)

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
SKIPPED: /root/bobc/report48.sqlite does not contain GPU memory data.
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
SKIPPED: /root/bobc/report48.sqlite does not contain GPU memory data.
Generated:
    /root/bobc/report48.nsys-rep
    /root/bobc/report48.sqlite
#

CUDA 12.2

For reference, here is the SASS. I haven’t studied it carefully, but there don’t appear to be obvious optimizations done by the compiler that would skew this:

# cuobjdump -sass ./t223

Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit

        code for sm_89

Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit

        code for sm_89
                Function : _Z1kI14__nv_bfloat162EvPT_i
        .headerflags    @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM89 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM89)"
        /*0000*/                   IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;    /* 0x00000a00ff017624 */
                                                                              /* 0x000fc400078e00ff */
        /*0010*/                   IMAD.MOV.U32 R8, RZ, RZ, c[0x0][0x168] ;   /* 0x00005a00ff087624 */
                                                                              /* 0x000fe200078e00ff */
        /*0020*/                   ULDC.64 UR4, c[0x0][0x118] ;               /* 0x0000460000047ab9 */
                                                                              /* 0x000fc80000000a00 */
        /*0030*/                   ISETP.GE.AND P0, PT, R8, 0x1, PT ;         /* 0x000000010800780c */
                                                                              /* 0x000fda0003f06270 */
        /*0040*/              @!P0 BRA 0x440 ;                                /* 0x000003f000008947 */
                                                                              /* 0x000fea0003800000 */
        /*0050*/                   IADD3 R12, R8, -0x1, RZ ;                  /* 0xffffffff080c7810 */
                                                                              /* 0x000fe20007ffe0ff */
        /*0060*/                   IMAD.MOV.U32 R10, RZ, RZ, c[0x0][0x160] ;  /* 0x00005800ff0a7624 */
                                                                              /* 0x000fe400078e00ff */
        /*0070*/                   IMAD.MOV.U32 R11, RZ, RZ, c[0x0][0x164] ;  /* 0x00005900ff0b7624 */
                                                                              /* 0x000fe200078e00ff */
        /*0080*/                   ISETP.GE.U32.AND P1, PT, R12, 0x3, PT ;    /* 0x000000030c00780c */
                                                                              /* 0x000fe40003f26070 */
        /*0090*/                   LOP3.LUT R8, R8, 0x3, RZ, 0xc0, !PT ;      /* 0x0000000308087812 */
                                                                              /* 0x000fe400078ec0ff */
        /*00a0*/                   LDG.E R0, [R10.64] ;                       /* 0x000000040a007981 */
                                                                              /* 0x000168000c1e1900 */
        /*00b0*/                   LDG.E R2, [R10.64+0x4] ;                   /* 0x000004040a027981 */
                                                                              /* 0x000162000c1e1900 */
        /*00c0*/                   ISETP.NE.AND P0, PT, R8, RZ, PT ;          /* 0x000000ff0800720c */
                                                                              /* 0x000fc60003f05270 */
        /*00d0*/                   LDG.E R3, [R10.64+0x8] ;                   /* 0x000008040a037981 */
                                                                              /* 0x000168000c1e1900 */
        /*00e0*/                   LDG.E R4, [R10.64+0xc] ;                   /* 0x00000c040a047981 */
                                                                              /* 0x000168000c1e1900 */
        /*00f0*/                   LDG.E R5, [R10.64+0x10] ;                  /* 0x000010040a057981 */
                                                                              /* 0x000168000c1e1900 */
        /*0100*/                   LDG.E R6, [R10.64+0x14] ;                  /* 0x000014040a067981 */
                                                                              /* 0x000168000c1e1900 */
        /*0110*/                   LDG.E R7, [R10.64+0x18] ;                  /* 0x000018040a077981 */
                                                                              /* 0x000168000c1e1900 */
        /*0120*/                   LDG.E R9, [R10.64+0x1c] ;                  /* 0x00001c040a097981 */
                                                                              /* 0x000162000c1e1900 */
        /*0130*/              @!P1 BRA 0x380 ;                                /* 0x0000024000009947 */
                                                                              /* 0x000fea0003800000 */
        /*0140*/                   IADD3 R10, -R8, c[0x0][0x168], RZ ;        /* 0x00005a00080a7a10 */
                                                                              /* 0x001fc60007ffe1ff */
        /*0150*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R0 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x020fe20000200000 */
        /*0160*/                   IADD3 R10, R10, -0x4, RZ ;                 /* 0xfffffffc0a0a7810 */
                                                                              /* 0x000fc60007ffe0ff */
        /*0170*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R2 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fe20000200002 */
        /*0180*/                   ISETP.NE.AND P1, PT, R10, RZ, PT ;         /* 0x000000ff0a00720c */
                                                                              /* 0x000fc60003f25270 */
        /*0190*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R3 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200003 */
        /*01a0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R4 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200004 */
        /*01b0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R5 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200005 */
        /*01c0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R6 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200006 */
        /*01d0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R7 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200007 */
        /*01e0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R9 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200009 */
        /*01f0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R0 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200000 */
        /*0200*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R2 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200002 */
        /*0210*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R3 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200003 */
        /*0220*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R4 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200004 */
        /*0230*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R5 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200005 */
        /*0240*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R6 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200006 */
        /*0250*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R7 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200007 */
        /*0260*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R9 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200009 */
        /*0270*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R0 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200000 */
        /*0280*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R2 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200002 */
        /*0290*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R3 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200003 */
        /*02a0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R4 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200004 */
        /*02b0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R5 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200005 */
        /*02c0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R6 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200006 */
        /*02d0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R7 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200007 */
        /*02e0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R9 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200009 */
        /*02f0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R0 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200000 */
        /*0300*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R2 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200002 */
        /*0310*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R3 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200003 */
        /*0320*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R4 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200004 */
        /*0330*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R5 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200005 */
        /*0340*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R6 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200006 */
        /*0350*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R7 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200007 */
        /*0360*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R9 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fe20000200009 */
        /*0370*/               @P1 BRA 0x150 ;                                /* 0xfffffdd000001947 */
                                                                              /* 0x000fea000383ffff */
        /*0380*/              @!P0 BRA 0x440 ;                                /* 0x000000b000008947 */
                                                                              /* 0x000fea0003800000 */
        /*0390*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R0 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x021fe20000200000 */
        /*03a0*/                   IADD3 R8, R8, -0x1, RZ ;                   /* 0xffffffff08087810 */
                                                                              /* 0x000fc60007ffe0ff */
        /*03b0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R2 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fe20000200002 */
        /*03c0*/                   ISETP.NE.AND P0, PT, R8, RZ, PT ;          /* 0x000000ff0800720c */
                                                                              /* 0x000fc60003f05270 */
        /*03d0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R3 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200003 */
        /*03e0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R4 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200004 */
        /*03f0*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R5 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200005 */
        /*0400*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R6 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200006 */
        /*0410*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R7 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fc80000200007 */
        /*0420*/                   HFMA2.BF16_V2 R11, R11, 1, 1, R9 ;         /* 0x3f803f800b0b7831 */
                                                                              /* 0x000fe20000200009 */
        /*0430*/               @P0 BRA 0x390 ;                                /* 0xffffff5000000947 */
                                                                              /* 0x000fea000383ffff */
        /*0440*/                   S2R R2, SR_TID.X ;                         /* 0x0000000000027919 */
                                                                              /* 0x020e620000002100 */
        /*0450*/                   IMAD.MOV.U32 R3, RZ, RZ, 0x4 ;             /* 0x00000004ff037424 */
                                                                              /* 0x000fc800078e00ff */
        /*0460*/                   IMAD.WIDE.U32 R2, R2, R3, c[0x0][0x160] ;  /* 0x0000580002027625 */
                                                                              /* 0x002fca00078e0003 */
        /*0470*/                   STG.E [R2.64], R11 ;                       /* 0x0000000b02007986 */
                                                                              /* 0x000fe2000c101904 */
        /*0480*/                   EXIT ;                                     /* 0x000000000000794d */
                                                                              /* 0x000fea0003800000 */
        /*0490*/                   BRA 0x490;                                 /* 0xfffffff000007947 */
                                                                              /* 0x000fc0000383ffff */
        /*04a0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04b0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04c0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04d0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04e0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04f0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0500*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0510*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0520*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0530*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0540*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0550*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0560*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0570*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
                ..........


                Function : _Z1kIfEvPT_i
        .headerflags    @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM89 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM89)"
        /*0000*/                   MOV R1, c[0x0][0x28] ;                     /* 0x00000a0000017a02 */
                                                                              /* 0x000fc40000000f00 */
        /*0010*/                   MOV R8, c[0x0][0x168] ;                    /* 0x00005a0000087a02 */
                                                                              /* 0x000fe20000000f00 */
        /*0020*/                   ULDC.64 UR4, c[0x0][0x118] ;               /* 0x0000460000047ab9 */
                                                                              /* 0x000fc60000000a00 */
        /*0030*/                   ISETP.GE.AND P0, PT, R8, 0x1, PT ;         /* 0x000000010800780c */
                                                                              /* 0x000fda0003f06270 */
        /*0040*/              @!P0 BRA 0x440 ;                                /* 0x000003f000008947 */
                                                                              /* 0x000fea0003800000 */
        /*0050*/                   IADD3 R12, R8, -0x1, RZ ;                  /* 0xffffffff080c7810 */
                                                                              /* 0x000fe40007ffe0ff */
        /*0060*/                   MOV R10, c[0x0][0x160] ;                   /* 0x00005800000a7a02 */
                                                                              /* 0x000fe40000000f00 */
        /*0070*/                   ISETP.GE.U32.AND P1, PT, R12, 0x3, PT ;    /* 0x000000030c00780c */
                                                                              /* 0x000fe40003f26070 */
        /*0080*/                   MOV R11, c[0x0][0x164] ;                   /* 0x00005900000b7a02 */
                                                                              /* 0x000fe40000000f00 */
        /*0090*/                   LOP3.LUT R8, R8, 0x3, RZ, 0xc0, !PT ;      /* 0x0000000308087812 */
                                                                              /* 0x000fc600078ec0ff */
        /*00a0*/                   LDG.E R0, [R10.64] ;                       /* 0x000000040a007981 */
                                                                              /* 0x000162000c1e1900 */
        /*00b0*/                   ISETP.NE.AND P0, PT, R8, RZ, PT ;          /* 0x000000ff0800720c */
                                                                              /* 0x000fc60003f05270 */
        /*00c0*/                   LDG.E R2, [R10.64+0x4] ;                   /* 0x000004040a027981 */
                                                                              /* 0x000168000c1e1900 */
        /*00d0*/                   LDG.E R3, [R10.64+0x8] ;                   /* 0x000008040a037981 */
                                                                              /* 0x000168000c1e1900 */
        /*00e0*/                   LDG.E R4, [R10.64+0xc] ;                   /* 0x00000c040a047981 */
                                                                              /* 0x000168000c1e1900 */
        /*00f0*/                   LDG.E R5, [R10.64+0x10] ;                  /* 0x000010040a057981 */
                                                                              /* 0x000168000c1e1900 */
        /*0100*/                   LDG.E R6, [R10.64+0x14] ;                  /* 0x000014040a067981 */
                                                                              /* 0x000168000c1e1900 */
        /*0110*/                   LDG.E R7, [R10.64+0x18] ;                  /* 0x000018040a077981 */
                                                                              /* 0x000168000c1e1900 */
        /*0120*/                   LDG.E R9, [R10.64+0x1c] ;                  /* 0x00001c040a097981 */
                                                                              /* 0x000162000c1e1900 */
        /*0130*/              @!P1 BRA 0x380 ;                                /* 0x0000024000009947 */
                                                                              /* 0x000fea0003800000 */
        /*0140*/                   IADD3 R12, -R8, c[0x0][0x168], RZ ;        /* 0x00005a00080c7a10 */
                                                                              /* 0x000fc60007ffe1ff */
        /*0150*/                   FADD R11, R0, R11 ;                        /* 0x0000000b000b7221 */
                                                                              /* 0x021fe20000000000 */
        /*0160*/                   IADD3 R12, R12, -0x4, RZ ;                 /* 0xfffffffc0c0c7810 */
                                                                              /* 0x000fc60007ffe0ff */
        /*0170*/                   FADD R10, R2, R11 ;                        /* 0x0000000b020a7221 */
                                                                              /* 0x000fe20000000000 */
        /*0180*/                   ISETP.NE.AND P1, PT, R12, RZ, PT ;         /* 0x000000ff0c00720c */
                                                                              /* 0x000fc60003f25270 */
        /*0190*/                   FADD R11, R3, R10 ;                        /* 0x0000000a030b7221 */
                                                                              /* 0x000fc80000000000 */
        /*01a0*/                   FADD R10, R4, R11 ;                        /* 0x0000000b040a7221 */
                                                                              /* 0x000fc80000000000 */
        /*01b0*/                   FADD R11, R5, R10 ;                        /* 0x0000000a050b7221 */
                                                                              /* 0x000fc80000000000 */
        /*01c0*/                   FADD R10, R6, R11 ;                        /* 0x0000000b060a7221 */
                                                                              /* 0x000fc80000000000 */
        /*01d0*/                   FADD R10, R7, R10 ;                        /* 0x0000000a070a7221 */
                                                                              /* 0x000fc80000000000 */
        /*01e0*/                   FADD R11, R9, R10 ;                        /* 0x0000000a090b7221 */
                                                                              /* 0x000fc80000000000 */
        /*01f0*/                   FADD R11, R0, R11 ;                        /* 0x0000000b000b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0200*/                   FADD R10, R2, R11 ;                        /* 0x0000000b020a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0210*/                   FADD R11, R3, R10 ;                        /* 0x0000000a030b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0220*/                   FADD R10, R4, R11 ;                        /* 0x0000000b040a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0230*/                   FADD R11, R5, R10 ;                        /* 0x0000000a050b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0240*/                   FADD R10, R6, R11 ;                        /* 0x0000000b060a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0250*/                   FADD R10, R7, R10 ;                        /* 0x0000000a070a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0260*/                   FADD R11, R9, R10 ;                        /* 0x0000000a090b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0270*/                   FADD R11, R0, R11 ;                        /* 0x0000000b000b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0280*/                   FADD R10, R2, R11 ;                        /* 0x0000000b020a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0290*/                   FADD R11, R3, R10 ;                        /* 0x0000000a030b7221 */
                                                                              /* 0x000fc80000000000 */
        /*02a0*/                   FADD R10, R4, R11 ;                        /* 0x0000000b040a7221 */
                                                                              /* 0x000fc80000000000 */
        /*02b0*/                   FADD R11, R5, R10 ;                        /* 0x0000000a050b7221 */
                                                                              /* 0x000fc80000000000 */
        /*02c0*/                   FADD R10, R6, R11 ;                        /* 0x0000000b060a7221 */
                                                                              /* 0x000fc80000000000 */
        /*02d0*/                   FADD R10, R7, R10 ;                        /* 0x0000000a070a7221 */
                                                                              /* 0x000fc80000000000 */
        /*02e0*/                   FADD R11, R9, R10 ;                        /* 0x0000000a090b7221 */
                                                                              /* 0x000fc80000000000 */
        /*02f0*/                   FADD R11, R0, R11 ;                        /* 0x0000000b000b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0300*/                   FADD R10, R2, R11 ;                        /* 0x0000000b020a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0310*/                   FADD R11, R3, R10 ;                        /* 0x0000000a030b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0320*/                   FADD R10, R4, R11 ;                        /* 0x0000000b040a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0330*/                   FADD R11, R5, R10 ;                        /* 0x0000000a050b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0340*/                   FADD R10, R6, R11 ;                        /* 0x0000000b060a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0350*/                   FADD R10, R7, R10 ;                        /* 0x0000000a070a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0360*/                   FADD R11, R9, R10 ;                        /* 0x0000000a090b7221 */
                                                                              /* 0x000fe20000000000 */
        /*0370*/               @P1 BRA 0x150 ;                                /* 0xfffffdd000001947 */
                                                                              /* 0x000fea000383ffff */
        /*0380*/              @!P0 BRA 0x440 ;                                /* 0x000000b000008947 */
                                                                              /* 0x000fea0003800000 */
        /*0390*/                   FADD R11, R0, R11 ;                        /* 0x0000000b000b7221 */
                                                                              /* 0x021fe20000000000 */
        /*03a0*/                   IADD3 R8, R8, -0x1, RZ ;                   /* 0xffffffff08087810 */
                                                                              /* 0x000fc60007ffe0ff */
        /*03b0*/                   FADD R10, R2, R11 ;                        /* 0x0000000b020a7221 */
                                                                              /* 0x000fe20000000000 */
        /*03c0*/                   ISETP.NE.AND P0, PT, R8, RZ, PT ;          /* 0x000000ff0800720c */
                                                                              /* 0x000fc60003f05270 */
        /*03d0*/                   FADD R11, R3, R10 ;                        /* 0x0000000a030b7221 */
                                                                              /* 0x000fc80000000000 */
        /*03e0*/                   FADD R10, R4, R11 ;                        /* 0x0000000b040a7221 */
                                                                              /* 0x000fc80000000000 */
        /*03f0*/                   FADD R11, R5, R10 ;                        /* 0x0000000a050b7221 */
                                                                              /* 0x000fc80000000000 */
        /*0400*/                   FADD R10, R6, R11 ;                        /* 0x0000000b060a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0410*/                   FADD R10, R7, R10 ;                        /* 0x0000000a070a7221 */
                                                                              /* 0x000fc80000000000 */
        /*0420*/                   FADD R11, R9, R10 ;                        /* 0x0000000a090b7221 */
                                                                              /* 0x000fe20000000000 */
        /*0430*/               @P0 BRA 0x390 ;                                /* 0xffffff5000000947 */
                                                                              /* 0x000fea000383ffff */
        /*0440*/                   S2R R2, SR_TID.X ;                         /* 0x0000000000027919 */
                                                                              /* 0x020e620000002100 */
        /*0450*/                   MOV R3, 0x4 ;                              /* 0x0000000400037802 */
                                                                              /* 0x000fca0000000f00 */
        /*0460*/                   IMAD.WIDE.U32 R2, R2, R3, c[0x0][0x160] ;  /* 0x0000580002027625 */
                                                                              /* 0x002fca00078e0003 */
        /*0470*/                   STG.E [R2.64], R11 ;                       /* 0x0000000b02007986 */
                                                                              /* 0x000fe2000c101904 */
        /*0480*/                   EXIT ;                                     /* 0x000000000000794d */
                                                                              /* 0x000fea0003800000 */
        /*0490*/                   BRA 0x490;                                 /* 0xfffffff000007947 */
                                                                              /* 0x000fc0000383ffff */
        /*04a0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04b0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04c0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04d0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04e0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*04f0*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0500*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0510*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0520*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0530*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0540*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0550*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0560*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
        /*0570*/                   NOP;                                       /* 0x0000000000007918 */
                                                                              /* 0x000fc00000000000 */
                ..........



Fatbin ptx code:
================
arch = sm_89
code version = [8,2]
host = linux
compile_size = 64bit
compressed

If we stay with your original formulation which adds two arrays in memory, I would expect this to be an entirely memory-bound problem. In that case, I would expect the equality point to be having the same number of FP32 adds as bf16x2 adds (not half as previously). Once again, I see approximate equality in that case:

# cat t224.cu
#include <cuda_bf16.h>

template <typename T>
__global__ void k(const T * __restrict__ a, const T * __restrict__ b, T * __restrict__ c, int l){

  int idx = threadIdx.x+blockDim.x*blockIdx.x;
  if (idx < l)
  c[idx] = a[idx]+b[idx];
}

using ft1 = float;
const int sz = 1048576*32;
int main(){
  ft1 *a, *b, *c;
  const int bs = 256;
  const int nb = (sz+bs-1)/bs;
  cudaMalloc(&a, sz * sizeof(ft1));
  cudaMalloc(&b, sz * sizeof(ft1));
  cudaMalloc(&c, sz * sizeof(ft1));
  k<<<nb,bs>>>(a, b, c, sz);
  cudaDeviceSynchronize();
  k<<<nb,bs>>>(reinterpret_cast<__nv_bfloat162 *>(a), reinterpret_cast<__nv_bfloat162 *>(b), reinterpret_cast<__nv_bfloat162 *>(c), sz);
  cudaDeviceSynchronize();
}
# nvcc -o t224 t224.cu -arch=sm_89
# nsys profile --stats=true ./t224
Generating '/tmp/nsys-report-8d17.qdstrm'
[1/8] [========================100%] report50.nsys-rep
[2/8] [========================100%] report50.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /root/bobc/report50.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)     Med (ns)    Min (ns)   Max (ns)    StdDev (ns)        Name
 --------  ---------------  ---------  ------------  -----------  --------  -----------  ------------  --------------
     57.1      233,819,715        484     483,098.6     15,671.0     1,040   78,880,360   4,659,663.1  ioctl
     41.9      171,538,278         11  15,594,388.9  3,023,216.0     5,090  100,133,742  29,536,232.8  poll
      0.5        1,994,332         27      73,864.1     12,252.0    10,272    1,225,290     231,472.8  mmap64
      0.2          803,951         44      18,271.6     17,396.5     5,687       36,437       5,027.3  open64
      0.1          420,760          9      46,751.1     41,900.0    35,794       79,454      13,318.9  sem_timedwait
      0.1          281,937          2     140,968.5    140,968.5   119,296      162,641      30,649.5  pthread_create
      0.0          193,250         29       6,663.8      5,789.0     2,095       19,080       3,868.0  fopen
      0.0          177,396         17      10,435.1      4,874.0     2,633       71,145      16,163.0  mmap
      0.0           84,441         48       1,759.2         66.5        58       81,104      11,696.1  fgets
      0.0           82,168          8      10,271.0      6,322.5     3,720       40,943      12,452.9  munmap
      0.0           72,407         23       3,148.1      3,228.0     1,662        4,850         831.4  fclose
      0.0           55,639         51       1,091.0      1,045.0       728        1,961         210.5  fcntl
      0.0           42,516          6       7,086.0      6,955.0     2,776       11,437       3,179.5  open
      0.0           31,910         13       2,454.6      1,890.0     1,262        5,194       1,153.8  read
      0.0           29,565         10       2,956.5      2,797.0     1,507        5,392       1,057.6  write
      0.0           18,435          2       9,217.5      9,217.5     4,968       13,467       6,009.7  socket
      0.0           17,510          1      17,510.0     17,510.0    17,510       17,510           0.0  fread
      0.0           16,528          1      16,528.0     16,528.0    16,528       16,528           0.0  connect
      0.0            9,361          1       9,361.0      9,361.0     9,361        9,361           0.0  pipe2
      0.0            6,536          7         933.7        912.0       841        1,048          81.1  dup
      0.0            2,567          1       2,567.0      2,567.0     2,567        2,567           0.0  bind
      0.0            1,801          1       1,801.0      1,801.0     1,801        1,801           0.0  listen

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)     Med (ns)    Min (ns)    Max (ns)     StdDev (ns)            Name
 --------  ---------------  ---------  ------------  -----------  ---------  -----------  -------------  ----------------------
     98.2      185,911,202          3  61,970,400.7    331,931.0    328,182  185,251,089  106,764,207.9  cudaMalloc
      1.7        3,261,579          2   1,630,789.5  1,630,789.5  1,550,148    1,711,431      114,044.3  cudaDeviceSynchronize
      0.1          223,756          2     111,878.0    111,878.0     36,570      187,186      106,501.6  cudaLaunchKernel
      0.0            2,437          1       2,437.0      2,437.0      2,437        2,437            0.0  cuModuleGetLoadingMode

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                            Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ---------------------------------------------------------
     52.4        1,709,411          1  1,709,411.0  1,709,411.0  1,709,411  1,709,411          0.0  void k<__nv_bfloat162>(const T1 *, const T1 *, T1 *, int)
     47.6        1,550,915          1  1,550,915.0  1,550,915.0  1,550,915  1,550,915          0.0  void k<float>(const T1 *, const T1 *, T1 *, int)

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
SKIPPED: /root/bobc/report50.sqlite does not contain GPU memory data.
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
SKIPPED: /root/bobc/report50.sqlite does not contain GPU memory data.
Generated:
    /root/bobc/report50.nsys-rep
    /root/bobc/report50.sqlite
#

Those kernels are achieving ~250GB/s bandwidth. That strikes me as reasonable/plausible for a GPU with an advertised peak theoretical bandwidth of 300GB/s

There’s quite a few things missing from your post, in my view:

  • compile command
  • CUDA version
  • how you are timing things
  • exact runnable code, without having to add anything or change anything

You’ll note I’ve provided those things in my answer. Sorry, I don’t wish to wade through your github site. It seems unnecessary to me, I believe I have demonstrated that concise and complete claims can be made within the confines of the forum here, and if you should take down that github repo for any reason, your demonstrators disappear or become useless for future readers.

I would almost never recommend trying to understand things from PTX analysis. SASS, on the other hand, is quite reliable, albeit perhaps harder to work with. PTX goes through a compilation stage before it becomes runnable code (ie. it gets converted to SASS). This compilation stage is not an “assembler” in the traditional meaning of that term e.g. from the 80’s and 90’s, but rather a highly optimizing compiler.

Thanks for the help, Robert! I will dig into your code.

I didn’t know how to concisely show everything without making a repo, but you’ve shown it’s possible!

See my replies inline:

Don’t compile with -G. This disables all optimizations and can lead to slow code. It is only meant for debugging.

For Robert’s code on A10, I get 300 GB/s. Great!

Thanks striker159 for the tip about -G!

In my code, after I remove -G, both the unrolled and not-unrolled implementations get 240 GB/s on A10! I will dig a bit deeper and see what subtle things are different between the 300 GB/s and 240 GB/s versions.

Also, I should mention that I tried initializing input arrays a and b of Robert’s code to random numbers using curand, and I saw no slowdown.