How to use the device-side CUDA Graph APIs? How to get hold of `cudaGraphDeviceNode_t`?

How to use the CUDA Graph __device__ APIs? The CUDA Runtime API for Graph Management lists few device side CUDA Graph APIs:

  1. __device__ ​cudaError_t cudaGraphKernelNodeSetEnabled(cudaGraphDeviceNode_t node, bool enable ) : Enables or disables the given kernel node.
  2. __device__ ​cudaError_t cudaGraphKernelNodeSetGridDim(cudaGraphDeviceNode_t node, dim3 gridDim ) : Updates the grid dimensions of the given kernel node.
  3. __device__ ​cudaError_t cudaGraphKernelNodeSetParam(cudaGraphDeviceNode_t node, size_t offset, const void* value, size_t size ): Updates the kernel parameters of the given kernel node.
  4. __device__ ​cudaError_t cudaGraphKernelNodeUpdatesApply( constcudaGraphKernelNodeUpdate* updates, size_t updateCount ): Batch applies multiple kernel node updates.

Of them I am curious of the function cudaGraphKernelNodeUpdatesApply which takes in an array of pointers to structure cudaGraphKernelNodeUpdate which has field for cudaGraphDeviceNode_t, defined as:

typedef CUgraphDeviceUpdatableNode_st * cudaGraphDeviceNode_t


    CUDA device node handle for device-side node update

Is there any documentation or example usage of these device side CUDA Graph management APIs – how to get hold of the cudaGraphDeviceNode_t and then pass it to the runtime APIs listed above?

Neither the CUDA C++ Programming Guide nor the CUDA Samples list any example usage of these APIs. Any suggestion would be of great help. Here is a suggestion which asks to set certain flags for the node and get hold of the device handle for the node. Also the CUDA Toolkit 12.4 Enhancements blog talks of the same:

At graph creation, a node can “opt in” to being device-updatable through a new kernel node attribute. This attribute, when enabled, will return the device node handle as part of the setAttribute call.

To opt in to this feature, set the attribute CU_KERNEL_NODE_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE on the node after node creation.

But it’s not clear how to use it.

Looking into the documentation. I could find two approaches.

  1. Marking the kernel at launch with a launch attribute (runtime) / launch attribute (driver) during cuda stream capture.

    Runtime:

    cudaLaunchAttributeValue::@31 cudaLaunchAttributeValue::deviceUpdatableKernelNode [inherited]


    Value of launch attribute cudaLaunchAttributeDeviceUpdatableKernelNode with the following fields:

    • int deviceUpdatable - Whether or not the resulting kernel node should be device-updatable.

    • cudaGraphDeviceNode_t devNode - Returns a handle to pass to the various device-side update functions.

    Here, under enum cudaLaunchAttributeID and value cudaLaunchAttributeDeviceUpdatableKernelNode=13 it is mentioned":


    Valid for graph nodes, launches. This attribute is graphs-only, and passing it to a launch in a non-capturing stream will result in an error. :cudaLaunchAttributeValue::deviceUpdatableKernelNode::deviceUpdatable can only be set to 0 or 1. Setting the field to 1 indicates that the corresponding kernel node should be device-updatable. On success, a handle will be returned via cudaLaunchAttributeValue::deviceUpdatableKernelNode::devNode which can be passed to the various device-side update functions to update the node’s kernel parameters from within another kernel. For more information on the types of device updates that can be made, as well as the relevant limitations thereof, see cudaGraphKernelNodeUpdatesApply. Nodes which are device-updatable have additional restrictions compared to regular kernel nodes. Firstly, device-updatable nodes cannot be removed from their graph via cudaGraphDestroyNode. Additionally, once opted-in to this functionality, a node cannot opt out, and any attempt to set the deviceUpdatable attribute to 0 will result in an error. Device-updatable kernel nodes also cannot have their attributes copied to/from another kernel node via cudaGraphKernelNodeCopyAttributes. Graphs containing one or more device-updatable nodes also do not allow multiple instantiation, and neither the graph nor its instantiated version can be passed to cudaGraphExecUpdate. If a graph contains device-updatable nodes and updates those nodes from the device from within the graph, the graph must be uploaded with cuGraphUpload before it is launched. For such a graph, if host-side executable graph updates are made to the device-updatable nodes, the graph must be uploaded before it is launched again.

    Driver:

    CUlaunchAttributeValue::@10 CUlaunchAttributeValue::deviceUpdatableKernelNode [inherited]


    Value of launch attribute CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE with the following fields:

    • int deviceUpdatable - Whether or not the resulting kernel node should be device-updatable.

    • CUgraphDeviceNode devNode - Returns a handle to pass to the various device-side update functions.

    Here, under enum CUlaunchAttributeID and value CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE=13 (This enum value is mentioned here) it is mentioned:


    Valid for graph nodes, launches. This attribute is graphs-only, and passing it to a launch in a non-capturing stream ….

    In simple words, The runtime exposes a launch attribute cudaLaunchAttributeDeviceUpdatableKernelNode. When you set it for a kernel launch, CUDA will return a cudaGraphDeviceNode_t in the attribute’s value; you then copy that handle to device memory. This seem to work with stream capture so the kernel ends up as a node in your captured graph and you get its device-updatable handle.

    // mark_kernel_at_launch.cu
    #include <iostream>
    #include <cublas_v2.h>
    #include <cuda_runtime.h>
    #include <vector>
    #include <cuda.h>
    
    #define CUDA_SAFECALL(call)                                                 \
        {                                                                       \
            call;                                                               \
            cudaError err = cudaGetLastError();                                 \
            if (cudaSuccess != err) {                                           \
                fprintf(                                                        \
                    stderr,                                                     \
                    "Cuda error in function '%s' file '%s' in line %i : %s.\n", \
                    #call, __FILE__, __LINE__, cudaGetErrorString(err));        \
                fflush(stderr);                                                 \
                exit(EXIT_FAILURE);                                             \
            }                                                                   \
        }
    
    #define SAFECALL_DRV(call)                                                  \
        {                                                                       \
            CUresult err = call;                                                \
            if (err != CUDA_SUCCESS) {                                          \
                const char *errStr;                                             \
                cuGetErrorString(err, &errStr);                                 \
                fprintf(                                                        \
                    stderr,                                                     \
                    "CUDA Driver API error in function '%s' file '%s' in line %i : %s.\n", \
                    #call, __FILE__, __LINE__, errStr);                         \
                fflush(stderr);                                                 \
                exit(EXIT_FAILURE);                                             \
            }                                                                   \
        }
    
    #define CUBLAS_CALL(func)                                                      \
    {                                                                              \
        cublasStatus_t s = (func);                                                 \
        if(s != CUBLAS_STATUS_SUCCESS) {                                           \
            std::cerr << "cuBLAS Error: " << s << " at line " << __LINE__ << std::endl; \
            exit(EXIT_FAILURE);                                                    \
        }                                                                          \
    }
    
    const int N = 1024;  // Dimension of the square matrix (N x N)
    
    // Helper function to initialize matrices
    void initializeMatrix(float* matrix, int n, float value) {
        for (int i = 0; i < n * n; ++i) {
            matrix[i] = value;
        }
    }
    
    void printMatrix(float* matrix, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                std::cout << matrix[i * n + j] << " ";
            }
            std::cout << std::endl;
        }
    }
    
    // CUDA kernel that adds two vectors, each thread handles one element of c
    __global__ void vecAdd(float *a, float *b, float *c, int n) {
        if(threadIdx.x == 0 && blockIdx.x == 0) {
            printf("vecAdd kernel launched with %d blocks of %d threads; n = %d\n", gridDim.x, blockDim.x, n);
        }
        int id = blockIdx.x * blockDim.x + threadIdx.x;
        if (id < n) c[id] = a[id] + b[id];
    }
    
    __device__ cudaGraphDeviceNode_t g_vecadd_node_handle;
    
    __global__ void tweak_node() {
        // enable (or disable) the node
        cudaGraphKernelNodeSetEnabled(g_vecadd_node_handle, true);
    
        // change grid size (example)
        cudaGraphKernelNodeSetGridDim(g_vecadd_node_handle, dim3(512,1,1));
    
        // update a param (example): offset is ABI-dependent; if vecAdd args are
        // (float*, float*, float*, int), the int may come after 3 pointers.
        int newN = 4096;
        size_t offset = 3 * sizeof(void*);  // typical on 64-bit, for illustration
        cudaGraphKernelNodeSetParam(g_vecadd_node_handle, offset, &newN, sizeof(newN));
    
        // or do a batch with cudaGraphKernelNodeUpdatesApply(...)
    }
    
    int main() {
        // Initialize cuBLAS
        cublasHandle_t handle;
        CUBLAS_CALL(cublasCreate(&handle));
    
        // Create a CUDA stream
        cudaStream_t stream;
        CUDA_SAFECALL(cudaStreamCreate(&stream));
    
        // Set cuBLAS to use the created stream
        CUBLAS_CALL(cublasSetStream(handle, stream));
    
        // Allocate memory for input and output matrices on host
        // A * B = C
        float* h_A = (float*) malloc(N * N * sizeof(float));
        float* h_B = (float*) malloc(N * N * sizeof(float));
        float* h_C = (float*) malloc(N * N * sizeof(float));
        initializeMatrix(h_A, N, 1.0f);  // Matrix A with all elements 1.0
        initializeMatrix(h_B, N, 2.0f);  // Matrix B with all elements 2.0
    
        // Allocate memory on device
        float *d_A, *d_B, *d_C;
        CUDA_SAFECALL(cudaMalloc((void**)&d_A, N * N * sizeof(float)));
        CUDA_SAFECALL(cudaMalloc((void**)&d_B, N * N * sizeof(float)));
        CUDA_SAFECALL(cudaMalloc((void**)&d_C, N * N * sizeof(float)));
        
        // Copy data from host to device (Async using stream)
        CUDA_SAFECALL(cudaMemcpyAsync(d_A, h_A, N * N * sizeof(float), cudaMemcpyHostToDevice, stream));
        CUDA_SAFECALL(cudaMemcpyAsync(d_B, h_B, N * N * sizeof(float), cudaMemcpyHostToDevice, stream));
    
        // Create CUDA Graph
        cudaGraph_t graph;
        cudaGraphExec_t instance;
        
        // Begin CUDA Graph capture
        cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
    
        // launch the tweaker kernel (still inside stream capture)
        tweak_node<<<1,1,0,stream>>>();
    
    # if 1
        //testing a vecAdd kernel
        int blockSize, gridSize;
        blockSize = 1024;
        gridSize = (int)ceil((float)N / blockSize);
        
        ////////////////////////////////////////////////////////////////////////////
        //build extensible launch config with the "device-updatable" attribute
        cudaLaunchAttribute attr{};
        attr.id = cudaLaunchAttributeDeviceUpdatableKernelNode;
    
        attr.val.deviceUpdatableKernelNode.deviceUpdatable = 1;
        attr.val.deviceUpdatableKernelNode.devNode        = 0;  
    
        cudaLaunchConfig_t cfg{};
        cfg.gridDim          = dim3(gridSize, 1, 1);
        cfg.blockDim         = dim3(blockSize, 1, 1);
        cfg.dynamicSmemBytes = 0;                       
        cfg.stream           = stream;
        cfg.attrs            = &attr;
        cfg.numAttrs         = 1;
    
        // launch (still inside stream capture)
        CUDA_SAFECALL(cudaLaunchKernelEx(&cfg, vecAdd, d_A, d_B, d_C, N));
    
        cudaGraphDeviceNode_t host_devnode = attr.val.deviceUpdatableKernelNode.devNode;
        ////////////////////////////////////////////////////////////////////////////
    #endif
        
        // Set up GEMM (Matrix Multiplication)
        float alpha = 1.0f, beta = 0.0f;
    
        // cublasSgemm performs matrix multiplication C = alpha * A * B + beta * C
        // where A, B, C are NxN matrices in column-major order.
        CUBLAS_CALL(cublasSgemm(handle,
                                CUBLAS_OP_N, CUBLAS_OP_N,
                                N, N, N,
                                &alpha,
                                d_A, N,
                                d_B, N,
                                &beta,
                                d_C, N));
        // End CUDA Graph capture
        cudaStreamEndCapture(stream, &graph);      
    
        // Instantiate the graph
        cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
        cuGraphUpload(instance, stream);
    
        // After the call, host_devnode now holds the device-updatable node handle.
        // Stash it in a __device__ global so kernels can use it:
        CUDA_SAFECALL(cudaMemcpyToSymbol(g_vecadd_node_handle, &host_devnode,
                                        sizeof(host_devnode)));
        printf("device-updatable node handle: %p\n", (void*)host_devnode);
    
        // Get the number of nodes to allocate an array of nodes
        cudaGraphNode_t* nodes=NULL;
        size_t numNodes=0;
        CUDA_SAFECALL((cudaGraphGetNodes(graph, nodes, &numNodes)));
        nodes = (cudaGraphNode_t*)malloc(numNodes*sizeof(cudaGraphNode_t));
        CUDA_SAFECALL((cudaGraphGetNodes(graph, nodes, &numNodes)));
        
        std::vector<cudaGraphNode_t> kernelNode;
        
        printf("Number of nodes: %lu\n", numNodes);
        for(int i = 0; i < numNodes; i++){
            cudaGraphNodeType type;
            cudaGraphNodeGetType(nodes[i], &type);
            if (type == cudaGraphNodeTypeKernel){
                printf("Node %d is kernel Node\n", i+1);
                kernelNode.push_back(nodes[i]);
            }
            else if(type == cudaGraphNodeTypeMemset){
                printf("Node %d is memset Node\n", i+1);
            }
            else if(type == cudaGraphNodeTypeMemAlloc){
                printf("Node %d is memory allocation Node\n", i+1);
            }
            else if(type == cudaGraphNodeTypeMemFree){
                printf("Node %d is memory free Node\n", i+1);
            }
            else{
                printf("Node %d is unknown type Node\n", i+1);
            }
        }
        
        // Launch the instantiated graph
        cudaGraphLaunch(instance, stream);
        cudaStreamSynchronize(stream);
        printf("Graph launched successfully.\n");
    
        // Copy result from device to host (Async using stream)
        CUDA_SAFECALL(cudaMemcpyAsync(h_C, d_C, N * N * sizeof(float), cudaMemcpyDeviceToHost, stream));
    
        // Wait for the stream to finish all operations
        CUDA_SAFECALL(cudaStreamSynchronize(stream));
    
        // // Print the result matrix C
        // std::cout << "Matrix C (Result of A * B):" << std::endl;
        // printMatrix(h_C, N);
    
        // Cleanup
        CUBLAS_CALL(cublasDestroy(handle));
        CUDA_SAFECALL(cudaFree(d_A));
        CUDA_SAFECALL(cudaFree(d_B));
        CUDA_SAFECALL(cudaFree(d_C));
        free(h_A);
        free(h_B);
        free(h_C);
    
        // Destroy the stream
        CUDA_SAFECALL(cudaStreamDestroy(stream));
    
        return 0;
    }
    
    
    $ nvcc mark_kernel_at_launch.cu -o mark_kernel_at_launch -lcublas -lcuda
    $ ./mark_kernel_at_launch 
    device-updatable node handle: 0x790d2ac74400
    Number of nodes: 3
    Node 1 is kernel Node
    Node 2 is kernel Node
    Node 3 is kernel Node
    vecAdd kernel launched with 512 blocks of 1024 threads; n = 4096
    Graph launched successfully.
    
  2. Alternatively, set a node attribute after graph construction.
    If we are not in control of the launch site (e.g., you constructed the graph via stream capture and only have the cudaGraphNode_t later), you can still ask CUDA to make a kernel node device-updatable and get the handle using the kernel-node attribute API:

    • Runtime: cudaGraphKernelNodeSetAttribute(node, cudaKernelNodeAttributeDeviceUpdatableKernelNode, &value)

    • Driver: cuGraphKernelNodeSetAttribute(node, CU_KERNEL_NODE_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE, &value)

      Using the runtime API:

      // set_node_attribute_after_graph_construction.cu
      #include <iostream>
      #include <cublas_v2.h>
      #include <cuda_runtime.h>
      #include <vector>
      #include <cuda.h>
      
      #define CUDA_SAFECALL(call)                                                 \
          {                                                                       \
              call;                                                               \
              cudaError err = cudaGetLastError();                                 \
              if (cudaSuccess != err) {                                           \
                  fprintf(                                                        \
                      stderr,                                                     \
                      "Cuda error in function '%s' file '%s' in line %i : %s.\n", \
                      #call, __FILE__, __LINE__, cudaGetErrorString(err));        \
                  fflush(stderr);                                                 \
                  exit(EXIT_FAILURE);                                             \
              }                                                                   \
          }
      
      #define SAFECALL_DRV(call)                                                  \
          {                                                                       \
              CUresult err = call;                                                \
              if (err != CUDA_SUCCESS) {                                          \
                  const char *errStr;                                             \
                  cuGetErrorString(err, &errStr);                                 \
                  fprintf(                                                        \
                      stderr,                                                     \
                      "CUDA Driver API error in function '%s' file '%s' in line %i : %s.\n", \
                      #call, __FILE__, __LINE__, errStr);                         \
                  fflush(stderr);                                                 \
                  exit(EXIT_FAILURE);                                             \
              }                                                                   \
          }
      
      #define CUBLAS_CALL(func)                                                      \
      {                                                                              \
          cublasStatus_t s = (func);                                                 \
          if(s != CUBLAS_STATUS_SUCCESS) {                                           \
              std::cerr << "cuBLAS Error: " << s << " at line " << __LINE__ << std::endl; \
              exit(EXIT_FAILURE);                                                    \
          }                                                                          \
      }
      
      const int N = 1024;  // Dimension of the square matrix (N x N)
      
      // Helper function to initialize matrices
      void initializeMatrix(float* matrix, int n, float value) {
          for (int i = 0; i < n * n; ++i) {
              matrix[i] = value;
          }
      }
      
      void printMatrix(float* matrix, int n) {
          for (int i = 0; i < n; ++i) {
              for (int j = 0; j < n; ++j) {
                  std::cout << matrix[i * n + j] << " ";
              }
              std::cout << std::endl;
          }
      }
      
      // CUDA kernel that adds two vectors, each thread handles one element of c
      __global__ void vecAdd(float *a, float *b, float *c, int n) {
          if(threadIdx.x == 0 && blockIdx.x == 0) {
              printf("vecAdd kernel launched with %d blocks of %d threads; n = %d\n", gridDim.x, blockDim.x, n);
          }
          int id = blockIdx.x * blockDim.x + threadIdx.x;
          if (id < n) c[id] = a[id] + b[id];
      }
      
      __device__ cudaGraphDeviceNode_t g_vecadd_node_handle;
      
      __global__ void tweak_node() {
          // enable (or disable) the node
          cudaGraphKernelNodeSetEnabled(g_vecadd_node_handle, true);
      
          // change grid size (example)
          cudaGraphKernelNodeSetGridDim(g_vecadd_node_handle, dim3(512,1,1));
      
          // update a param (example): offset is ABI-dependent; if vecAdd args are
          // (float*, float*, float*, int), the int may come after 3 pointers.
          int newN = 4096;
          size_t offset = 3 * sizeof(void*);  // typical on 64-bit, for illustration
          cudaGraphKernelNodeSetParam(g_vecadd_node_handle, offset, &newN, sizeof(newN));
      
          // or do a batch with cudaGraphKernelNodeUpdatesApply(...)
      }
      
      int main() {
          // Initialize cuBLAS
          cublasHandle_t handle;
          CUBLAS_CALL(cublasCreate(&handle));
      
          // Create a CUDA stream
          cudaStream_t stream;
          CUDA_SAFECALL(cudaStreamCreate(&stream));
      
          // Set cuBLAS to use the created stream
          CUBLAS_CALL(cublasSetStream(handle, stream));
      
          // Allocate memory for input and output matrices on host
          // A * B = C
          float* h_A = (float*) malloc(N * N * sizeof(float));
          float* h_B = (float*) malloc(N * N * sizeof(float));
          float* h_C = (float*) malloc(N * N * sizeof(float));
          initializeMatrix(h_A, N, 1.0f);  // Matrix A with all elements 1.0
          initializeMatrix(h_B, N, 2.0f);  // Matrix B with all elements 2.0
      
          // Allocate memory on device
          float *d_A, *d_B, *d_C;
          CUDA_SAFECALL(cudaMalloc((void**)&d_A, N * N * sizeof(float)));
          CUDA_SAFECALL(cudaMalloc((void**)&d_B, N * N * sizeof(float)));
          CUDA_SAFECALL(cudaMalloc((void**)&d_C, N * N * sizeof(float)));
          
          // Copy data from host to device (Async using stream)
          CUDA_SAFECALL(cudaMemcpyAsync(d_A, h_A, N * N * sizeof(float), cudaMemcpyHostToDevice, stream));
          CUDA_SAFECALL(cudaMemcpyAsync(d_B, h_B, N * N * sizeof(float), cudaMemcpyHostToDevice, stream));
      
          // Create CUDA Graph
          cudaGraph_t graph;
          cudaGraphExec_t instance;
          
          // Begin CUDA Graph capture
          cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
      
          // launch the tweaker kernel (still inside stream capture)
          tweak_node<<<1,1,0,stream>>>();
          
      # if 1
          //testing a vecAdd kernel
          int blockSize, gridSize;
          blockSize = 1024;
          gridSize = (int)ceil((float)N / blockSize);
          vecAdd<<<gridSize, blockSize, 0, stream>>>(d_A, d_B, d_C, N);
      
      #endif
          
          // Set up GEMM (Matrix Multiplication)
          float alpha = 1.0f, beta = 0.0f;
      
          // cublasSgemm performs matrix multiplication C = alpha * A * B + beta * C
          // where A, B, C are NxN matrices in column-major order.
          CUBLAS_CALL(cublasSgemm(handle,
                                  CUBLAS_OP_N, CUBLAS_OP_N,
                                  N, N, N,
                                  &alpha,
                                  d_A, N,
                                  d_B, N,
                                  &beta,
                                  d_C, N));
          // End CUDA Graph capture
          cudaStreamEndCapture(stream, &graph);      
      
          
          // Get the number of nodes to allocate an array of nodes
          cudaGraphNode_t* nodes=NULL;
          size_t numNodes=0;
          CUDA_SAFECALL((cudaGraphGetNodes(graph, nodes, &numNodes)));
          nodes = (cudaGraphNode_t*)malloc(numNodes*sizeof(cudaGraphNode_t));
          CUDA_SAFECALL((cudaGraphGetNodes(graph, nodes, &numNodes)));
          
          std::vector<cudaGraphNode_t> kernelNode;
          
          printf("Number of nodes: %lu\n", numNodes);
          for(int i = 0; i < numNodes; i++){
              cudaGraphNodeType type;
              cudaGraphNodeGetType(nodes[i], &type);
              if (type == cudaGraphNodeTypeKernel){
                  printf("Node %d is kernel Node\n", i+1);
                  kernelNode.push_back(nodes[i]);
              }
              else if(type == cudaGraphNodeTypeMemset){
                  printf("Node %d is memset Node\n", i+1);
              }
              else if(type == cudaGraphNodeTypeMemAlloc){
                  printf("Node %d is memory allocation Node\n", i+1);
              }
              else if(type == cudaGraphNodeTypeMemFree){
                  printf("Node %d is memory free Node\n", i+1);
              }
              else{
                  printf("Node %d is unknown type Node\n", i+1);
              }
          }
          
          ////////////////////////////////////////////////////////////////////////////
          cudaKernelNodeAttrValue attr{};
          attr.deviceUpdatableKernelNode.deviceUpdatable = 1;   // enable device updates
          attr.deviceUpdatableKernelNode.devNode        = 0;   // CUDA will fill this out
      
          CUDA_SAFECALL(cudaGraphKernelNodeSetAttribute(
              kernelNode[1],  // assuming the 2nd kernel node is vecAdd
              cudaKernelNodeAttributeDeviceUpdatableKernelNode,
              &attr));
          
          // At this point, attr.deviceUpdatableKernelNode.devNode was populated.
          cudaGraphDeviceNode_t devHandle = attr.deviceUpdatableKernelNode.devNode;
      
          // stash it to a __device__ global so kernels can call the __device__ graph APIs:
          CUDA_SAFECALL(cudaMemcpyToSymbol(g_vecadd_node_handle, &devHandle, sizeof(devHandle)));
          printf("Captured vecAdd node handle: %p\n", (void*)devHandle);
          ////////////////////////////////////////////////////////////////////////////
      
          // Instantiate the graph
          cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
          cuGraphUpload(instance, stream);
      
          // Launch the instantiated graph
          cudaGraphLaunch(instance, stream);
          cudaStreamSynchronize(stream);
          printf("Graph launched successfully.\n");
      
          // Copy result from device to host (Async using stream)
          CUDA_SAFECALL(cudaMemcpyAsync(h_C, d_C, N * N * sizeof(float), cudaMemcpyDeviceToHost, stream));
      
          // Wait for the stream to finish all operations
          CUDA_SAFECALL(cudaStreamSynchronize(stream));
      
          // // Print the result matrix C
          // std::cout << "Matrix C (Result of A * B):" << std::endl;
          // printMatrix(h_C, N);
      
          // Cleanup
          CUBLAS_CALL(cublasDestroy(handle));
          CUDA_SAFECALL(cudaFree(d_A));
          CUDA_SAFECALL(cudaFree(d_B));
          CUDA_SAFECALL(cudaFree(d_C));
          free(h_A);
          free(h_B);
          free(h_C);
      
          // Destroy the stream
          CUDA_SAFECALL(cudaStreamDestroy(stream));
      
          return 0;
      }
      
      $ nvcc set_node_attribute_after_graph_construction.cu -o set_node_attribute_after_graph_construction -lcublas -lcuda
      $ ./set_node_attribute_after_graph_construction 
      Number of nodes: 3
      Node 1 is kernel Node
      Node 2 is kernel Node
      Node 3 is kernel Node
      Captured vecAdd node handle: 0x7f4c72c74400
      vecAdd kernel launched with 512 blocks of 1024 threads; n = 4096
      Graph launched successfully.