cudaMalloc performance issue after p2p access is enabled

Per the blog:

Meanwhile, the run-time complexity of the cudaEnablePeerAccess API is roughly O(N * lg(N)) where N is the number of allocations made on the source device that need to be mapped to the destination device. Often this is called for each device pair to enable full bidirectional peer access, being a total O(D * D * N * lg(N)) , where D is the number of devices. Also, as mentioned earlier, cudaMalloc must now map its allocations to all devices with peer access enabled. This means that the runtime complexity now scales as O(D * lg(N)) .

This means, if i enable p2p access between 2 gpus, after the p2p access, cudaMalloc will be twice slower because it needs to map the memory in both processes.

I tried to verify this claim. If this is true, I need to adjust how I share memory between processes. This is what I do:

import os

import time
import torch
import torch.distributed as dist

import os
import ctypes
import time
import torch
import torch.distributed as dist

# Load CUDA runtime library
libcudart = ctypes.CDLL("libcudart.so")

# Define function prototypes in ctypes
cudaMalloc = libcudart.cudaMalloc
cudaMalloc.restype = ctypes.c_int
cudaMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]

cudaFree = libcudart.cudaFree
cudaFree.restype = ctypes.c_int
cudaFree.argtypes = [ctypes.c_void_p]

class CudaIpcMemHandle(ctypes.Structure):
    _fields_ = [("reserved", ctypes.c_byte * 128)]

cudaIpcMemLazyEnablePeerAccess = 1

cudaIpcGetMemHandle = libcudart.cudaIpcGetMemHandle
cudaIpcGetMemHandle.restype = ctypes.c_int
cudaIpcGetMemHandle.argtypes = [ctypes.POINTER(CudaIpcMemHandle), ctypes.c_void_p]

cudaIpcOpenMemHandle = libcudart.cudaIpcOpenMemHandle
cudaIpcOpenMemHandle.restype = ctypes.c_int
cudaIpcOpenMemHandle.argtypes = [ctypes.POINTER(ctypes.c_void_p), CudaIpcMemHandle, ctypes.c_uint]

cudaMemcpy = libcudart.cudaMemcpy
cudaMemcpy.restype = ctypes.c_int
cudaMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]

cudaDeviceSynchronize = libcudart.cudaDeviceSynchronize
cudaDeviceSynchronize.restype = ctypes.c_int

# Helper functions for CUDA memory management
def alloc(size):
    ptr = ctypes.c_void_p()
    result = cudaMalloc(ctypes.byref(ptr), size)
    if result != 0:
        raise Exception("cudaMalloc failed")
    return ptr

def free(ptr):
    result = cudaFree(ptr)
    if result != 0:
        raise Exception("cudaFree failed")

def synchronize():
    cudaDeviceSynchronize()


def worker(ipc=False):
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)
    # warmup device
    size_in_bytes = 1024 * 1024
    n_elements = size_in_bytes // 4
    ptr = alloc(size_in_bytes)
    if ipc:
        handle = CudaIpcMemHandle()
        assert cudaIpcGetMemHandle(ctypes.byref(handle), ptr) == 0
        ptrs = []
        for i in range(world_size):
            if i == rank:
                ptrs.append(ptr)
                dist.broadcast_object_list([handle], src=i)
            else:
                recv = [None]
                dist.broadcast_object_list(recv, src=i)
                recv_handle = recv[0]
                recv_ptr = ctypes.c_void_p()
                assert cudaIpcOpenMemHandle(ctypes.byref(recv_ptr), recv_handle, cudaIpcMemLazyEnablePeerAccess) == 0
                ptrs.append(recv_ptr)

    synchronize()
    dist.barrier()

    start = time.time()
    data = []
    for i in range(2000):
        data.append(alloc(size_in_bytes))
    synchronize()
    end = time.time()
    elapsed = end - start
    print(f"time for cudaMalloc: {elapsed}")
    dist.destroy_process_group()

if __name__ == "__main__":
    ipc = bool(int(os.getenv("IPC", "0")))
    print(f"ipc: {ipc}")
    worker(ipc=ipc)

Basically, I use pytorch to broadcast the handle, and call cuda APIs through ctypes.

Run it with torchrun --nproc-per-node 4 test.py , either export IPC=0 or export IPC=1, the results are:

ipc: False
ipc: False
time for cudaMalloc: 0.3364677429199219
time for cudaMalloc: 0.336867094039917

ipc: True
ipc: True
time for cudaMalloc: 0.33277273178100586
time for cudaMalloc: 0.33387041091918945

This means, after I turn on p2p access through ipc, the cudaMalloc speed is kind of the same, rather than twice.

I’m running the program in DGX-V100 machine.

I even run the program with 4 GPUs:

ipc: False
ipc: False
ipc: False
ipc: False
time for cudaMalloc: 0.5533695220947266
time for cudaMalloc: 0.6247894763946533
time for cudaMalloc: 0.6472084522247314
time for cudaMalloc: 0.6500265598297119
ipc: True
ipc: True
ipc: True
ipc: True

time for cudaMalloc: 0.5264256000518799
time for cudaMalloc: 0.5990443229675293
time for cudaMalloc: 0.6208875179290771
time for cudaMalloc: 0.6369361877441406

Still quite the same, not four times.

It seems enabling p2p access does not hurt cudaMalloc at all. Is it the general case? Or this is just a special case for my machine?

Here is the test I ran:

$ cat t6.cu
#include <iostream>
#include <cstdlib>

#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}

const int num_alloc = 100;
const int ds = 1048576*128;
int main(int argc, char *argv[]){
  cudaError_t err;
  if (argc > 1) {
    cudaSetDevice(1);
    cudaDeviceEnablePeerAccess(0, 0);
    cudaSetDevice(0);
    cudaDeviceEnablePeerAccess(1, 0);
    err = cudaGetLastError();
    std::cout << "enabling peer access: " << cudaGetErrorString(err) << std::endl;
  }
  cudaDeviceSynchronize();
  unsigned long long dt = dtime_usec(0);
  float *dev[num_alloc];
  for (int i = 0; i < num_alloc; i++) cudaMalloc(dev+i, ds);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  err = cudaGetLastError();
  std::cout << cudaGetErrorString(err) << std::endl;
  std::cout << "elapsed time: " << dt/(float)USECPSEC << std::endl;
}
$ nvcc -o t6 t6.cu
$ ./t6
no error
elapsed time: 0.007637
$ ./t6 1
enabling peer access: no error
no error
elapsed time: 0.125872
$

DGX-H100, CUDA 12.2

So I witness a substantially longer time in the peer enabled case (16x for this test).

I won’t be able to help with your pytorch test. This is a CUDA programming forum, and I consider, from my own perspective, pytorch questions to be off-topic here. I would suggest asking pytorch questions on a forum for pytorch, such as discuss.pytorch.org. There are NVIDIA experts that patrol that forum. If you want help here, in this forum, my suggestion is that you provide a CUDA test case.

To be clear: do as you wish. People post questions on a wide variety of topics here. However, based on my observation over a period of time, the questions that are most likely to get traction here are the ones that involve CUDA C++. We have separate forums for other related CUDA toolchains like CUDA Fortran. People post questions here on thrust, numba CUDA, and other related stuff, and sometimes those get answered too. thrust ships with the CUDA toolkit, and numba CUDA has nearly a 1:1 correspondence with CUDA C++. But beyond that, based on my observations, other questions generally get less traction, based on (I guess) the community that typically contributes to this forum.

Maybe someone will come along and tell you what is going on with your pytorch test. From my perspective, I see very solid evidence that cudaMalloc takes longer in the peer-enabled case.

Thanks for your answer. You explicitly enabled p2p access by cudaDeviceEnablePeerAccess , while I’m interested in implicit p2p access via cudaIpcOpenMemHandle .

Here is my updated test script, with pure cuda c++ program:

#include <iostream>
#include <unistd.h>
#include <sys/wait.h>
#include <chrono>
#include <cuda.h>
#include <cuda_runtime.h>

#define CHECK_CUDA(call) do { \
    cudaError_t err = call; \
    if (err != cudaSuccess) { \
        fprintf(stderr, "CUDA Error: %s in %s at line %d\n", cudaGetErrorString(err), __FILE__, __LINE__); \
        exit(EXIT_FAILURE); \
    } \
} while (0)

double measureMallocTime(size_t bytes, int times) {
    std::chrono::high_resolution_clock::time_point start, stop;
    std::chrono::microseconds duration(0);

    void *mem;
    start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < times; ++i) {
        cudaMalloc(&mem, bytes);
    }
    cudaDeviceSynchronize();
    stop = std::chrono::high_resolution_clock::now();
    duration += std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
    return static_cast<double>(duration.count()) / times;
}

int main() {
    int pipefd1[2]; // Parent to child
    int pipefd2[2]; // Child to parent
    pid_t pid;
    char ack = '1';

    // Create pipes
    if (pipe(pipefd1) == -1 || pipe(pipefd2) == -1) {
        perror("pipe");
        return EXIT_FAILURE;
    }

    // Fork a process
    pid = fork();
    if (pid == -1) {
        perror("fork");
        return EXIT_FAILURE;
    }

    if (pid > 0) {
        // Parent process
        close(pipefd1[0]);
        close(pipefd2[1]);

        CHECK_CUDA(cudaSetDevice(0));

        double timeBefore = measureMallocTime(1024 * 1024, 1000);
        std::cout << "Parent (Device 0) average cudaMalloc time before IPC: " << timeBefore << " us" << std::endl;

        // Allocate memory and get IPC handle
        void *mem, *hostmem;
        CHECK_CUDA(cudaMalloc(&mem, 1024 * 1024));
        CHECK_CUDA(cudaMemset(mem, 1, 1024 * 1024));
        hostmem = malloc(1024 * 1024);
        CHECK_CUDA(cudaMemcpy(hostmem, mem, 1024 * 1024, cudaMemcpyDefault));
        std::cout << "Parent (Device 0), the first bytes are " << ((uint32_t*)(hostmem))[0] << std::endl;
        cudaIpcMemHandle_t handle;
        CHECK_CUDA(cudaIpcGetMemHandle(&handle, mem));

        // Write the handle to the pipe
        write(pipefd1[1], &handle, sizeof(cudaIpcMemHandle_t));

        // Wait for ack from the child
        read(pipefd2[0], &ack, 1);

        double timeAfter = measureMallocTime(1024 * 1024, 1000);
        std::cout << "Parent (Device 0) average cudaMalloc time after IPC: " << timeAfter << " us" << std::endl;

        // Clean up
        CHECK_CUDA(cudaFree(mem));

        close(pipefd1[1]);
        close(pipefd2[0]);

        wait(NULL); // Wait for child to exit
    } else {
        // Child process
        close(pipefd1[1]);
        close(pipefd2[0]);

        CHECK_CUDA(cudaSetDevice(1));

        double timeBefore = measureMallocTime(1024 * 1024, 1000);
        std::cout << "Child (Device 1) average cudaMalloc time before IPC: " << timeBefore << " us" << std::endl;

        // Read the IPC handle from the pipe
        cudaIpcMemHandle_t handle;
        read(pipefd1[0], &handle, sizeof(cudaIpcMemHandle_t));

        // Open IPC memory
        void *remoteMem, *hostmem;
        CHECK_CUDA(cudaIpcOpenMemHandle(&remoteMem, handle, cudaIpcMemLazyEnablePeerAccess));
        hostmem = malloc(1024 * 1024);
        CHECK_CUDA(cudaMemcpy(hostmem, remoteMem, 1024 * 1024, cudaMemcpyDefault));
        std::cout << "Child (Device 1), the first bytes are " << ((uint32_t*)(hostmem))[0] << std::endl;
        // Send ack to the parent
        write(pipefd2[1], &ack, 1);

        double timeAfter = measureMallocTime(1024 * 1024, 1000);
        std::cout << "Child (Device 1) average cudaMalloc time after IPC: " << timeAfter << " us" << std::endl;

        // Clean up
        CHECK_CUDA(cudaIpcCloseMemHandle(remoteMem));
        close(pipefd1[0]);
        close(pipefd2[1]);
    }

    return EXIT_SUCCESS;
}

The output:

Parent (Device 0) average cudaMalloc time before IPC: 207.496 us
Child (Device 1) average cudaMalloc time before IPC: 206.167 us
Parent (Device 0), the first bytes are 16843009
Child (Device 1), the first bytes are 16843009
Parent (Device 0) average cudaMalloc time after IPC: 209.476 us
Child (Device 1) average cudaMalloc time after IPC: 209.606 us

I can confirm the p2p access, because child process can read the same memory as the parent process (by a quick verification of the first several bytes). However, after this p2p being enabled, the malloc speed is still the same.

I suppose the claim in the blog may not be applicable to every scenario. It is certainly applicable to some.

It seems pretty obvious to me that:

  • enabling P2P implies/requires peer access for any appropriate device allocation on the peer devices.
  • enabling IPC does not. It only requires peer access for the allocation for which a mem handle is extracted. Therefore, the claim in the blog might not be applicable to the case where you use IPC on a particular allocation, but then simply do cudaMalloc subsequently.

Although IPC may require P2P capability/mapping for the IPC-shared allocation, it does not make any implication about future allocations via cudaMalloc.

On the other hand, cudaEnablePeerAccess does make implications about future allocations via cudaMalloc.

In short, your result is not surprising to me. It seems quite plausible, what I would expect actually, and not at all in conflict with what you excerpted from the blog, which explicitly mentions cudaEnablePeerAccess and makes no mention of IPC.

The blog points out that there is no such thing as a global “enable peer access”, even though the API seems to work that way. The way it actually works, is that each relevant allocation must be individually peer mapped.

For CUDA IPC, each individual allocation need not be peer mapped. Only the allocation for which a shareable handle is extracted. Therefore, although IPC seems to depend on P2P capability under the hood (for the device-to-device case), it seems to me that there is not a 1:1 equivalence between the P2P utilization of IPC, and the P2P capability that is enabled via cudaEnablePeerAccess.

Honestly speaking, the result looks plausible to me, too. I would expect cuda only map that part of memory exported through IPC.

My puzzle is caused by the documentation of cudaIpcOpenMemHandle, which clearly says that:

For contexts on different devices cudaIpcOpenMemHandle can attempt to enable peer access between the devices as if the user called cudaDeviceEnablePeerAccess.

Does the above result indicate that the documentation is wrong?

I imagine the documentation could always be improved. Honestly this looks like a matter of interpretation, but ideally documentation would be so crystal-clear that such interpretation would never be necessary. (It might be as simple as adding “for that allocation” to the excerpt you selected. Which might be what was meant, even though its not stated exactly that way. But even that is arguably confusing.)

If you think the documentation could be improved you are welcome to file a bug.

From my experiment, it seems a more precise way would be:

For contexts on different devices cudaIpcOpenMemHandle can attempt to enable peer access between the devices as if the user called cudaDeviceEnablePeerAccess, allocated the memory, and then called cudaDeviceDisablePeerAccess .

bug filed at Log in | NVIDIA Developer

Thanks for filing a bug ticket . This maps to NVBUG ID 4680774 . We will get back conclusion here when the ticket cycle is done internally .

Best,
Yuki

With the word can the documentation does not explicitly prescribe anything, does it?

I don’t understand what do you mean. cudaIpcOpenMemHandle will enable p2p access if the current context device is different from the device the memory handle to open lives in.