Why write pinned memory is much slower than load from pinned memory on multiprocessing multi-GPU?

Hi,

Recently I face an issue with pinned memory (zero-copy) on multiprocessing multi-GPU programs.

I pinned some data on pinned memory and want to achieve better performance than write on CPU memory. The performance of zero copy load data is good, but when the pinned memory is increased to 3*10^7 float, I can see a little performance improvement from writing data. The data load time is from 10.3 ms to 3 ms by zero copy, but the write time is only from 9.4ms to 8ms.

I pinned memory by using cudaHostRegister() and set flags with cudaHostRegisterPortable | cudaHostAllocMapped. Anyone who knows the reason?

My GPU is 4 RTX 3090.

Thank you!

I don’t see a performance difference like you are suggesting (3ms vs. 8ms) but I do note that the performance seems to be better by using cudaHostAlloc rather than a host allocation followed by cudaHostRegister, so my suggestion would be to use cudaHostAlloc:

# cat t198.cu
#include <iostream>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}

size_t sz = 1048576ULL*32;

template <typename T>
__global__ void write_kernel(T *d, size_t s, T val){
  for (size_t i = blockIdx.x*blockDim.x+threadIdx.x; i < s; i+=gridDim.x*blockDim.x)
    d[i] = val;
}

template <typename T>
__global__ void read_kernel(const T *d, size_t s, T tval, T *r){
  T val = 0;
  for (size_t i = blockIdx.x*blockDim.x+threadIdx.x; i < s; i+=gridDim.x*blockDim.x)
    val += d[i];
  if (val == tval) *r = val;
}


using mt = float;

int main(){
#ifndef USE_HALLOC
  mt *d = new mt[sz];
  cudaHostRegister(d, sizeof(*d)*sz, /* cudaHostRegisterPortable | */  cudaHostAllocMapped);
#else
  mt *d;
  cudaHostAlloc(&d, sizeof(*d)*sz, cudaHostAllocDefault);
#endif
  mt *r;
  cudaMalloc(&r, sizeof(mt));
  memset(d, 0, sizeof(*d)*sz);
  cudaMemset(r, 0, sizeof(*r));
  // warm-up
  write_kernel<<<3*58, 512>>>(d, sz, 1.0f);
  cudaDeviceSynchronize();
  read_kernel<<<3*58, 512>>>(d, sz, 1.0f, r);
  cudaDeviceSynchronize();

  unsigned long long dt = dtime_usec(0);
  write_kernel<<<3*58, 512>>>(d, sz, 1.0f);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "write kernel time: " << dt/(float)USECPSEC << "s" << std::endl;
  dt = dtime_usec(0);
  read_kernel<<<3*58, 512>>>(d, sz, 1.0f, r);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "read kernel time:  " << dt/(float)USECPSEC << "s" << std::endl;


}
# nvcc -o t198 t198.cu
# ./t198
write kernel time: 0.018488s
read kernel time:  0.016026s
# nvcc -o t198 t198.cu -DUSE_HALLOC
# ./t198
write kernel time: 0.011183s
read kernel time:  0.010902s
#

In the cudaHostAlloc case, there is about a 2.5% difference for 128MB, between read and write speeds.

In the cudaHostRegister case, there is about a 13% difference between read and write speeds, and compared to the cudaHostAlloc case it is slower by around 40% or more. (Adding the cudaHostRegisterPortable flag doesn’t seem to affect the data. And cudaHostAllocMapped is implied in any UVA-enabled setting, which is for practical purposes all modern CUDA settings.)

I don’t have an explanation for the various differences. Like I said, based on that data above, I would prefer cudaHostAlloc. Although both are methods to pin memory, I don’t know of any claims by NVIDIA that they result in identical behavior or circumstances. And the underlying mechanics of each are not documented or specified by NVIDIA.

CUDA 12.2, L4 GPU

I thought there could be an intersection with the cudaHostAllocPortable flag in a multi-GPU case, but when I ran the following code on a DGX-H100 I observed a somewhat similar scenario:

$ cat t4.cu
#include <iostream>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}

size_t sz = 1048576ULL*32;

template <typename T>
__global__ void write_kernel(T *d, size_t s, T val){
  for (size_t i = blockIdx.x*blockDim.x+threadIdx.x; i < s; i+=gridDim.x*blockDim.x)
    d[i] = val;
}

template <typename T>
__global__ void read_kernel(const T *d, size_t s, T tval, T *r){
  T val = 0;
  for (size_t i = blockIdx.x*blockDim.x+threadIdx.x; i < s; i+=gridDim.x*blockDim.x)
    val += d[i];
  if (val == tval) *r = val;
}


using mt = float;

int main(){
  cudaSetDevice(0);
  cudaSetDevice(1);
  cudaSetDevice(0);
#ifndef USE_HALLOC
  mt *d = new mt[sz];
  cudaHostRegister(d, sizeof(*d)*sz,  cudaHostRegisterPortable |   cudaHostAllocMapped);
#else
  mt *d;
  cudaHostAlloc(&d, sizeof(*d)*sz, cudaHostAllocDefault);
#endif
  mt *r;
  cudaMalloc(&r, sizeof(mt));
  memset(d, 0, sizeof(*d)*sz);
  cudaMemset(r, 0, sizeof(*r));
  // warm-up
  write_kernel<<<3*58, 512>>>(d, sz, 1.0f);
  cudaDeviceSynchronize();
  read_kernel<<<3*58, 512>>>(d, sz, 1.0f, r);
  cudaDeviceSynchronize();

  unsigned long long dt = dtime_usec(0);
  write_kernel<<<3*58, 512>>>(d, sz, 1.0f);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "write kernel time: " << dt/(float)USECPSEC << "s" << std::endl;
  dt = dtime_usec(0);
  read_kernel<<<3*58, 512>>>(d, sz, 1.0f, r);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "read kernel time:  " << dt/(float)USECPSEC << "s" << std::endl;
}
$ nvcc -o t4 t4.cu
$ ./t4
write kernel time: 0.003816s
read kernel time:  0.004044s
$ nvcc -o t4 t4.cu -DUSE_HALLOC
$ ./t4
write kernel time: 0.002561s
read kernel time:  0.002675s
$

Overall its faster because this machine has a newer PCIE generation.

In the cudaHostAlloc case, the performance difference between read and write is very small, about 4%. In the cudaHostRegister case, the perf difference between read and write is larger, about 6%, and the cudaHostRegister case is overall slower than the cudaHostAlloc case by about 34%.

Thanks for you reply! I saw your code but it seems like it is only executed for one GPU.

I have a more complex system, which needs to create 4 processors for 4 different GPUs. What I was considered is the performance degradation is from PCIE since there are 4 GPUs load and write data through zero-copy technique.

For example, if I use DMA load and write, the performance is a little different for 1,2 or 4 GPUs. But when I use zero-copy, the performance of 4GPUs is about 4 times slower than the performance of 1GPUs only for write operation, the load operation in zero-copy seems like 2 times slower.

I did add another trivial case showing a multi-GPU setup, but only with one GPU active. If you have a more complex scenario, I was not able to deduce precisely what it was from your original posting. Therefore I responded as I did.

It takes a fair amount of effort to write test cases, and I have no ability to read someone else’s mind. If you don’t want to provide a well-specified question, it’s quite possible any response I may make is off target. Good luck!

I think I should describe my problem for more details.

My code is based on PyTorch multiprocessing. Since I need to call CUDA kernels in my code in each subprocess, I need to set the code as ‘forkserver’ start method. Then I pinned memory by using cudaHostRegister with flag cudaHostRegisterPortable | cudaHostAllocMapped. And for each subprocess, I also pinned them because I found if I only pinned them into main process, the subprocess may not get the device pointer and raise illegal memory access error.

In my code, each GPU/subprocess should load data from my pinned memory, compute, and then write back to my pinned memory. That’s all of my code’s pipeline.

In my experiment, I found the time consumption is a little strange.
I should give a table of my profiling result:

              1gpu          2gpu          4gpu

read 1.52ms 1.76ms 2.97ms
write 1.26ms 1.81ms 4.79ms

I am so sorry, I will write a test case for it asap.

In main.py:

import torch.multiprocessing as mp
import torch as th
import time
import sys
from torch.utils.cpp_extension import load
import numpy as np


zerocopy_cpp = load(name='testcase', sources=['data_move/zp_test.cu'], extra_cflags=['-I/usr/local/cuda/include'], extra_cuda_cflags=['-I/usr/local/cuda/include'], extra_ldflags=['-lcuda', '-ldl'])


class ZeroCopy(th.autograd.Function):
    @staticmethod
    def forward(ctx, emb, indices, device):
        output = zerocopy_cpp.zero_copy_call(emb, indices, device)
        return output

    @staticmethod
    def backward(ctx):
        pass

class ZeroWrite(th.autograd.Function):
    @staticmethod
    def forward(ctx, emb, res, indices, device):
        zerocopy_cpp.zero_write(emb, res, indices, device)

    @staticmethod
    def backward(ctx):
        pass

class Pin_Mem(th.autograd.Function):
    @staticmethod
    def forward(ctx, emb):
        zerocopy_cpp.pin_mem(emb)
        return emb

    @staticmethod
    def backward(ctx):
        pass

zero_copy = ZeroCopy.apply
zero_write = ZeroWrite.apply
pin_mem = Pin_Mem.apply


def train_mp(emb, rank):
    pin_mem(emb)
    copy_time=0
    write_time=0
    for i in range(1000):
        indices = th.randint(0, 3000000, (10000,))
        start = time.time()
        data = zero_copy(emb, indices, rank)
        copy_time+=time.time() - start
        grad = data*0.1
        start = time.time()
        zero_write(emb, grad, indices, rank)
        write_time += time.time() - start
        
    print('copy time on {} is {}'.format(rank, copy_time))
    print('write time on {} is {}'.format(rank, write_time))


def main():
    mp.set_start_method('spawn')
    num_gpus = 4

    data = th.rand((3000000, 100))
    pin_mem(data)
    procs = []
    for i in range(num_gpus):
        proc = mp.Process(target=train_mp, args=(data, i))
        procs.append(proc)
        proc.start()
    for proc in procs:
        proc.join()
    
if __name__ == '__main__':
    main()

my cuda code:

#include <torch/extension.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cstdint>
#include <iostream>
#include <bitset>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <errno.h>
#include <error.h>
#include <stdlib.h>
#include <sys/time.h>

using namespace std;

typedef unsigned __int128 uint128_t;
#define abort(ret, errno, ...) error_at_line(ret, errno, __FILE__, __LINE__, \
                                             __VA_ARGS__)
#define CEIL(a, b) (((a)+(b)-1)/(b))

#define CHECK(call)                                                            \
{                                                                              \
    const cudaError_t error = call;                                            \
    if (error != cudaSuccess)                                                  \
    {                                                                          \
        fprintf(stderr, "Error: %s:%d, ", __FILE__, __LINE__);                 \
        fprintf(stderr, "code: %d, reason: %s\n", error,                       \
                cudaGetErrorString(error));                                    \
    }                                                                          \
}


__global__ void index_kernel(float *res, long *indices, float *src, int upper_bound, int dim)
{
    const int idx = blockIdx.x * blockDim.y + threadIdx.y;
    if (idx < upper_bound){
        for(int i=threadIdx.x; i<dim; i+=blockDim.x){
            res[idx * dim + i] = src[indices[idx] * dim + i];
        }
    }
}

torch::Tensor zero_copyH2D(torch::Tensor emb, torch::Tensor indices, int dev_id) {

    cudaSetDevice(dev_id);
    dim3 block(32, 32);
    dim3 grids = (CEIL(indices.size(0), block.y));
    dim3 grids_vec = (CEIL(indices.size(0), block.y*block.x));
    torch::Device dev = indices.device();
    
    long * idx;
    CHECK(cudaMalloc(&idx, sizeof(long) * indices.size(0)));
    CHECK(cudaMemcpy(idx, indices.data_ptr<long>(), sizeof(long) * indices.size(0), cudaMemcpyHostToDevice));
    torch::Tensor res = torch::empty({indices.size(0), emb.size(1)}, torch::TensorOptions(torch::kFloat32).device(torch::kCUDA, dev_id));
    index_kernel<<< grids, block, 0 >>>(res.data_ptr<float>(), idx, emb.data_ptr<float>(), indices.size(0), emb.size(1));
    CHECK(cudaFree(idx));
    cudaDeviceSynchronize();return res;
        
}

__global__ void write_kernel(float *emb, long *indices, float *res, int upper_bound, int dim)
{
    const int idx = blockIdx.x * blockDim.y + threadIdx.y;
    if (idx < upper_bound){
        for(int i=threadIdx.x; i<dim; i+=blockDim.x){
            emb[indices[idx] * dim + i] += res[idx * dim + i];
        }
    }
}

void zero_writeD2H(torch::Tensor emb, torch::Tensor res, torch::Tensor indices, int dev_id){
    cudaSetDevice(dev_id);
    dim3 block(32, 32);
    dim3 grids = (CEIL(indices.size(0), block.y));

    torch::Device dev = indices.device();
    long * idx;
    CHECK(cudaMalloc(&idx, sizeof(long) * indices.size(0)));
    CHECK(cudaMemcpy(idx, indices.data_ptr<long>(), sizeof(long) * indices.size(0), cudaMemcpyHostToDevice));
    
    write_kernel<<< grids, block, 0 >>>(emb.data_ptr<float>(), idx, res.data_ptr<float>(), indices.size(0), emb.size(1));
    cudaDeviceSynchronize();
    
    CHECK(cudaFree(idx));
}



void pin_mem(torch::Tensor emb){
    CHECK(cudaHostRegister(emb.data_ptr<float>(), sizeof(float) * emb.size(0)*emb.size(1), cudaHostRegisterPortable| cudaHostAllocMapped));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
	m.def("zero_copy_call", &zero_copyH2D, "zero copy data read from cpu to gpu");
    m.def("zero_write", &zero_writeD2H, "zero copy data read from cpu to gpu");
    m.def("pin_mem", &pin_mem, "pin memory on CPU");
}

As I tested just now, the performance is like:

For 4 GPUs:
copy time on 0 is 2.006917715072632
write time on 0 is 6.646509170532227
copy time on 2 is 2.8822367191314697
write time on 2 is 7.350091934204102
copy time on 1 is 3.4186103343963623
write time on 1 is 6.843975782394409
copy time on 3 is 3.2107207775115967
write time on 3 is 7.03743577003479

For 2 GPUs:
copy time on 0 is 0.8724493980407715
write time on 0 is 2.2253098487854004
copy time on 1 is 1.565633773803711
write time on 1 is 2.2500340938568115

For 1 GPU:
copy time on 0 is 0.6452534198760986
write time on 0 is 1.0221374034881592

I think this may because of the PCIE limitation for zero-copy load and write data, but I am not sure…

Sorry, I won’t be able to help with pytorch code. This is a CUDA programming forum. You may get better help at a pytorch forum like discuss.pytorch.org. There are NVIDIA experts that patrol that forum. Here is a 4-GPU read and write concurrent test case. The numbers look pretty sane to me:

$ cat t5.cu
#include <iostream>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}

size_t sz = 1048576ULL*32;

template <typename T>
__global__ void write_kernel(T *d, size_t s, T val){
  for (size_t i = blockIdx.x*blockDim.x+threadIdx.x; i < s; i+=gridDim.x*blockDim.x)
    d[i] = val;
}

template <typename T>
__global__ void read_kernel(const T *d, size_t s, T tval, T *r){
  T val = 0;
  for (size_t i = blockIdx.x*blockDim.x+threadIdx.x; i < s; i+=gridDim.x*blockDim.x)
    val += d[i];
  if (val == tval) *r = val;
}


using mt = float;
const int nblk = 132*4;
int main(){
#ifndef USE_HALLOC
  mt *d = new mt[sz];
  cudaHostRegister(d, sizeof(*d)*sz,  cudaHostRegisterPortable |   cudaHostAllocMapped);
#else
  mt *d;
  cudaHostAlloc(&d, sizeof(*d)*sz, cudaHostAllocDefault);
#endif
  mt *r;
  cudaHostAlloc(&r, sizeof(mt), cudaHostAllocDefault);
  memset(d, 0, sizeof(*d)*sz);
  memset(r, 0, sizeof(*r));
  unsigned long long dt;
  // warm-up
  for (int i = 0; i < 4; i++) {
    cudaSetDevice(i);
    write_kernel<<<nblk, 512>>>(d, sz, 1.0f);
    read_kernel<<<nblk, 512>>>(d, sz, 1.0f, r);
    cudaDeviceSynchronize();}
  cudaSetDevice(0);
  // test writes
  dt = dtime_usec(0);
  write_kernel<<<nblk, 512>>>(d, sz, 1.0f);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "write kernel time: " << dt/(float)USECPSEC << "s" << std::endl;
  // test reads
  dt = dtime_usec(0);
  read_kernel<<<nblk, 512>>>(d, sz, 1.0f, r);
  cudaDeviceSynchronize();
  dt = dtime_usec(dt);
  std::cout << "read kernel time:  " << dt/(float)USECPSEC << "s" << std::endl;
  // concurrent writes
  dt = dtime_usec(0);
  for (int i = 0; i < 4; i++) {
    cudaSetDevice(i);
    write_kernel<<<nblk, 512>>>(d, sz, 1.0f);
  }
  for (int i = 0; i < 4; i++) {
    cudaSetDevice(i);
    cudaDeviceSynchronize();}
  dt = dtime_usec(dt);
  std::cout << "conc write time:  " << dt/(float)USECPSEC << "s" << std::endl;
  // concurrent reads
  dt = dtime_usec(0);
  for (int i = 0; i < 4; i++) {
    cudaSetDevice(i);
    read_kernel<<<nblk, 512>>>(d, sz, 1.0f, r);
  }
  for (int i = 0; i < 4; i++) {
    cudaSetDevice(i);
    cudaDeviceSynchronize();}
  dt = dtime_usec(dt);
  std::cout << "conc read time:  " << dt/(float)USECPSEC << "s" << std::endl;
}
$ nvcc -o t5 t5.cu
$ ./t5
write kernel time: 0.004772s
read kernel time:  0.004104s
conc write time:  0.007995s
conc read time:  0.007706s
$

variation seems to be within 10% read vs. write.
DGX-H100, CUDA 12.2

Thank you!

I did not read this lengthy thread, but in general, when there are simultaneous transfers from and to multiple GPUs the configuration of the host system can become important. With a PCIe gen4 x16 connection, bi-directional bandwidth per GPU is 2 x 25 GB/sec. So with four such GPUs transferring simultaneously, much system memory bandwidth is required to maintain full-speed operation: an eight-channel DDR4 system memory would just barely be sufficient.

There may be other choke points in the system. For example, does the CPU of the host system provide >= 64 PCIe gen4 lanes to support four PCIe gen x16 connections? If not, there will either be multiplexing of PCIe lanes or reduced PCIe lane count per GPU, which can cause performance artifacts.

Also, if this is a system with more than one CPU socket, or with a CPU internally split across multiple chiplet, one needs to pay close attention to processor and memory affinity to make sure GPUs always “talk” to the nearest processor and memory controller, otherwise there could be choke points due to the inter-processor interconnect (either between the sockets or internally between the core complexes).

When you look at NVIDIA’s highly optimized DGX systems, you will find that their system architecture is based on very expensive highest-end system platforms, such as dual Xeon Platinums (cost per CPU > $10K) with 80 PCIe lanes and eight-channel DDR5 memory each. I suspect therein lies the difference between OP’s and Rovert Crovella’s observations.