Discrepancy between theoretical occupancy and achieved occupancy depending on ThreadsPerBlock

I am trying to understand the significance of ThreadsPerBlock. I have run the trials on Volta V100 GPU (80 SMs).
My kernel function is simple, as shown below, to measure the global memory bandwidth. It accepts an array of size N (here 1024*1024) of datatype float. Since the kernel does not consume too many registers, I obtained a theoretical occupancy of 100% for block sizes 64, 128, 256 and 512.

__global__ void offset(float* a)
{
  int i = blockDim.x * blockIdx.x + threadIdx.x;
  a[i] = a[i] + 1;
}

For the block sizes 64, 128, 256 and 512, the active blocks per SM are 32, 16, 8 and 4, respectively (max limit is 32 for V100). In all these cases, the total number of warps remains the same, as well as the total number of active warps per SM (here 64). However, the achieved occupancies are 13.51%, 33.68%, 87.16%, and 85.13% for the block sizes 64, 128, 256 and 512, respectively. Also, the corresponding bandwidths are 290.698, 573.394, 726.744 and 722.543 GB/s.

I believe that there is no overhead cost associated with context switching of warps belonging to different Thread blocks. But I observe a different phenomenon. Could someone please explain the reason for this discrepancy? Thanks.

It would appear to be block-scheduling overhead, which would be consistent with the occupancy observation.

I don’t have a V100 at the moment, but here is a test case running on my L4 GPU. We first handle the entire data set using your kernel, and a grid sized to match the data set size, for varying block sizes. I see a similar pattern where the very small block size (64) is noticeably worse execution time compared to the other 3. Then I switch to a grid-stride loop methodology, were there is only one load of blocks deposited on the GPU (all blocks in the grid simultaneously fit on the GPU), doing all the work, therefore reducing the impact of block scheduling. In this case we still use the same 4 blocks sizes, but the grid is sized to match the GPU capacity, not the data set. The execution times of all 4 variants are within 5% of each other, in this case:

# cat t275.cu
__global__ void offset(float* a)
{
  int i = blockDim.x * blockIdx.x + threadIdx.x;
  a[i] = a[i] + 1;
}

__global__ void offset_i(float* a, const int ds)
{
  for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ds; i+=gridDim.x*blockDim.x)
    a[i] = a[i] + 1;
}

int main(){

  const int ds = 1024*1024;
  float *a;
  cudaMalloc(&a, sizeof(a[0])*ds);
  cudaMemset(a, 0, sizeof(a[0])*ds);
  offset<<<ds/512, 512>>>(a); // warm-up
  offset_i<<<ds/512, 512>>>(a, ds); // warm-up
  cudaDeviceSynchronize();
  for (int i = 64; i < 1024; i*=2){
    offset<<<ds/i, i>>>(a);
    cudaDeviceSynchronize();}
  int bs = 512;
  for (int i = 3; i < 48; i*=2){
    offset_i<<<58*i, bs>>>(a, ds);
    bs /=2;
    cudaDeviceSynchronize();}
}
# nvcc -o t275 t275.cu -arch=sm_89
# nsys nvprof --print-gpu-trace ./t275

WARNING: t275 and any of its children processes will be profiled.

Generating '/tmp/nsys-report-4635.qdstrm'
[1/3] [========================100%] report8.nsys-rep
[2/3] [========================100%] report8.sqlite
[3/3] Executing 'cuda_gpu_trace' stats report

 Start (ns)   Duration (ns)  CorrId   GrdX   GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd     Device      Ctx  Strm           Name
 -----------  -------------  ------  ------  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------  ---  ----  ----------------------
 655,959,365          4,928     119                                                                                  4.194        851,116.556  Device              NVIDIA L4 (0)    1     7  [CUDA memset]
 656,105,957          6,720     120   2,048     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset(float *)
 656,138,661          6,464     121   2,048     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset_i(float *, int)
 656,157,125         10,976     123  16,384     1     1    64     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset(float *)
 656,177,957          6,368     125   8,192     1     1   128     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset(float *)
 656,194,309          6,112     127   4,096     1     1   256     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset(float *)
 656,209,317          6,144     129   2,048     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset(float *)
 656,224,677          5,120     131     174     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset_i(float *, int)
 656,238,821          5,024     133     348     1     1   256     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset_i(float *, int)
 656,253,029          5,024     135     696     1     1   128     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset_i(float *, int)
 656,267,461          5,216     137   1,392     1     1    64     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  offset_i(float *, int)

Generated:
    /root/bobc/report8.nsys-rep
    /root/bobc/report8.sqlite
#

You won’t be able to see the same result with that exact code on your V100, the grid sizes in the second kernel would need to be adjusted to match your GPU capacity. For a V100 with 80 SMs it would be like this:

for (int i = 4; i < 64; i*=2){
  offset_i<<<80*i, bs>>>(a, ds);

Hi @Robert_Crovella, thank you for replying. I shall test this grid-stride loop methodology and get back. I have a few pertinent queries in this:

  1. Why is there a block-scheduling overhead? I know there is some cost with respect to allocation and freeing of resources (at the start and end of the block). But I thought 32 active thread blocks per SM implied that the thread blocks have already been scheduled on that SM and thus resources allocated in prior. Thus this SM, for a block size of 64, would have 64 active warps in total and there should be no cost in context switching. But, for block size 64 the occupancy is just 13.51%, i.e. active warps are about 8.4 only, implying each warp scheduler (SM of V100 has 4 Sub-partitions) has only about 2 active warps each.

First, I said it appears to be block scheduling overhead. I am reporting observations and conjecture.

Second, let’s not bring context-switching into it. The GPU can do zero-cycle context switching from warp to warp for resident warps. This is something different.

A kernel launch will have a grid associated with it, and that grid has a number of blocks. There is an entity in the GPU called a block scheduler or CWD (CUDA work distributor, actually Compute Work Distributor) that takes these blocks and deposits them on SMs. However each SM has a limit as to what it can hold. The limit can be expressed in several ways, such as number of threads or number of warps.

Although this is not specified anywhere that I know of, we can presume the behavior of the CWD is something like this:

  • at kernel launch, the CWD will deposit blocks from among those available in previous kernel launches (i.e. grids), until all SM are “full”. At that point, any remaining blocks must wait in a queue of some sort - there is no space for them to be resident on the SMs.
  • Eventually (presumably) one or more blocks executing on an SM finish their work and “retire”. This leaves openings (i.e. available resources) to support depositing new blocks, from the queue, onto that SM. The CWD observes this and selects additional block(s) to deposit there.

This process continues until the queue is empty.

The time period between when a block retires and when a new block is deposited in its place is what I am referring to as block scheduling overhead. I don’t know how long it is, or even that it exists, but if we acknowledge the above conjecture, then we could presume that the minimum time for that activity is not zero (unlike the warp context-switch time). If it is not zero, and if there are many blocks that need to be scheduled, over the duration of a grid/kernel execution, then it might be the case that the block scheduling “overhead” could be noticeable/measureable.

Likewise, if block scheduling overhead exists, there might be ways to minimize it, fundamentally by minimizing the number of times a block has to be scheduled, for a kernel or grid to complete its work. The grid-stride loop tends to minimize the number of times a block has to be scheduled, while still allowing a kernel to complete its work (ie. for a grid to “complete”), and (generally) still allowing "full " occupancy, or maximal utilization of the GPU.

This is observation and conjecture, not formally specified by NVIDIA, that I know of.

1 Like

@Robert_Crovella,
I have run the trials on V100 with and without grid-stride loop methodology. As you suggested, in grid-stride loop case, the bandwidth for block sizes 64, 128, 256, 512 and 1024 are more or less same. Thus, might be the effect of block scheduling overhead as you suggested.

However, contrary to the expectation, simple straight-forward kernel (each thread for one addition) achieves peak values in comparison to grid-stride loop methodology. But I could not see any probable reasons for this. Please let me know your views on this. Thanks.

The results are as follows:

Device: Tesla V100-SXM2-32GB
maxThreadsPerMultiProcessor: 2048
multiProcessorCount: 80
Transfer size (MB): 120
Single Precision
Simple Kernel results
Block Size, Grid Size, Bandwidth (GB/s):
        64,    491520, 291.874207
       128,    245760, 580.136169
       256,    122880, 737.028259
       512,     61440, 737.028259
      1024,     30720, 732.421875

Grid Stride Loop Kernel results
Block Size, Grid Size, Bandwidth (GB/s):
        64,      2560, 689.338257
       128,      1280, 681.322693
       256,       640, 701.721558
       512,       320, 693.417114
      1024,       160, 691.371704

I am attaching partial code below for reference.

template <typename T>
__global__ void simpleKernel(T* a)
{
  int i = blockDim.x * blockIdx.x + threadIdx.x;
  a[i] = a[i] + 1;
}

template <typename T>
__global__ void gridStrideLoopKernel(T* a, const int n)
{
  for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < n; i+=gridDim.x*blockDim.x)
    a[i] = a[i] + 1;
}

for (int blockSize = 64; blockSize <= 1024; blockSize = blockSize << 1) {
    int gridSize = (n+blockSize -1)/blockSize;
    checkCuda( cudaMemset(d_a, 0, n * sizeof(T)) );

    checkCuda( cudaEventRecord(startEvent,0) );
    simpleKernel<<<gridSize, blockSize>>>(d_a);
    checkCuda( cudaEventRecord(stopEvent,0) );
    checkCuda( cudaEventSynchronize(stopEvent) );

    checkCuda( cudaEventElapsedTime(&ms, startEvent, stopEvent) );
    printf("%10d, %9d, %f\n", blockSize, gridSize, 2*nMB/ms);
  }

for (int blockSize = 64; blockSize <= 1024; blockSize = blockSize << 1) {
    int gridSize = std::min(numSM * maxThreadsPerMultiProcessor/blockSize, (n+blockSize -1)/blockSize);
    checkCuda( cudaMemset(d_a, 0, n * sizeof(T)) );

    checkCuda( cudaEventRecord(startEvent,0) );
    gridStrideLoopKernel<<<gridSize, blockSize>>>(d_a, n);
    checkCuda( cudaEventRecord(stopEvent,0) );
    checkCuda( cudaEventSynchronize(stopEvent) );

    checkCuda( cudaEventElapsedTime(&ms, startEvent, stopEvent) );
    printf("%10d, %9d, %f\n", blockSize, gridSize, 2*nMB/ms);
  }

Sorry, I don’t have a V100 to test on at this point. My code already demonstrates higher bandwidth for the grid-stride case. And the discrepancy that you are now chasing is ~5%. There’s potentially a lot of things that would have to be investigated to track down the origin, such as compiler settings, CUDA version, event overhead vs. profiler timing, caching effects, etc. and I would need a V100 to test on to be certain of anything I might find, if I were looking for a 5% discrepancy.

1 Like

Just to be clear, as it is only implicitly implied:

The simple kernel is very short (just one read+write), so occupancy can go down, if new blocks are not scheduled to the SMs fast enough, whereas the grid stride loop kernel is slightly longer.

So the effect of grid-stride loops is not by clever indexing per se, but mostly by preventing a very short kernel.

1 Like