Stream capture of cublas gemm

I am trying to capture graph associated with a cublas sgemm call (cuda 11.5 on Ampere). I modified the matrixMultiply function in the matrixMulCUBLAS CUDA sample as below. I get an error cudaErrorInvalidDeviceFunction for the function captured by stream capture, see line tagged as FAILURE. What am I doing wrong?

////////////////////////////////////////////////////////////////////////////////
//! Run a simple test matrix multiply using CUBLAS
////////////////////////////////////////////////////////////////////////////////
int matrixMultiply(int argc, char **argv, int devID, sMatrixSize &matrix_size) {
  cudaDeviceProp deviceProp;

  checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));

  int block_size = 32;

  // set seed for rand()
  srand(2006);

  // allocate host memory for matrices A and B
  unsigned int size_A = matrix_size.uiWA * matrix_size.uiHA;
  unsigned int mem_size_A = sizeof(float) * size_A;
  float *h_A = (float *)malloc(mem_size_A);
  unsigned int size_B = matrix_size.uiWB * matrix_size.uiHB;
  unsigned int mem_size_B = sizeof(float) * size_B;
  float *h_B = (float *)malloc(mem_size_B);

  // set seed for rand()
  srand(2006);

  // initialize host memory
  randomInit(h_A, size_A);
  randomInit(h_B, size_B);

  // allocate device memory
  float *d_A, *d_B, *d_C;
  unsigned int size_C = matrix_size.uiWC * matrix_size.uiHC;
  unsigned int mem_size_C = sizeof(float) * size_C;

  // allocate host memory for the result
  float *h_C = (float *)malloc(mem_size_C);
  float *h_CUBLAS = (float *)malloc(mem_size_C);

  checkCudaErrors(cudaMalloc((void **)&d_A, mem_size_A));
  checkCudaErrors(cudaMalloc((void **)&d_B, mem_size_B));
  checkCudaErrors(cudaMemcpy(d_A, h_A, mem_size_A, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMemcpy(d_B, h_B, mem_size_B, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMalloc((void **)&d_C, mem_size_C));

  // setup execution parameters
  dim3 threads(block_size, block_size);
  dim3 grid(matrix_size.uiWC / threads.x, matrix_size.uiHC / threads.y);

  // create and start timer
  printf("Computing result using CUBLAS...");

  // execute the kernel
  int nIter = 30;

  // CUBLAS version 2.0
  {
    const float alpha = 1.0f;
    const float beta = 0.0f;
    cublasHandle_t handle;
    cudaStream_t stream;
    cudaEvent_t start, stop;
    cudaGraph_t graph;
    cudaGraphExec_t graphExec = NULL;
    cudaKernelNodeParams NodeParams;
    cudaMemsetParams MemsetParams;
    std::vector<cudaGraphNode_t> vnodes;
    
    checkCudaErrors(cublasCreate(&handle));
    checkCudaErrors(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE));
    checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
    cublasSetStream(handle,stream);
    checkCudaErrors(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));

    // Perform warmup operation with cublas
    checkCudaErrors(cublasSgemm(
        handle, CUBLAS_OP_N, CUBLAS_OP_N, matrix_size.uiWB, matrix_size.uiHA,
        matrix_size.uiWA, &alpha, d_B, matrix_size.uiWB, d_A, matrix_size.uiWA,
        &beta, d_C, matrix_size.uiWB));
    checkCudaErrors(cudaStreamEndCapture(stream, &graph));

    size_t numNodes;
    checkCudaErrors(cudaGraphGetNodes(graph, NULL, &numNodes));
    vnodes.resize(numNodes);
    checkCudaErrors(cudaGraphGetNodes(graph, vnodes.data(), &numNodes));
    std::vector<cudaGraphNodeType> nodeType(numNodes);

    for(size_t i=0; i<numNodes; i++) {
      checkCudaErrors(cudaGraphNodeGetType (vnodes[i], &nodeType[i]));
      switch(nodeType[i]) {
         case cudaGraphNodeTypeMemset:
           checkCudaErrors(cudaGraphMemsetNodeGetParams (vnodes[i], &MemsetParams));
           break;
         case cudaGraphNodeTypeKernel:
	   checkCudaErrors(cudaGraphKernelNodeGetParams(vnodes[i], &NodeParams));
          // FAILURE  code=98 (cudaErrorInvalidDeviceFunction)
	   break;
      }
   }


    cudaGraphExecUpdateResult updateResult_out;
    checkCudaErrors(cudaGraphExecUpdate(graphExec, graph, NULL, &updateResult_out));
    if (updateResult_out != cudaGraphExecUpdateSuccess) {
        if (graphExec != NULL) {
          checkCudaErrors(cudaGraphExecDestroy(graphExec));
        }
        printf("graph update failed with error - %d\n", updateResult_out);
        checkCudaErrors(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
    }
    checkCudaErrors(cudaGraphLaunch(graphExec, stream));
    checkCudaErrors(cudaStreamSynchronize(stream));

    // Allocate CUDA events that we'll use for timing
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));

    // Record the start event
    checkCudaErrors(cudaEventRecord(start, NULL));

    for (int j = 0; j < nIter; j++) {
      // note cublas is column primary!
      // need to transpose the order
      checkCudaErrors(cudaGraphLaunch(graphExec, stream));
      checkCudaErrors(cudaStreamSynchronize(stream));
     }

    printf("done.\n");

    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, NULL));

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
    float msecPerMatrixMul = msecTotal / nIter;
    double flopsPerMatrixMul = 2.0 * (double)matrix_size.uiHC *
                               (double)matrix_size.uiWC *
                               (double)matrix_size.uiHB;
    double gigaFlops =
        (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
    printf("Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
           gigaFlops, msecPerMatrixMul, flopsPerMatrixMul);

    // copy result from device to host
    checkCudaErrors(
        cudaMemcpy(h_CUBLAS, d_C, mem_size_C, cudaMemcpyDeviceToHost));

    // Destroy the handle
    checkCudaErrors(cublasDestroy(handle));
    checkCudaErrors(cudaStreamDestroy(stream));

  }

  // compute reference solution
  printf("Computing result using host CPU...");
  float *reference = (float *)malloc(mem_size_C);
  matrixMulCPU(reference, h_A, h_B, matrix_size.uiHA, matrix_size.uiWA,
               matrix_size.uiWB);
  printf("done.\n");

  // check result (CUBLAS)
  bool resCUBLAS = sdkCompareL2fe(reference, h_CUBLAS, size_C, 1.0e-6f);

  if (resCUBLAS != true) {
    printDiff(reference, h_CUBLAS, matrix_size.uiWC, matrix_size.uiHC, 100,
              1.0e-5f);
  }

  printf("Comparing CUBLAS Matrix Multiply with CPU results: %s\n",
         (true == resCUBLAS) ? "PASS" : "FAIL");

  printf(
      "\nNOTE: The CUDA Samples are not meant for performance measurements. "
      "Results may vary when GPU Boost is enabled.\n");

  // clean up memory
  free(h_A);
  free(h_B);
  free(h_C);
  free(reference);
  checkCudaErrors(cudaFree(d_A));
  checkCudaErrors(cudaFree(d_B));
  checkCudaErrors(cudaFree(d_C));

  if (resCUBLAS == true) {
    return EXIT_SUCCESS;  // return value = 1
  } else {
    return EXIT_FAILURE;  // return value = 0
  }
}

please format your code correctly

a possible set of steps:

  • edit your post, by clicking on the pencil icon below it
  • select the code
  • press the </> button at the top of the edit window
  • save your changes
  1. Your usage of CUBLAS_POINTER_MODE is incorrect.
  2. I’m not sure what you are doing with all of the node lookup. I don’t typically expect to see that when using stream capture.
  3. Your instantiation call sequence after the capture is not correct.

This throws no errors for me:


$ cat t2029.cu
#include <cublas_v2.h>
#include <cstdio>
#include <helper_cuda.h>
#include <vector>

int main(int argc, char *argv[]){

////////////////////////////////////////////////////////////////////////////////
//! Run a simple test matrix multiply using CUBLAS
////////////////////////////////////////////////////////////////////////////////
  int devID = 0;

  cudaDeviceProp deviceProp;

  checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));

  int block_size = 32;
  int WA, HA, WB, HB, WC, HC;
  WA = HA = WB = HB = WC = HC = 512;

  // set seed for rand()
  srand(2006);

  // allocate host memory for matrices A and B
  unsigned int size_A = WA * HA;
  unsigned int mem_size_A = sizeof(float) * size_A;
  float *h_A = (float *)malloc(mem_size_A);
  unsigned int size_B = WB * HB;
  unsigned int mem_size_B = sizeof(float) * size_B;
  float *h_B = (float *)malloc(mem_size_B);

  // set seed for rand()
  srand(2006);

  // initialize host memory

  // allocate device memory
  float *d_A, *d_B, *d_C;
  unsigned int size_C = WC * HC;
  unsigned int mem_size_C = sizeof(float) * size_C;

  // allocate host memory for the result
  float *h_C = (float *)malloc(mem_size_C);
  float *h_CUBLAS = (float *)malloc(mem_size_C);

  checkCudaErrors(cudaMalloc((void **)&d_A, mem_size_A));
  checkCudaErrors(cudaMalloc((void **)&d_B, mem_size_B));
  checkCudaErrors(cudaMemcpy(d_A, h_A, mem_size_A, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMemcpy(d_B, h_B, mem_size_B, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMalloc((void **)&d_C, mem_size_C));

  // setup execution parameters
  dim3 threads(block_size, block_size);
  dim3 grid(WC / threads.x, HC / threads.y);

  // create and start timer
  printf("Computing result using CUBLAS...");

  // execute the kernel
  int nIter = 30;

  // CUBLAS version 2.0
  {
    const float alpha = 1.0f;
    const float beta = 0.0f;
    cublasHandle_t handle;
    cudaStream_t stream;
    cudaEvent_t start, stop;
    cudaGraph_t graph;
    cudaGraphExec_t graphExec = NULL;
    cudaKernelNodeParams NodeParams;
    cudaMemsetParams MemsetParams;
    std::vector<cudaGraphNode_t> vnodes;

    checkCudaErrors(cublasCreate(&handle));
    //checkCudaErrors(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE));
    checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
    cublasSetStream(handle,stream);
    checkCudaErrors(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));

    // Perform warmup operation with cublas
    checkCudaErrors(cublasSgemm(
        handle, CUBLAS_OP_N, CUBLAS_OP_N, WB, HA,
        WA, &alpha, d_B, WB, d_A, WA,
        &beta, d_C, WB));
//    checkCudaErrors(cudaDeviceSynchronize());
//    return 0;
    checkCudaErrors(cudaStreamEndCapture(stream, &graph));
#if 0
    size_t numNodes;
    checkCudaErrors(cudaGraphGetNodes(graph, NULL, &numNodes));
    vnodes.resize(numNodes);
    checkCudaErrors(cudaGraphGetNodes(graph, vnodes.data(), &numNodes));
    std::vector<cudaGraphNodeType> nodeType(numNodes);

    for(size_t i=0; i<numNodes; i++) {
      checkCudaErrors(cudaGraphNodeGetType (vnodes[i], &nodeType[i]));
      switch(nodeType[i]) {
         case cudaGraphNodeTypeMemset:
           checkCudaErrors(cudaGraphMemsetNodeGetParams (vnodes[i], &MemsetParams));
           break;
         case cudaGraphNodeTypeKernel:
           checkCudaErrors(cudaGraphKernelNodeGetParams(vnodes[i], &NodeParams));
          // FAILURE  code=98 (cudaErrorInvalidDeviceFunction)
           break;
      }
   }


    cudaGraphExecUpdateResult updateResult_out;
    checkCudaErrors(cudaGraphExecUpdate(graphExec, graph, NULL, &updateResult_out));
    if (updateResult_out != cudaGraphExecUpdateSuccess) {
        if (graphExec != NULL) {
          checkCudaErrors(cudaGraphExecDestroy(graphExec));
        }
        printf("graph update failed with error - %d\n", updateResult_out);
#endif
        checkCudaErrors(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
#if 0
    }
#endif
    checkCudaErrors(cudaGraphLaunch(graphExec, stream));
    checkCudaErrors(cudaStreamSynchronize(stream));

    // Allocate CUDA events that we'll use for timing
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));

    // Record the start event
    checkCudaErrors(cudaEventRecord(start, NULL));

    for (int j = 0; j < nIter; j++) {
      // note cublas is column primary!
      // need to transpose the order
      checkCudaErrors(cudaGraphLaunch(graphExec, stream));
      checkCudaErrors(cudaStreamSynchronize(stream));
     }

    printf("done.\n");

    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, NULL));

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
    float msecPerMatrixMul = msecTotal / nIter;
    double flopsPerMatrixMul = 2.0 * (double)HC *
                               (double)WC *
                               (double)HB;
    double gigaFlops =
        (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
    printf("Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
           gigaFlops, msecPerMatrixMul, flopsPerMatrixMul);

    // copy result from device to host
    checkCudaErrors(
        cudaMemcpy(h_CUBLAS, d_C, mem_size_C, cudaMemcpyDeviceToHost));

    // Destroy the handle
    checkCudaErrors(cublasDestroy(handle));
    checkCudaErrors(cudaStreamDestroy(stream));

  }

  // compute reference solution
  printf("Computing result using host CPU...");
  float *reference = (float *)malloc(mem_size_C);
  printf("done.\n");

  // check result (CUBLAS)


  printf(
      "\nNOTE: The CUDA Samples are not meant for performance measurements. "
      "Results may vary when GPU Boost is enabled.\n");

  // clean up memory
  free(h_A);
  free(h_B);
  free(h_C);
  free(reference);
  checkCudaErrors(cudaFree(d_A));
  checkCudaErrors(cudaFree(d_B));
  checkCudaErrors(cudaFree(d_C));

    return EXIT_SUCCESS;  // return value = 1
}
$ nvcc -o t2029 t2029.cu -lcublas -I/usr/local/cuda/samples/common/inc
t2029.cu(72): warning: variable "MemsetParams" was declared but never referenced

$ ./t2029
Computing result using CUBLAS...done.
Performance= 6471.19 GFlop/s, Time= 0.041 msec, Size= 268435456 Ops
Computing result using host CPU...done.

NOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.
$

My intended code has some graph restructuring logic. The example I shared is just a reproducer, not my intended full code. I need to capture the graph node and restructure dependencies.
The example I shared show where failure occurs.
In Robert’s code, I do not have a chance to restructure the graph dependencies.
Could some of cublas kernel execute eagerly? and as such graph capture is not effectual?

Is it technically illegal to access a kernelNodeParam of a stream-captured CUDAGraph?

I face the following situation:

CASE 1: CudaGraph only has cuBLAS kernel (stream captured):

In the above situation, trying to do cudaGraphKernelNodeGetParams on the kernel node gives

Cuda error in function 'cudaGraphKernelNodeGetParams(kernelNode[i], &kernelNodeParams[i])' file 'cublas-matmul-cudagraph-stream-capture-post.cu' in line 140 : invalid device function.

that @cudevelop mentions in his code.

CASE 2: Same old code, but added a synthetic vecAdd kernel just before the cuBLAS matrix multiply call (stream captured):

In this case, cudaGraphKernelNodeGetParams is successful on the first vecAdd kernel node. But, when I try to do the same for the volta_sgemm_.. kernel node added by cuBLAS, I get the same error:

Cuda error in function 'cudaGraphKernelNodeGetParams(kernelNode[i], &kernelNodeParams[i])' file 'cublas-matmul-cudagraph-stream-capture-post.cu' in line 140 : invalid device function.

While the exact same line of code passes for the vecAdd kernel node. Does it have something to do with the fact that the implementation of cuBLAS is closed-source?


Example Code:

#include <iostream>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <vector>

#define CUDA_SAFECALL(call)                                                 \
    {                                                                       \
        call;                                                               \
        cudaError err = cudaGetLastError();                                 \
        if (cudaSuccess != err) {                                           \
            fprintf(                                                        \
                stderr,                                                     \
                "Cuda error in function '%s' file '%s' in line %i : %s.\n", \
                #call, __FILE__, __LINE__, cudaGetErrorString(err));        \
            fflush(stderr);                                                 \
            exit(EXIT_FAILURE);                                             \
        }                                                                   \
    }

#define CUBLAS_CALL(func)                                                      \
{                                                                              \
    cublasStatus_t s = (func);                                                 \
    if(s != CUBLAS_STATUS_SUCCESS) {                                           \
        std::cerr << "cuBLAS Error: " << s << " at line " << __LINE__ << std::endl; \
        exit(EXIT_FAILURE);                                                    \
    }                                                                          \
}

const int N = 1024;  // Dimension of the square matrix (N x N)

// Helper function to initialize matrices
void initializeMatrix(float* matrix, int n, float value) {
    for (int i = 0; i < n * n; ++i) {
        matrix[i] = value;
    }
}

void printMatrix(float* matrix, int n) {
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            std::cout << matrix[i * n + j] << " ";
        }
        std::cout << std::endl;
    }
}

// CUDA kernel that adds two vectors, each thread handles one element of c
__global__ void vecAdd(float *a, float *b, float *c, int n) {
    int id = blockIdx.x * blockDim.x + threadIdx.x;
    if (id < n) c[id] = a[id] + b[id];
}

int main() {
    // Initialize cuBLAS
    cublasHandle_t handle;
    CUBLAS_CALL(cublasCreate(&handle));

    // Create a CUDA stream
    cudaStream_t stream;
    CUDA_SAFECALL(cudaStreamCreate(&stream));

    // Set cuBLAS to use the created stream
    CUBLAS_CALL(cublasSetStream(handle, stream));

    // Allocate memory for input and output matrices on host
    // A * B = C
    float* h_A = (float*) malloc(N * N * sizeof(float));
    float* h_B = (float*) malloc(N * N * sizeof(float));
    float* h_C = (float*) malloc(N * N * sizeof(float));
    initializeMatrix(h_A, N, 1.0f);  // Matrix A with all elements 1.0
    initializeMatrix(h_B, N, 2.0f);  // Matrix B with all elements 2.0

    // Allocate memory on device
    float *d_A, *d_B, *d_C;
    CUDA_SAFECALL(cudaMalloc((void**)&d_A, N * N * sizeof(float)));
    CUDA_SAFECALL(cudaMalloc((void**)&d_B, N * N * sizeof(float)));
    CUDA_SAFECALL(cudaMalloc((void**)&d_C, N * N * sizeof(float)));

    // Copy data from host to device (Async using stream)
    CUDA_SAFECALL(cudaMemcpyAsync(d_A, h_A, N * N * sizeof(float), cudaMemcpyHostToDevice, stream));
    CUDA_SAFECALL(cudaMemcpyAsync(d_B, h_B, N * N * sizeof(float), cudaMemcpyHostToDevice, stream));

    // Create CUDA Graph
    cudaGraph_t graph;
    cudaGraphExec_t instance;
    
    // Begin CUDA Graph capture
    cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

# if 0
    //testing a vecAdd kernel
    int blockSize, gridSize;
    blockSize = 1024;
    gridSize = (int)ceil((float)N / blockSize);
    vecAdd<<<gridSize, blockSize, 0, stream>>>(d_A, d_B, d_C, N);
#endif

    // Set up GEMM (Matrix Multiplication)
    float alpha = 1.0f, beta = 0.0f;

    // cublasSgemm performs matrix multiplication C = alpha * A * B + beta * C
    // where A, B, C are NxN matrices in column-major order.
    CUBLAS_CALL(cublasSgemm(handle,
                            CUBLAS_OP_N, CUBLAS_OP_N,
                            N, N, N,
                            &alpha,
                            d_A, N,
                            d_B, N,
                            &beta,
                            d_C, N));
    // End CUDA Graph capture
    cudaStreamEndCapture(stream, &graph);      

    // Instantiate the graph
    cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
    
    // Get the number of nodes to allocate an array of nodes
    cudaGraphNode_t* nodes=NULL;
    size_t numNodes=0;
    CUDA_SAFECALL((cudaGraphGetNodes(graph, nodes, &numNodes)));
    nodes = (cudaGraphNode_t*)malloc(numNodes*sizeof(cudaGraphNode_t));
    CUDA_SAFECALL((cudaGraphGetNodes(graph, nodes, &numNodes)));
    
    std::vector<cudaGraphNode_t> kernelNode;
    
    printf("Number of nodes: %lu\n", numNodes);
    for(int i = 0; i < numNodes; i++){
        cudaGraphNodeType type;
        cudaGraphNodeGetType(nodes[i], &type);
        if (type == cudaGraphNodeTypeKernel){
            printf("Node %d is kernel Node\n", i+1);
            kernelNode.push_back(nodes[i]);
        }
    }

    cudaKernelNodeParams kernelNodeParams[kernelNode.size()] = {0};
    
    // Get the kernel node parameters
    for (int i = 0; i < kernelNode.size(); i++){
        CUDA_SAFECALL(cudaGraphKernelNodeGetParams(kernelNode[i], &kernelNodeParams[i]));
    }
    
    // Launch the instantiated graph
    cudaGraphLaunch(instance, stream);

    // Copy result from device to host (Async using stream)
    CUDA_SAFECALL(cudaMemcpyAsync(h_C, d_C, N * N * sizeof(float), cudaMemcpyDeviceToHost, stream));

    // Wait for the stream to finish all operations
    CUDA_SAFECALL(cudaStreamSynchronize(stream));

    // // Print the result matrix C
    // std::cout << "Matrix C (Result of A * B):" << std::endl;
    // printMatrix(h_C, N);

    // Cleanup
    CUBLAS_CALL(cublasDestroy(handle));
    CUDA_SAFECALL(cudaFree(d_A));
    CUDA_SAFECALL(cudaFree(d_B));
    CUDA_SAFECALL(cudaFree(d_C));
    free(h_A);
    free(h_B);
    free(h_C);

    // Destroy the stream
    CUDA_SAFECALL(cudaStreamDestroy(stream));

    return 0;
}