Hi all,
I worked a bit with CUDA 12.0 with the new device launched graph feature and made some nice proof of concept with it when came out with the device graph launch feature.
I wanted to try it with some feature of CUDA 12.3 but my test no longer passes.
I ran some investigation and I have some insight about the issue.
Here is a description :
This is produced on an Ubuntu 22.04 server equipped with an Nvidia A100 40gb but I also reproduced it with a DGX server with V100.
Please find here a really sample code which just create a cuBLAS node within a graph and then upload it to the device.
/*
instructions:
#> nvcc gpuGraphInstantiation_bug.cu -lcublas -o gpuGraphInstantiation_bug
#> ./gpuGraphInstantiation_bug 4096 12500 -> test fails
#> ./gpuGraphInstantiation_bug 1024 4096 -> test pass! (On A100, depending on the underlying BLAS choice this result may differ)
*/
#include <iostream>
#include <fstream>
#include <vector>
#include <unistd.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <thread>
#include "iostream"
////////////////////////////////////////////////////////////////////////
// ERROR check tools
template <typename T>
void check(T result, char const *const func, const char *const file,
int const line) {
if (result) {
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
static_cast<unsigned int>(result), cudaGetErrorName(result), func);
exit(EXIT_FAILURE);
}
}
// This will output the proper CUDA error strings in the event
// that a CUDA host call returns an error
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)
inline void __getLastCudaError(const char *errorMessage, const char *file,
const int line) {
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr,
"%s(%i) : getLastCudaError() CUDA error :"
" %s : (%d) %s.\n",
file, line, errorMessage, static_cast<int>(err),
cudaGetErrorString(err));
exit(EXIT_FAILURE);
}
}
//////////////////////////////////////////////////////////////////////////////
int main(int argc, char *argv[])
{
if (argc < 3)
{
std::cout << "usage : ./binary matrix_height matrix width" << std::endl;
return 1;
}
int m = std::stoi(argv[1]);
int n = std::stoi(argv[2]);
cudaStream_t stream;
cudaStreamCreate(&stream);
float *d_matrixData;
float *d_vectorData;
float *d_res;
cublasHandle_t d_handle;
cublasStatus_t stat = cublasCreate(&d_handle);
if (stat != CUBLAS_STATUS_SUCCESS)
printf("CUBLAS initialization failed\n");
cublasSetStream(d_handle, stream);
cudaMalloc(&d_matrixData, sizeof(float) * m * n);
cudaMalloc(&d_vectorData, n * sizeof(float));
cudaMalloc(&d_res, sizeof(float) * m);
cudaGraphExec_t computeGraph;
cudaGraph_t cuGraph;
float alpha = 1.f;
float beta = 0.f;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
cublasStatus_t err = cublasSgemv(d_handle, CUBLAS_OP_N, m, n, &alpha,
d_matrixData, m, d_vectorData, 1, &beta, d_res, 1);
checkCudaErrors(cudaStreamEndCapture(stream, &cuGraph));
checkCudaErrors(cudaGraphInstantiate(&computeGraph, cuGraph, cudaGraphInstantiateFlagDeviceLaunch));
cudaGraphDebugDotPrint(cuGraph, "debugGraph.dot", 0);
checkCudaErrors(cudaGraphDestroy(cuGraph));
checkCudaErrors(cudaGraphUpload(computeGraph, stream));
std::cout << "Graph instantiation was a success!" << std::endl;
}
Here are some information about this bug :
-
This issue is produced with the Driver version 545.23.08 and CUDA 12.3. As I mentioned earlier however, this issue does not occur if I link my library with CUDA 12.0 at runtime.
-
The issue appears starting from CUDA 12.1.
-
removing the cudaGraphInstantiateFlagsDeviceLaunch flag fixes the issue when calling cudaGraphInstantiate
-
Depends on the matrix size for the bug to appear
Here comes the graph generated when the graph instantiation occurs (done by removing the problematic flag).
And one with a smaller matrix that pass through the test, regardless the CUDA version
In my understanding, this seems to be strongly correlated to memory allocation nodes that are no longer supported. This is quite troublesome on my side so I would take any insight to fix or to bypass it on the user side.
Best regards.