Internode nvshmme and ib problem

I have some trouble with inter-node nvshmem environment setting and RDMA environment setting.

I installed the nv_peer_mem and gdrcopy based on this doc( NVSHMEM Installation Guide — nvshmem 2.10.1 documentation (nvidia.com)) on the catalyst-cluster.cs.cmu.edu.

My lsmod | grep nv_peer_mem and depmod -n | grep -i gdrdrv have output, which means I I installed the nv_peer_mem and gdrcopy successfully.

I also write a demo to test the RDMA connection. one node is server and another node is client. The log indicates the RDMA connection is ok. Although if I use the ibping - S, the output is ibwarn: [446273] mad_rpc_open_port: can’t open UMAD port ((null):0) ibping: iberror: failed: Failed to open ‘(null)’ port '0’ . If I use the sudo ibping - S, the output is ibwarn: [453537] _do_madrpc: recv failed: Connection timed out ibwarn: [453537] mad_rpc_rmpp: _do_madrpc failed; dport (Lid 172).
If I run rdma-server on one node and rdma-client -s ip on other node, there is no problem.

Then I try to run my code to test the internode nvshmem. the command is nvshmrun -np 2 --host ip1,ip2 ./worker > nvshmem.log 2>&1 #change the ip to your ip address

#include <stdio.h>
#include "mpi.h"
#include "nvshmem.h"
#include "nvshmemx.h"
#include <unistd.h>

#define CUDA_CHECK(stmt)                                  \
do {                                                      \
    cudaError_t result = (stmt);                          \
    if (cudaSuccess != result) {                          \
        fprintf(stderr, "[%s:%d] CUDA failed with %s \n", \
         __FILE__, __LINE__, cudaGetErrorString(result)); \
        exit(-1);                                         \
    }                                                     \
} while (0)

__global__ void simple_shift(int *destination) {
    int mype = nvshmem_my_pe();
    int npes = nvshmem_n_pes();
    int peer = (mype + 1) % npes;

    nvshmem_int_p(destination, mype, peer);
}

int main (int argc, char *argv[]) {
    int mype_node, msg;
    cudaStream_t stream;
    int rank, nranks;
    char hostname[256];
    gethostname(hostname, 256);
    
    MPI_Comm mpi_comm = MPI_COMM_WORLD;
    nvshmemx_init_attr_t attr;

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &nranks);
    printf("Rank: %d, Hostname: %s and nranks:%d \n", rank, hostname,nranks);

    attr.mpi_comm = &mpi_comm;
    nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr);
    mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);

    CUDA_CHECK(cudaSetDevice(mype_node));
    CUDA_CHECK(cudaStreamCreate(&stream));
    int *destination = (int *) nvshmem_malloc (sizeof(int));

    simple_shift<<<1, 1, 0, stream>>>(destination);
    nvshmemx_barrier_all_on_stream(stream);
    CUDA_CHECK(cudaMemcpyAsync(&msg, destination, sizeof(int),
                cudaMemcpyDeviceToHost, stream));

    CUDA_CHECK(cudaStreamSynchronize(stream));
    printf("%d: received message %d and hostname:%s\n", nvshmem_my_pe(), msg, hostname);

    nvshmem_free(destination);
    nvshmem_fence();
    nvshmem_finalize();
    MPI_Finalize();
    return 0;
}

the error becomes

/home/xxxx/nvshmem_src_2.10.1-3/src/modules/transport/common/transport_ib_common.cpp:84: NULL value mem registration failed 

/home/xxxx/nvshmem_src_2.10.1-3/src/modules/transport/ibrc/ibrc.cpp:500: non-zero status: 2 Unable to register memory handle.
[yyyy:446187:0:446187] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x10)
/home/xxxx/nvshmem_src_2.10.1-3/src/modules/transport/common/transport_ib_common.cpp:84: NULL value mem registration failed 

/home/xxxx/nvshmem_src_2.10.1-3/src/modules/transport/ibrc/ibrc.cpp:500: non-zero status: 2 Unable to register memory handle.
[yyyy:446188:0:446188] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x10)

Is there any problem with my nvshmem or RDMA?

/home/xxxx/nvshmem_src_2.10.1-3/src/modules/transport/common/transport_ib_common.cpp:84: NULL value mem registration failed
/home/xxxx/nvshmem_src_2.10.1-3/src/modules/transport/ibrc/ibrc.cpp:500: non-zero status: 2 Unable to register memory handle.
[yyyy:446187:0:446187] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x10)

This is indicative of RDMA memory registration failed during NVSHMEM runtime initialization. Intuitively, if you are not using DMAbuf, a successful GPU memory registration depends on a few factors - appropriate installation of nvidia-peermem module, matching MOFED installed on the host, etc. Depending on the HW/OS, you may need to pick one: Linux InfiniBand Drivers

Can you share this output for both commands ?

If I run rdma-server on one node and rdma-client -s ip on other node, there is no problem.

Unless, you selected the use of GPU memory as target, this is not a close approximation of IBRC transport in NVSHMEM (when it comes to memory registration). You may want to look into using running this benchmark: GitHub - linux-rdma/perftest: Infiniband Verbs Performance Tests with gpudirect capability first, to make sure RDMA, GPU driver and its associated SW/HW deps are met before running with NVSHMEM multi-node examples over IB.

my output for lsmod | grep nv_peer_mem and depmod -n | grep -i gdrdrv

My lsmod | grep nv_peer_mem and depmod -n | grep -i gdrdrv have output, which means I I installed the nv_peer_mem and gdrcopy successfully.

Sorry for the delay in responding back and thank you for sharing details on the installed modules - A few more follow ups:

  • Can you provide insight into the CUDA driver version you are using using nvidia-smi ? AFAIK, nv_peer_mem off-the-tree project is considered deprecated after CUDA driver r470 and above. NVSHMEM doc may need to updated to reflect that. You may need to install nvidia-peermem module (distributed as part of GPU driver) and remove nv_peer_mem following the instructions here, if you are using the newer CUDA driver/toolkit: 1. Overview — GPUDirect RDMA 12.4 documentation

  • Can you also share output of /var/log/kern.log or dmesg and grep for any messages from nv_peer_mem, nvidia_peermem and ib_core modules ? Assuming you are using Mellanox NICs, there is good chance that your MOFED and peermem modules are not compatible and hence, your memory registration is failing because of that. Please share the timestamp to when ib_core.ko and nvidia-peermem.ko were installed ?

  • Can you share MOFED version ofed_info -s and kernel version uname -r ?

1)hello , my driver version is NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 ,
2) could you tell me how to install the nvidia-peermem module
3)my command cat /var/log/kern.log
cat: /var/log/kern.log: No such file or directory
4)the output of ofed_info -s
-bash: ofed_info: command not found
5)the output of uname -r
4.18.0-372.32.1.el8_6.x86_64

dmesg.log (948.4 KB)

the dmesg log is the above

if you are convenient, could you please provide a dockerfile so I can use the dockerfile to launch a docker

i write a pytorch multi-node code to test the ib connection

import os
import torch
import torch.distributed as dist

def main():

    os.environ['MASTER_ADDR'] = '172.16.1.11'
    os.environ['MASTER_PORT'] = '12348'

  
    dist.init_process_group(backend='nccl', init_method='env://')


    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # 创建一个tensor
    tensor = torch.tensor([1, 2, 3]).cuda()

    if rank == 0:

        dist.send(tensor, dst=1)

    elif rank == 1:

        received_tensor = torch.empty_like(tensor)
        dist.recv(received_tensor, src=0)
        print(f'Rank {rank} received: {received_tensor}')

    dist.barrier()

    dist.destroy_process_group()

if __name__ == '__main__':
    main()

i set the node 0 as master node and its IB address is 172.16.1.11
the command

export NCCL_SOCKET_IFNAME=ib0
export NCCL_IB_DISABLE=0
export NCCL_DEBUG=INFO
torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=172.16.1.11 --master_port=12348 train.py

the node 1 command

export NCCL_SOCKET_IFNAME=ib0
export NCCL_IB_DISABLE=0
export NCCL_DEBUG=INFO


torchrun --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr=172.16.1.11 --master_port=12348 train.py

the error log is

base) [ubuntu@cloudlab-0-11 ~]$ export NCCL_DEBUG=INFO
(base) [ubuntu@cloudlab-0-11 ~]$ torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=172.16.1.11 --master_port=12348 train.py
cloudlab-0-11:512935:512935 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to ib0
cloudlab-0-11:512935:512935 [0] NCCL INFO Bootstrap : Using ib0:172.16.1.11<0>
cloudlab-0-11:512935:512935 [0] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
cloudlab-0-11:512935:512935 [0] NCCL INFO cudaDriverVersion 12040
NCCL version 2.19.3+cuda12.3
cloudlab-0-11:512935:512943 [0] NCCL INFO NCCL_IB_DISABLE set by environment to 0.
cloudlab-0-11:512935:512943 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to ib0
cloudlab-0-11:512935:512943 [0] NCCL INFO NET/IB : Using [0]mlx5_0:1/IB [RO]; OOB ib0:172.16.1.11<0>
cloudlab-0-11:512935:512943 [0] NCCL INFO Using non-device net plugin version 0
cloudlab-0-11:512935:512943 [0] NCCL INFO Using network IB
cloudlab-0-11:512935:512943 [0] NCCL INFO comm 0x8c123a0 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 1000 commId 0xd0cad364d9dcd57 - Init START
cloudlab-0-11:512935:512943 [0] NCCL INFO Setting affinity for GPU 0 to ff000000,00000000,ff000000
cloudlab-0-11:512935:512943 [0] NCCL INFO Channel 00/02 :    0   1
cloudlab-0-11:512935:512943 [0] NCCL INFO Channel 01/02 :    0   1
cloudlab-0-11:512935:512943 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1
cloudlab-0-11:512935:512943 [0] NCCL INFO P2P Chunksize set to 131072
cloudlab-0-11:512935:512943 [0] NCCL INFO Channel 00/0 : 1[0] -> 0[0] [receive] via NET/IB/0
cloudlab-0-11:512935:512943 [0] NCCL INFO Channel 01/0 : 1[0] -> 0[0] [receive] via NET/IB/0
cloudlab-0-11:512935:512943 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[0] [send] via NET/IB/0
cloudlab-0-11:512935:512943 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[0] [send] via NET/IB/0

cloudlab-0-11:512935:512945 [0] misc/ibvwrap.cc:190 NCCL WARN Call to ibv_create_cq failed with error Cannot allocate memory
cloudlab-0-11:512935:512945 [0] NCCL INFO transport/net_ib.cc:520 -> 2
cloudlab-0-11:512935:512945 [0] NCCL INFO transport/net_ib.cc:647 -> 2
cloudlab-0-11:512935:512945 [0] NCCL INFO transport/net.cc:677 -> 2
cloudlab-0-11:512935:512943 [0] NCCL INFO transport/net.cc:304 -> 2
cloudlab-0-11:512935:512943 [0] NCCL INFO transport.cc:148 -> 2
cloudlab-0-11:512935:512943 [0] NCCL INFO init.cc:1117 -> 2
cloudlab-0-11:512935:512943 [0] NCCL INFO init.cc:1396 -> 2
cloudlab-0-11:512935:512943 [0] NCCL INFO group.cc:64 -> 2 [Async thread]
cloudlab-0-11:512935:512935 [0] NCCL INFO group.cc:418 -> 2
cloudlab-0-11:512935:512935 [0] NCCL INFO group.cc:95 -> 2

cloudlab-0-11:512935:512945 [0] misc/ibvwrap.cc:190 NCCL WARN Call to ibv_create_cq failed with error Cannot allocate memory
cloudlab-0-11:512935:512945 [0] NCCL INFO transport/net_ib.cc:520 -> 2
cloudlab-0-11:512935:512945 [0] NCCL INFO transport/net_ib.cc:647 -> 2
cloudlab-0-11:512935:512945 [0] NCCL INFO transport/net.cc:677 -> 2
cloudlab-0-11:512935:512945 [0] NCCL INFO misc/socket.cc:47 -> 3
cloudlab-0-11:512935:512945 [0] NCCL INFO misc/socket.cc:58 -> 3
cloudlab-0-11:512935:512945 [0] NCCL INFO misc/socket.cc:773 -> 3
cloudlab-0-11:512935:512945 [0] NCCL INFO proxy.cc:1374 -> 3
cloudlab-0-11:512935:512945 [0] NCCL INFO proxy.cc:1415 -> 3

cloudlab-0-11:512935:512945 [0] proxy.cc:1557 NCCL WARN [Proxy Service 0] Failed to execute operation Connect from rank 0, retcode 3
Traceback (most recent call last):
  File "/home/ubuntu/train.py", line 36, in <module>
    main()
  File "/home/ubuntu/train.py", line 23, in main
    dist.send(tensor, dst=1)
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 1660, in send
    default_pg.send([tensor], dst, tag).wait()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1691, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.19.3
ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
Last error:
Call to ibv_create_cq failed with error Cannot allocate memory
cloudlab-0-11:512935:512935 [0] NCCL INFO comm 0x8c123a0 rank 0 nranks 2 cudaDev 0 busId 1000 - Abort COMPLETE
[2024-04-18 04:01:04,615] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 512935) of binary: /home/ubuntu/anaconda3/bin/python
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-04-18_04:01:04
  host      : cloudlab-0-11.eth
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 512935)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
(base) [ubuntu@cloudlab-0-11 ~]$ 

it seems my IB has some problem? @arnavg

For 2, if you have installed GPU driver and CUDA toolkit from apt-get or yum previously, you should have nvidia-peermem available under /lib/modules/. To install or load it, just perform modprobe nvidia-peermem and then do lsmod | grep nvidia_peermem to confirm it is loaded.

For 4, this indicates you don’t have Mellanox OFED package on your node.

  • To download the drivers, please follow the instructions here
  • To install OFED, please follow the instructions here

For 4, could I install the driver and OFED on docker? thanks

It needs to be installed on baremetal. Here are its HW and SW requirements: https://docs.nvidia.com/networking/display/mlnxofedv23102131lts/hardware+and+software+requirements

for 2, I have installed the GPU driver and CUDA toolkit( CUDA Toolkit 12.2 Downloads | NVIDIA Developer).

Then i enter into the /lib/modules/

then perform modprobe nvidia-peermem , the output is modprobe: ERROR: could not insert 'nvidia_peermem': Operation not permitted,

The output of command sudo modprobe nvidia-peermem is modprobe: ERROR: could not insert 'nvidia_peermem': Invalid argument

modprobe: ERROR: could not insert ‘nvidia_peermem’: Invalid argument

  • Please share dmesg -T output when you are run this command to understand this error better.
  • Please share the output of this command to confirm .ko exists:
    find /lib/modules -name "nvidia*.ko"

dmegs.log (1.1 MB)
the dmesg -T output is very long, so I share the log

for the find /lib/modules -name "nvidia*.ko", the output is None

but the path /lib/modules/4.18.0-372.32.1.el8_6.x86_64/extra
it has nvidia-drm.ko.xz nvidia.ko.xz nvidia-modeset.ko.xz nvidia-peermem.ko.xz nvidia-uvm.ko.xz

Hello,

At least one problem is detailed in this section of the GPUDirect RDMA documentation here under the note at the end of the section.

If you install the MOFED after the GPU driver, you will need to reinstall the GPU driver. That is why the nvidia-peermem module is failing to load for you.

if you are convenient, could you please provide a dockerfile so I can use the dockerfile to launch a docker
Unfortunately, we don’t have a docker file available for nvshmem today. Also, I don’t know that we plan to add installation of kernel modules into the container definition. FWICT, this is generally frowned upon due to the effects experienced outside the container and the maintainability issues related to kernel module versioning.

I can say that nvshmem containers including all of our non-kernel dependencies are actively under consideration for prioritization within our roadmap, but I can’t give a specific release at this time.

hi, @sethh @arnavg
my script for installing nvshmem. The document is NVSHMEM Installation Guide — nvshmem 2.10.1 documentation (nvidia.com)

First I install the ucs with the following command


git clone https://github.com/openucx/ucx.git
cd ucx
./autogen.sh
./contrib/configure-release --prefix=/home/lambda7xx/UCX --enable-mt  --with-dm
make -j 32
make install #sudo make install 

Then I install the nvshmem for inter-node.


 export NVSHMEM_USE_GDRCOPY=1
  export NVSHMEM_MPI_SUPPORT=1
  export NVSHMEM_UCX_SUPPORT=1
  export NVSHMEM_USE_NCCL=1
  export NVSHMEM_PREFIX==/home/lambda7xx/nvshmem
  export UCX_HOME=/home/lambda7xx/UCX
  mkdir build
  cd build 
  cmake  -DNVSHMEM_PREFIX=/home/lambda7xx/nvshmem -DNVSHMEM_IBRC_SUPPORT=1 -DNVSHMEM_UCX_SUPPORT=1 -DNVSHMEM_IBGDA_SUPPORT=1 -DNVSHMEM_MPI_SUPPORT=1     -DNVSHMEM_MPI_IS_OMPI=1    ..
  make -j 128
  make install #sudo make install 
  
  
  #set the .bashrc
  echo ` export PATH=/home/lambda7xx/nvshmem/bin:$PATH' >> ~/.bashrc
  echo `export LD_LIBRARY_PATH=/home/lambda7xx/nvshmem/lib:$LD_LIBRARY_PATH`>> ~/.bashrc

is there any mistake for my command.

Since you are enabling NVSHMEM_USE_GDRCOPY, NVSHMEM_IBRC_SUPPORT and NVSHMEM_IBGDA_SUPPORT, how did you install the corresponding dependencies ?

The order of installation of those remaining dependencies are:

  • MOFED and reboot the device/node for the NIC firmware to bootstrap
  • Install CUDA driver and toolkit, which should install nvidia_peermem
  • Install GDRCOPY

Link to these dependencies are also in the SW requirements:
https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#software-requirements

install the gdrcopy

  1. build and install
git clone https://github.com/NVIDIA/gdrcopy.git

cd gdrcopy

make all -j 128

make install # sudo make install 

sudo ./insmod.sh
  1. test if the gdrcopy is installed
depmod -n | grep -i gdrdrv

if the output is empty, it means gdrdrv is not loaded into the kernel

Your steps to install gdrcopy look correct to me. Please make sure the order of installation of MOFED, NVIDIA Peermem and CUDA is as described above post and the order of installation is important as others have pointed out on this thread.