#include "utils/cudadrv.hpp"
#include "utils/cudart.hpp"
#include <cstdio>
#include <numeric>
#include <vector>
template <typename T>
__global__ void vec_add_one(T *arr, uint32_t n) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
for (uint32_t i = tid; i < n; i += gridDim.x * blockDim.x) {
arr[i] += i;
}
}
struct ScopedContext {
explicit ScopedContext(CUcontext ctx) {
CUDADRV_CHECK(cuCtxSetCurrent(ctx));
}
~ScopedContext() {
CUDADRV_CHECK(cuCtxSetCurrent(nullptr));
}
};
// A possible cuMalloc implementation using CUDA driver API
struct CudaMmapProxy {
std::vector<CUcontext> cu_contexts_;
std::vector<uint32_t> devices_;
explicit CudaMmapProxy() {
CUDADRV_CHECK(cuInit(0));
int num_devices = 0;
CUDADRV_CHECK(cuDeviceGetCount(&num_devices));
devices_.reserve(num_devices);
std::iota(devices_.begin(), devices_.end(), 0);
for (uint32_t i = 0; i < num_devices; ++i) {
CUcontext ctx;
CUDADRV_CHECK(cuCtxCreate(&ctx, 0, i));
cu_contexts_.push_back(ctx);
}
int access;
for (uint32_t i = 0; i < num_devices; ++i) {
CUDADRV_CHECK(cuCtxSetCurrent(cu_contexts_[i]));
for (uint32_t j = 0; j < num_devices; ++j) {
if (i == j) {
continue;
}
CUDADRV_CHECK(cuDeviceCanAccessPeer(&access, i, j));
if (access) {
std::printf("Peer access enabled between %d and %d\n", i, j);
CUDADRV_CHECK(cuCtxEnablePeerAccess(cu_contexts_[j], 0));
}
}
// CUDADRV_CHECK(cuCtxSetCurrent(nullptr));
}
}
~CudaMmapProxy() {
for (auto ctx : cu_contexts_) {
CUDADRV_CHECK(cuCtxDestroy(ctx));
}
}
CUcontext context(uint32_t i) {
return cu_contexts_[i];
}
};
struct CudaMmapAllocation {
size_t size_{0}, align_size_{0};
CUmemGenericAllocationHandle handle_;
CUmemAccessDesc access_desc_;
CUcontext ctx_{nullptr};
CUdeviceptr device_ptr_;
CudaMmapAllocation(CUcontext ctx, size_t size) {
int device_id;
ScopedContext _{ctx};
CUDADRV_CHECK(cuCtxGetDevice(&device_id));
ctx_ = ctx;
CUmemAllocationProp prop{};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device_id;
access_desc_.location = prop.location;
access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUDADRV_CHECK(cuMemGetAllocationGranularity(&align_size_, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
size_ = ((size + align_size_ - 1) / align_size_) * align_size_;
CUDADRV_CHECK(cuMemAddressReserve(&device_ptr_, size_, 0ULL, 0ULL, 0ULL));
CUDADRV_CHECK(cuMemCreate(&handle_, size_, &prop, 0));
CUDADRV_CHECK(cuMemMap(device_ptr_, size_, 0ULL, handle_, 0ULL));
CUDADRV_CHECK(cuMemSetAccess(device_ptr_, size_, &access_desc_, 1ULL));
}
~CudaMmapAllocation() {
if (size_ > 0) {
ScopedContext _{ctx_};
// CUDADRV_CHECK(cuMemAddressFree(device_ptr_, size_));
CUDADRV_CHECK(cuMemRelease(handle_));
}
}
template <typename T = void>
T *ptr() {
return reinterpret_cast<T *>(device_ptr_);
}
void migrate(CUcontext ctx) {
// Migrate the managed memory from current device to `to_device`
CUmemAllocationProp old_prop{};
CUDADRV_CHECK(cuMemGetAllocationPropertiesFromHandle(&old_prop, handle_));
// CUDADRV_CHECK(cuCtxSetCurrent(ctx));
CUmemAllocationProp new_prop{};
new_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
new_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
CUmemAccessDesc new_access_desc;
new_access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUdeviceptr new_ptr;
CUmemGenericAllocationHandle new_handle;
int old_device, new_device;
CUDADRV_CHECK(cuCtxGetDevice(&old_device));
{
ScopedContext _{ctx};
CUDADRV_CHECK(cuCtxGetDevice(&new_device));
new_prop.location.id = new_device;
new_access_desc.location = new_prop.location;
// Allocate physical memory on the target device
CUDADRV_CHECK(cuMemCreate(&new_handle, size_, &new_prop, 0));
CUDADRV_CHECK(cuMemAddressReserve(&new_ptr, size_, 0ULL, 0ULL, 0ULL));
// Map to a temporary VA range
CUDADRV_CHECK(cuMemMap(new_ptr, size_, 0ULL, new_handle, 0ULL));
CUDADRV_CHECK(cuMemSetAccess(new_ptr, size_, &new_access_desc, 1ULL));
// Copy from source device to target device
CUDADRV_CHECK(cuMemcpyPeer(new_ptr, ctx, device_ptr_, ctx_, size_));
}
// Remap
CUDADRV_CHECK(cuMemUnmap(device_ptr_, size_));
// CUDADRV_CHECK(cuMemAddressFree(device_ptr_, size_));
CUDADRV_CHECK(cuMemRelease(handle_));
CUDADRV_CHECK(cuCtxSetCurrent(ctx));
CUDADRV_CHECK(cuMemUnmap(new_ptr, new_handle));
CUDADRV_CHECK(cuMemAddressFree(new_ptr, size_));
// CUDADRV_CHECK(cuMemAddressReserve(&device_ptr_, size_, 0ULL, 0ULL, 0ULL));
CUDADRV_CHECK(cuMemMap(device_ptr_, size_, 0ULL, new_handle, 0ULL));
CUDADRV_CHECK(cuMemSetAccess(device_ptr_, size_, &new_access_desc, 1ULL));
handle_ = new_handle;
ctx_ = ctx;
// device_ptr_ = new_ptr;
access_desc_ = new_access_desc;
}
};
uint32_t pointer_device(const void *ptr) {
cudaPointerAttributes ptr_attr{};
CUDART_CHECK(cudaPointerGetAttributes(&ptr_attr, ptr));
return ptr_attr.device;
}
int main() {
auto proxy = CudaMmapProxy();
size_t size = 2048;
auto alloc = CudaMmapAllocation(proxy.context(0), size);
auto *device_ptr = alloc.ptr<int>();
CUDADRV_CHECK(cuCtxSetCurrent(proxy.context(0)));
CUDART_CHECK(cudaMemset(device_ptr, 0, size));
vec_add_one<<<1, 256>>>(device_ptr, size / sizeof(int));
std::vector<int> host_data(size / sizeof(int), 0);
CUDART_CHECK(cudaMemcpy(host_data.data(), device_ptr, size, cudaMemcpyDeviceToHost));
CUDART_CHECK(cudaDeviceSynchronize());
// for (uint32_t i = 0; i < host_data.size(); ++i) {
// std::printf("data[%d] = %d\n", i, host_data[i]);
// }
std::printf("Before migrate: %d\n", pointer_device(device_ptr));
uint32_t new_device = 1;
alloc.migrate(proxy.context(new_device));
std::printf("After migrate: %d\n", pointer_device(device_ptr));
int cu_device;
CUDADRV_CHECK(cuCtxGetDevice(&cu_device));
std::printf("Current context device: %d\n", cu_device);
// device_ptr = alloc.ptr<int>();
vec_add_one<<<1, 256>>>(device_ptr, size / sizeof(int));
CUDART_CHECK(cudaMemcpy(host_data.data(), device_ptr, size, cudaMemcpyDeviceToHost));
CUDART_CHECK(cudaDeviceSynchronize());
// for (uint32_t i = 0; i < host_data.size(); ++i) {
// std::printf("data[%d] = %d\n", i, host_data[i]);
// }
return 0;
}
Here is the whole code snippet.