Question about profiling nccl kernels with Nsight Compute

Hi,
I would like to profile nccl kernel and get some detail metrics by using nsight compute, but it always hang. Can anybody give me some information about this? Thanks.

Add more details:
tested on NCG container: nvcr.io/nvidia/pytorch:21.07-py3
application: nccl-test/build/all_reduce_perf

PS: There has been an same issue reported in github , but no conclusion yet.

2 Likes

Nsight Compute serializes kernel launches across all profiled processes. If a kernel waits for other concurrent processes (or kernels) it will not be able to make forward progress and the profiling will hang. So such applications cannot be profiled using Nsight Compute.

Hi Sanjiv,
Thanks. And is there any plan to make ncu support ncck kernel profiling?

Yes, we are looking into supporting these types of applications in the future, but there is no definite timeline for such support to be released, yet.

hi there,
it seems we are hitting the same issue (Profiling all_reduce_perf with Nsight hangs · Issue #101 · NVIDIA/nccl-tests · GitHub) – is there any updates on the timeline for supporting this?

thanks!

Nsight Compute 2022.1 includes a new Range Replay feature to support profiling mandatory concurrent kernels (such as NCC all reduce). Range replay requires you to mark explicit ranges of kernels (and CUDA API calls) for profiling, using either the cu(da)ProfilerStart/Stop API or NVTX. A single result for the entire range is then collected, with the limitation that data is only collected for kernels from the first CUDA context found within the range. Also, it only works for ranges covering a single process.

Note that NCCL all reduce kernels are not yet fully supported with this version of range replay, meaning that it is possible to hang intermittently. Still, it will work in many cases.

For the NCCL all_reduce_perf test, a possible range is in common.cu lines 621ff

// Performance Benchmark
auto start = std::chrono::high_resolution_clock::now();
cudaProfilerStart();
for (int iter = 0; iter < iters; iter++) {
  if (agg_iters>1) NCCLCHECK(ncclGroupStart());
  for (int aiter = 0; aiter < agg_iters; aiter++) {
    TESTCHECK(startColl(args, type, op, root, in_place, iter*agg_iters+aiter));
  }
  if (agg_iters>1) NCCLCHECK(ncclGroupEnd());
}
cudaProfilerStop();

Hi felix_dt,

I just followed your steps:

  1. install latest Nsight Compute to 2022.1.1
  2. add cudaProfilerStart() and cudaProfilerStop() before and after startColl()

but it still hang, why ? any other steps I missing ?

NVIDIA (R) Nsight Compute Command Line Profiler
Copyright (c) 2018-2022 NVIDIA Corporation
Version 2022.1.1.0 (build 30914944) (public-release)

command: ncu --set full -f -o all_reduce_2_ranks_2M ./build/all_reduce_perf -g 2 -n 1 -w 0 -b 2M -e 2M -c 0

nThread 1 nGpus 2 minBytes 2097152 maxBytes 2097152 step: 1048576(bytes) warmup iters: 0 iters: 1 validation: 0

Using devices
==PROF== Connected to process 13632 (githubs/nccl-tests/build/all_reduce_perf)
Rank 0 Pid 13632 on da870a356542 device 0 [0x54] NVIDIA A100-SXM-80GB
Rank 1 Pid 13632 on da870a356542 device 1 [0x5a] NVIDIA A100-SXM-80GB

                                                   out-of-place                       in-place
   size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
    (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)

==PROF== Profiling “ncclKernel_AllReduce_RING_LL_…” - 1:

BTW, I was testing on nvcr.io/nvidia/pytorch:22.03-py3, with latset ncu and nccl 2.12

And after tested all kernels, I found that only broadcast_pref could profile, is that true ?

Thanks

below is the code diff:

diff --git a/src/common.cu b/src/common.cu
index 05f814d…a6e7f58 100644
— a/src/common.cu
+++ b/src/common.cu
@@ -10,6 +10,7 @@
#include <getopt.h>
#include <libgen.h>
#include “cuda.h”
+#include <cuda_profiler_api.h>

int test_ncclVersion = 0; // init’d with ncclGetVersion()

@@ -596,8 +597,8 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
}

// Sync
– TESTCHECK(startColl(args, type, op, root, in_place, 0));
– TESTCHECK(completeColl(args));
++ // TESTCHECK(startColl(args, type, op, root, in_place, 0));
++ // TESTCHECK(completeColl(args));

Barrier(args);

@@ -617,6 +618,8 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t

// Performance Benchmark
auto start = std::chrono::high_resolution_clock::now();
++ PRINT(“cudaProfilerStart\n”);
++ cudaProfilerStart();
for (int iter = 0; iter < iters; iter++) {
if (agg_iters>1) NCCLCHECK(ncclGroupStart());
for (int aiter = 0; aiter < agg_iters; aiter++) {
@@ -624,6 +627,8 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
}
if (agg_iters>1) NCCLCHECK(ncclGroupEnd());
}
++ cudaProfilerStop();
++ PRINT(“cudaProfilerStop\n”);

Hi!

You should explicitly specify the range replay option when running the profiler
For example:
ncu --replay-mode range ./build/all_reduce_perf -g 2 -n 1 -w 0 -b 2M -e 2M -c 0

At least it worked for me

But I had a need to profile an application with network communications. I was trying to run the all_reduce_perf compiled with MPI=1 and run a 2 MPI process.

mpirun -n 2 ncu --target-processes all --replay-mode range ./all_reduce_perf -g 3 -n 1 -c 0

This is where dependency occurs.

#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
==PROF== Profiling "range" - 1: ==PROF== Profiling "range" - 1:

Can I somehow profile the nccl kernel using mpi?

1 Like

Hi,

Thank you so much, and it works after adding the ‘–replay-mode range’ option.

as for MPI profiling, I didn’t try it before.

Nsight Compute range replay across different processes is not yet supported, unfortunately. We will be looking into this for a future release. You could consider Nsight System’s GPU metric sampling functionality to get some limited metric values sampled over time.

Hi felix,

With Nsight Compute 2022.1, I can now profile NCCL kernel with Range replay feature.

And I found that there is no metrics about Peer Memory in the “Memory workload” Chart.

Could you help confirm it ? missing any options or just tool doesn’t support it yet ?

Thanks

Nsight Compute range replay across different processes is now supported using the new app-range replay mode starting from Nsight Compute version 2023.1 (CUDA 12.1) . The new app-range replay mode profiles ranges without API capture by relaunching the entire application multiple times. After setting an appropriate range (using profiler start/stop API or NVTX ranges), applications using nccl can now be profiled with --replay-mode app-range.

Hi,
I would like to profile nccl kernels in pytorch code using nsight compute. I am able to profile the nccl-test/build/all_reduce_perf using --replay-mode app-range but for pytorch code it always hangs.
My code is the following:

“”“run.py:”“”
#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
“”“Simple collective communication.”“”
torch.cuda.set_device(rank)

tensor = torch.ones(1).cuda()
torch.cuda.cudart().cudaProfilerStart()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
torch.cuda.cudart().cudaProfilerStop()
print("Rank ", rank, " has data ", tensor.cpu()[0])

def init_process(rank, size, fn, backend=“nccl”):
“”“Initialize the distributed environment.”“”
os.environ[“MASTER_ADDR”] = “127.0.0.1”
os.environ[“MASTER_PORT”] = “29500”
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)

if name == “main”:
size = torch.cuda.device_count()
processes =
mp.set_start_method(“spawn”)

for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run))
    p.start()
    processes.append(p)
for p in processes:
    p.join()

The command to run the code is ncu --target-processes all --replay-mode app-range python run.py Can anyone give me some insights on this? Thanks.

From this snippet, it looks like you are attempting to profile all the ranks that each have a range as they are launched in parallel, which could cause the hang. Can you try to only call the torch.cuda.cudart().cudaProfilerStart()/Stop() for a single rank with something like:

(if rank == 0 )
        torch.cuda.cudart().cudaProfilerStart()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
(if rank == 0)
          torch.cuda.cudart().cudaProfilerStop()

Hi,
Thank you so much. It works just profiling rank 0.

Hi,
I have questions related to the profiling results of running pytorch distributed data parallel training with 2 gpus(connected with NVlink) I got by running ncu --target-processes all --replay-mode app-range --set nvlink --metrics nvlrx__bytes.sum,nvltx__bytes.sum,nvlrx__bytes_data_user.sum,nvltx__bytes_data_user.sum,pcie__read_bytes.sum,pcie__write_bytes.sum,nvltx__bytes.sum.per_second,nvlrx__bytes.sum.per_second python main_nsysprofiler.py -a resnext101_32x8d --dist-url 'tcp://127.0.0.1:13421' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 -b 128 --epochs 1 /data/datasets/imagenet
The results are following, I ran twice to get the results for each gpu.


I would like to know how nvltx/mvlrx__bytes.sum.per_second is calculated. I know the numerator is nvlrx/nvltx__bytes.sum but I don’t know the denominator. Is it the time spent executing the range? Or is it the time data gets transferred on NVlink fabric? How is it related to the nccl_allreduce kernel time? I know I can get the kernel execution time from Nsight system but I am not sure how to get it using Nsight compute.
The bandwidth/utilization is very low in my results. Is this a normal value as we expected? Hope someone can shed some light on it. Thanks!

The denominator is the wall-clock time it took to execute the range from begin() to end(). This value is collected with the metric “gpu__time_duration.sum”. Because you need to use ranges, Nsight Compute doesn’t have metrics for individual kernels. In this case, the best way to get the nccl_allreduce kernel time is probably from Nsight Systems. Or if your range happens to only contain that kernel, the range time in Nsight Compute may be close. With respect to the bandwidth utilization, that’s a difficult question to answer and dependent on the application.

Hi,
Thanks for your reply. That makes sense. I will use the nccl_allreduce kernel time from Nsight System. I have one more question related to the metrics in the nvlink set. Why is there received/transmitted overhead bytes? Is that related to the algorithm used by NCCL library? It seems the overhead is round 80% of the useful data.

The overhead bytes are protocol overhead for using nvlink, and not specific to nccl. It’s hard to say why the ratio is what it is. Perhaps the algorithm is only sending small amounts of data per transmission. You may need to talk with the nccl team or dig deeper into the perf analysis.