Profiling fails on more than one gpu device

Hi,
We are trying to profile a PyTorch application on multiple devices, but there is a problem with multi-gpu run.
The command is:

$ncu_path -c 200 --metrics launch__grid_size 
--cache-control none --target-processes all 
-f -o run2 python3 main.py

Then the output is:

Starting Training Loop...
==PROF== Profiling "vectorized_elementwise_kernel" - 25 (26/200): 0%....50%....100% - 1 pass
==PROF== Profiling "CatArrayBatchedCopy_aligned16..." - 26 (27/200): 0%....50%....100% - 1 pass
==PROF== Profiling "ncclKernel_Broadcast_RING_LL_..." - 27 (28/200): 0%....50%
==WARNING== Launching the workload is taking more time than expected. If this continues to hang, 
terminate the profile and re-try by profiling the range of all related launches 
using '--replay-mode range'. 
See https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#replay for more details.

We see this error on a HGX machine (16 gpu system with V100) and a machine with two 4090 devices. If we limit the program to run on one device, it is fine.

Using the --replay-mode doesn’t help either. In this case, nothing is profiled!
The command is:

$ncu_path -c 200 --metrics launch__grid_size 
--replay-mode range --cache-control none --target-processes all 
-f -o run2 python3 main.py

The output is:

Starting Training Loop...
==PROF== Target process 27 terminated before first instrumented API call.
[0/1][0/99]	Loss_D: 1.8108	Loss_G: 2.3480	D(x): 0.3327	D(G(z)): 0.3594 / 0.1407
[0/1][1/99]	Loss_D: 3.7184	Loss_G: 2.7405	D(x): 0.9820	D(G(z)): 0.9565 / 0.1020
MORE PROGRAM OUTPUT
...
==PROF== Target process 65 terminated before first instrumented API call.
==PROF== Disconnected from process 21
==WARNING== No ranges were profiled.

It is a bit weird, because the metric launch__grid_size is a basic one and I am pretty sure this doesn’t need any replay. I remember that some metrics, e.g. cache metrics, need replays because of their operation.

The ncu version is Version 2023.2.2.0 (build 33188574) (public-release).

Any idea about that? How can we debug further?

The hang is likely caused by two kernels being serialized which normally need to run concurrently. Using a range-based replay mode is the right approach. However, both of these (range and app-range) require to explicitly define ranges in the application. Once you do that, you should see ranges being profiled by the tool. If possible, using app-range is preferred, as it doesn’t have limitations on the supported APIs and doesn’t require memory save/restore. You can refer to the documentation on the two modes.

I tested app-range and it is similar to range. I mean nothing is profiled.
I read the documents, but it is not clear for me what is the modification level. The high level code of PyTorch is Python and the libraries are in C++. I am sure I am not the first one that is trying to profile a PyTorch app on multi-gpu environment.

Regarding the cu(da)ProfilerStart, since the profiling with one GPU is fine, I assume that is embedded in the code and libraries. So, the two-gpu error is not related to that. Do you agree with that?

Regarding --nvtx --nvtx-include <expression> [--nvtx-include <expression>], I was wondering what is the expression exactly? Is it the kernel name? Then how can I determine which kernel to be in the range? I have the usage, but I really don’t understand domain and range names:

Usage of --nvtx-include and --nvtx-exclude:
  ncu --nvtx --nvtx-include "Domain A@Range A"
     Profile kernels wrapped inside start/end range 'Range A' of 'Domain A'
  ncu --nvtx --nvtx-exclude "Range A]"
    Profile all kernels except kernels wrapped inside push/pop range 'Range A' of <default domain> at the top of the stack.
  ncu --nvtx --nvtx-include "Range A" --nvtx-exclude "Range B"
     Profile kernels wrapped inside start/end range 'Range A' but not inside  'Range B' of <default domain>

Since the original error happens at

==PROF== Profiling "ncclKernel_Broadcast_RING_LL_..." - 27 (28/200): 0%....50%
==WARNING== Launching the workload is taking more time than expected. If this continues to hang, 
terminate the profile and re-try by profiling the range of all related launches 

Does that mean I have to specify the option like ncu --nvtx --nvtx-include "ncclKernel_Broadcast_RING*"?

The names of the kernels in the range don’t matter for the --nvtx-include expression, the name of the range itself matters. Assuming you want to define the range using PyTorch itself, it would be similar to

torch.cuda.nvtx.range_push("my_range")
# execute CUDA/PyTorch kernels here
torch.cuda.nvtx.range_pop()

and profile with

ncu --replay-mode (app-)range --nvtx --nvtx-include "my_range/" ...

As shown here, the / is needed to denote that you want to match push/pop ranges (vs. start/end ranges). If you don’t want to use nvtx, you could replace it instead with

torch.cuda.cudart().cudaProfilerStart()
torch.cuda.cudart().cudaProfilerStop()

The exact way to call these APIs will depend on the framework you are using. You need to consult your framework’s documentation on further information for that.

1 Like

We tried to apply the ranges inside the code, like this:

        torch.cuda.nvtx.range_push("init_D")
        netD.zero_grad()
        .....
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        torch.cuda.nvtx.range_pop()

        torch.cuda.nvtx.range_push("forward_D")
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        ....
        torch.cuda.nvtx.range_pop()

        torch.cuda.nvtx.range_push("step_D")
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        ...
        torch.cuda.nvtx.range_pop()

        torch.cuda.nvtx.range_push("init_G")
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        torch.cuda.nvtx.range_pop()

        torch.cuda.nvtx.range_push("forward_and_step_G")
        errG = criterion(output, label)
        ...
        optimizerG.step()
        torch.cuda.nvtx.range_pop()

Using the command:

$ncu_path -c 200 \
--nvtx --nvtx-include "init_D/*/forward_and_step_G" \
--metrics launch__grid_size \
--replay-mode app-range \
--cache-control none \
--target-processes all

The output is:

==PROF== Target process 27 terminated before first instrumented API call.
[0/1][0/99]	Loss_D: 1.8108	Loss_G: 2.3481	D(x): 0.3327	D(G(z)): 0.3594 / 0.1407
...
[0/1][98/99]	Loss_D: 1.1188	Loss_G: 10.2879	D(x): 0.9319	D(G(z)): 0.5484 / 0.0006
==PROF== Target process 65 terminated before first instrumented API call.
==PROF== Disconnected from process 21
==WARNING== No ranges were profiled.

Does that mean we have to continue and narrow the ranges by defining more ranges?
Is there a way to better pinpoint the problem and see which kernel is causing this problem? Something like a call stack?

You are pushing and popping range init_D, followed by pushing and popping range forward_and_step_G. Then you tell ncu to profile the range that has both init_D and forward_and_step_G on the stack simultaneously, but no such range exists.

If you want to profile both ranges in your current setup, you have to provide multiple --nvtx-include parameters, each with the respective range name.

1 Like

Still the problem exists. Let me explain with a test. The command is:

$ncu_path -c 200 \
--nvtx --nvtx-include "init_D/" \
--metrics launch__grid_size \
--replay-mode app-range --cache-control none --target-processes all
...

And the range is the code is:
‍‍‍

    torch.cuda.nvtx.range_push("init_D")
    netD.zero_grad()
    real_cpu = data[0].to(device)
    b_size = real_cpu.size(0)
    label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
    torch.cuda.nvtx.range_pop()

But as you see below, it shows n/a for the grid size. In fact, there is no column for launch__grid_size.

The program output is:

Starting Training Loop...
==PROF== Target process 27 terminated before first instrumented API call.
==PROF== Profiling "range" - 0 (1/200): Application replay pass 1
==PROF== Profiling "range" - 1 (2/200): Application replay pass 1
==PROF== Profiling "range" - 2 (3/200): Application replay pass 1
==PROF== Profiling "range" - 3 (4/200): Application replay pass 1
==PROF== Profiling "range" - 4 (5/200): Application replay pass 1
...

Do you have any idea about that?

I would argue that the problem you described previously is in fact resolved. Originally, your issue was that no range had been profiled at all, which is not the case anymore. Your new concern is that there are no launch metrics for the range, which is in fact expected. Since the range can contain any number of workloads, it’s not given a single grid size at this point. We are looking into improving this in future releases. You should be able to collect other metrics from the basic or full set for these ranges.

OK I agree that the original question has been answered. So, I will check for more details and create a new topic if needed. Thanks.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.