Profiling nccl 2.14.3 leads to segmentation fault

Forwarding issue from Pytorch: SIGSEGV when using DDP + NCCL + nsys profiling + Pytorch 1.13 · Issue #94393 · pytorch/pytorch · GitHub that seems to be related to nsys + nccl 2.14.3 rather than Pytorch:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


def run_demo(demo_fn, world_size):
    mp.start_processes(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True,
             start_method="fork")

if __name__ == "__main__":
    run_demo(demo_basic, torch.cuda.device_count())

When running this script with Pytorch 1.13.1+cu117, under nsys profiling only (I’ve been using Nsight Systems 2023.1.1 but it’s also happening with 2022.4.2) ~/.local/share/nsight_systems/nsys profile python test_ddp.py, I get a SIGSEGV on rank 0.

Here is the GDB backtrace:

#0  0x0000000000000000 in ?? ()
No symbol table info available.
#1  0x00007ff5c8b0f1da in ?? () from /home/juh/.local/share/nsight_systems/libToolsInjection64.so
No symbol table info available.
#2  0x00007ff5c885f606 in NSYS_DL_dlvsym () from /home/juh/.local/share/nsight_systems/libToolsInjection64.so
No symbol table info available.
#3  0x00007ff58be01168 in initOnceFunc () at misc/ibvwrap.cc:94
        ibvhandle = 0x7ff518002d30
        tmp = <optimized out>
        cast = 0x7ff59acdfb08 <ibv_internal_reg_dmabuf_mr>
        __func__ = "initOnceFunc"
#4  0x00007ff5c83ae4df in __pthread_once_slow (once_control=0x7ff59acdfba4 <initOnceControl>, init_routine=0x7ff58be00ec0 <initOnceFunc()>) at pthread_once.c:116
        _buffer = {__routine = 0x7ff5c83ae530 <clear_once_control>, __arg = 0x7ff59acdfba4 <initOnceControl>, __canceltype = 0, __prev = 0x0}
        val = <optimized out>
        newval = <optimized out>
#5  0x00007ff5c83ae595 in __GI___pthread_once (once_control=once_control@entry=0x7ff59acdfba4 <initOnceControl>, init_routine=init_routine@entry=0x7ff58be00ec0 <initOnceFunc()>) at pthread_once.c:143
        val = <optimized out>
#6  0x00007ff58be01847 in wrap_ibv_symbols () at misc/ibvwrap.cc:139
No locals.
#7  0x00007ff58bde3349 in ncclIbInit (logFunction=<optimized out>) at transport/net_ib.cc:150
        shownIbHcaEnv = 0
        __func__ = <optimized out>
        nIbDevs = <optimized out>
        devices = <optimized out>
        userIbEnv = <optimized out>
        userIfs = <optimized out>
        searchNot = <optimized out>
        searchExact = <optimized out>
        nUserIfs = <optimized out>
        d = <optimized out>
        context = <optimized out>
        nPorts = <optimized out>
        devAttr = <optimized out>
        port = <optimized out>
        portAttr = <optimized out>
        res = <optimized out>
        line = <optimized out>
        addrline = <optimized out>
        d = <optimized out>
#8  0x00007ff58bdbb84a in netGetState (state=<synthetic pointer>, i=1) at net.cc:245
        ndev = 32757
#9  ncclNetInit (comm=comm@entry=0x555d647315b0) at net.cc:273
        res = <optimized out>
        state = <optimized out>
        i = 1
        netName = <optimized out>
        ok = false
        __func__ = <optimized out>
#10 0x00007ff58bd7e687 in commAlloc (comret=comret@entry=0x555d64123530, ndev=ndev@entry=1, rank=rank@entry=0) at init.cc:323
        res = <optimized out>
        comm = 0x555d647315b0
        __func__ = "commAlloc"
#11 0x00007ff58bd8363e in ncclCommInitRankFunc (job_=<optimized out>) at init.cc:1088
        job = <optimized out>
        newcomm = 0x555d64123530
        comm = 0x555d647315b0
        nranks = 1
        commId = {internal = "\002\000\245\355\254\032-6", '\000' <repeats 119 times>}
        myrank = 0
        cudaDev = <optimized out>
        res = ncclSuccess
        __func__ = "ncclCommInitRankFunc"
#12 0x00007ff58bd7bd68 in ncclAsyncJobMain (arg=0x555d29662100) at group.cc:62
        job = 0x555d29662100
        __func__ = "ncclAsyncJobMain"
#13 0x00007ff5c88b4c0d in ?? () from /home/juh/.local/share/nsight_systems/libToolsInjection64.so
No symbol table info available.
#14 0x00007ff5c83a5609 in start_thread (arg=<optimized out>) at pthread_create.c:477
        ret = <optimized out>
        pd = <optimized out>
        unwind_buf = {cancel_jmp_buf = {{jmp_buf = {140691213436672, 2100541623046917983, 140732780327214, 140732780327215, 140732780327216, 140691213434304, -2104046464448442529, -2104008895042138273}, mask_was_saved = 0}}, priv = {pad = {0x0, 0x0, 0x0, 0x0}, data = {prev = 0x0, cleanup = 0x0, canceltype = 0}}}
        not_first_call = 0
#15 0x00007ff5c816e133 in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95
No locals.

The backtrace is leading me to believe there is some kind of incompatibility between this version of nccl and nsys. Unfortunately, this is the nccl version that Pytorch is currently using for its binaries, so it would be very interesting to see if it could be fixed on the nsys side.

@afroger

Hi juh,

I was able to reproduce using the 1.13.1-cuda11.6-cudnn8-runtime PyTorch container.

Interestingly, the problem doesn’t occur when “libibverbs.so” from the “libibverbs-dev” debian package isn’t installed on the system. NCCL loads and interacts with the library at runtime. I haven’t looked at what it’s used for yet. If the library is absent, NCCL will fail to load it and continue functioning correctly.

Now that I have a reproducer, I’ll look into the issue and post more once I have some updates.

The problem was identified and a fix is under review. The upcoming version of Nsight Systems will contain the fix.

Also a side note that’s unrelated to the issue, I see the code sample uses start_processes with the start_method set to fork. If you want Nsight Systems to work reliably, you should start away from the fork and forkserver start method and use spawn. See this comment to understand why using the fork without exec idiom heavily complicates things for the tool.

That’s amazing, thanks for identifying and fixing it so quickly. Do you know about the release cycle for Nsight Systems?

Also, agreed with fork/spawn - it happened with all methods and it was just easier to debug with GDB with fork

Our next website release should be 2023.2.1 and will be available by the end of March.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.