Is there a way to inspect the time cost of each individual cuda block?

I’m profiling the kernel of flash attention. In this kernel, each cuda block process different workload.

Some cuda blocks exit early, but some has to do a lot of work.

I want to know how imbalanced the kernel is. Is there a way to inspect this?

You can use these metrics from Nsight Compute: sm__cycles_elapsed.min, sm__cycles_elapsed.max, and sm__cycles_elapsed.avg to inspect this imbalance. These metrics provide insights into the minimum, maximum, and average latency of the Streaming Multiprocessors (SMs), respectively.

You can use the following command to collect these metrics:

ncu --metrics sm__cycles_elapsed.min,sm__cycles_elapsed.max,sm__cycles_elapsed.avg <your-application>

GREAT!!! That’s exactly what I need !!! Thank you so much!!!

Is this metric accurate enough?

There seems to be something wrong.

I’m trying this metric on H100 and this is my cuda code:

__global__ void test() {
        int bid = blockIdx.x;
        int n = bid * 1000000;
        float x = 10;

        if(bid < 9) return;
        for(int i = 0; i < n; i++) {
                x = x * (i + 1);
        }
}

int main() {
        cudaFree(0);
        test<<<10, 1>>>();
        cudaDeviceSynchronize();
        return 0;
}

I compile it using nvcc -g -G test.cu and then run

ncu --metrics sm__cycles_elapsed.min,sm__cycles_elapsed.max,sm__cycles_elapsed.avg  ./a.out

I get the following result:

==PROF== Profiling "test()" - 0: 0%....50%....100% - 1 pass
==PROF== Disconnected from process 294566
[294566] a.out@127.0.0.1
  test() (10, 1, 1)x(1, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    ---------------------- ----------- -------------
    Metric Name            Metric Unit  Metric Value
    ---------------------- ----------- -------------
    sm__cycles_elapsed.avg       cycle 2700004918.58
    sm__cycles_elapsed.max       cycle    2700004928
    sm__cycles_elapsed.min       cycle    2700004912
    ---------------------- ----------- -------------

Is that normal?

The kernel has 9 short blocks and 1 long cuda block. But the result shows there is little difference between them.

By the way, how can I find all the available metrics and their meanings? I look up to the manual, but can’t find this metric.

You can query all the available metrics with ncu --query-metrics. You will get something like this:

Device NVIDIA A100-SXM4-80GB (GA100)
--------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------
Metric Name                                                                 Metric Description
--------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------
dram__bytes                                                                 # of bytes accessed in DRAM
dram__bytes_read                                                            # of bytes read from DRAM
dram__bytes_write                                                           # of bytes written to DRAM
dram__cycles_active                                                         # of cycles where DRAM was active
dram__cycles_active_read                                                    # of cycles where DRAM was active for reads
dram__cycles_active_write                                                   # of cycles where DRAM was active for writes
dram__cycles_elapsed                                                        # of elapsed DRAM memory clock cycles
dram__cycles_in_frame                                                       # of cycles in user-defined frame
dram__cycles_in_region                                                      # of cycles in user-defined region
dram__sectors                                                               # of sectors accessed in DRAM
dram__sectors_read                                                          # of sectors read from DRAM
......

Sorry, I made a mistake. The sm__cycles_elapsed measures the cycles of SMs regardless of whether they are busy or idle. The correct metrics for this case are sm__cycles_active.

You can use the following command to collect these metrics:

ncu --metrics sm__cycles_active.min,sm__cycles_active.max,sm__cycles_active.avg  ./a.out

Thank you! This metric seems to work.

NCU does not currently support warp or block timing. This can be done via manual instrumentation of the kernel.

__global__ void test(..., uint64_t* durations)
{
    uint64_t cycles_start = clock64();
    
    // kernel
    
    uint64_t cycles_end = clock64();

    // limit down to 1 thread
    if (threadIdx.x == 0) {
        uint64_t duration = cycles_end - cycles_start;
        // output per warp or block
        uint34_t warpidx = calc_flat_warp_idx();
        uint32_t blockidx = calc_flat_block_idx();

        durations[blockidx] = duration;
    }

}

calc_flat_{warp,block}_idx() may need to handle 1D - 3D grids and blocks. The overhead increases with the number of dimensions.

If capturing per warp, then I would recommend undefined behavior of having every lane in the warp write to the same address. This is a race condition but is generally the lowest overhead method with high accuracy.

If capturing per block, then it is often necessary to do a __syncthreads. This will not work for all kernels, so it is possible to use the undefined behavior and have all warps write to the same blockIdx.

The code above uses clock64(). I would recommend either using PTX %clocklo to get a 32-bit value and handle roll-over or use PTX %globaltimer_lo. If using globaltimer_lo the output will be in 32ns resolution; however, the default update frequency is 1 MHz on pre-GH100. On pre-GH100 you will want to run in NSYS or NCU as the tool will increase to 31.25 MHz (32 ns) update frequency.

Hi, it seems this code is used to record the time cost for each warp.

But even if the time cost of each block varies a lot, the SM load may still be balanced, right?

I benchmark a kernel with a lot of blocks, each block has different workloads. But the load imbalance is very little.

I wonder, is there any documentation talking about the block scheduling strategy. I mean, how does the GPU dispatch the blocks to each SM.

sm__cycles_active.{avg, max, min}
smsp__cycles_active.{avg, max, min}
sm__warps_active.{avg, max, min}

are likely the best method to find imbalance.

Yes, block duration can vary significantly especially if you have high number of blocks per SM and many waves. This does not indicate a load imbalance.

The code above can quickly be modified to either (a) collect warp or block duration, or (b) write-out both start timestamp, end timestamp, and smid to help find the cause of imbalance.

1 Like

Thank you. I see the load imbalance using this metric. It works.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.