Performance Difference Between Custom CUDNN Wrapper and PyTorch

I’ve implemented a wrapper function for calling the CUDNN convolution API as shown in my code snippet below. After testing its execution and profiling with nsys, I noticed a significant performance gap compared to PyTorch’s implementation.

When profiling my implementation, I found that the core convolution computation takes more than 20 milliseconds:

cudnn_infer_ampere_scudnn_winograd_128x128_ldg1_ldg4_relu_tile148t_nt_v1
Begins: 3.61082s
Ends: 3.6346s (+23.778 ms)
grid: <<<1024, 4, 2>>>
block: <<<256, 1, 1>>>
Launch Type: Regular
Static Shared Memory: 49,152 bytes
Dynamic Shared Memory: 0 bytes
Registers Per Thread: 126
Local Memory Per Thread: 0 bytes
Local Memory Total: 127,401,984 bytes
Shared Memory executed: 102,400 bytes
Shared Memory Bank Size: 4 B
Theoretical occupancy: 25 %
Launched from thread: 3428407
Latency: ←109.155 ms
Correlation ID: 2594
Stream: Stream 22

However, when I wrote the same convolution operation in PyTorch and profiled it, it only takes about 100 microseconds:

cudnn_ampere_scudnn_128x32_relu_small_nn_v1
Begins: 13.3173s
Ends: 13.3174s (+93.952 μs)
grid: <<<3136, 2, 1>>>
block: <<<64, 1, 1>>>
Launch Type: Regular
Static Shared Memory: 5,376 bytes
Dynamic Shared Memory: 0 bytes
Registers Per Thread: 128
Local Memory Per Thread: 0 bytes
Local Memory Total: 205,258,752 bytes
Shared Memory executed: 102,400 bytes
Shared Memory Bank Size: 4 B
Theoretical occupancy: 25 %
Launched from thread: 3417000
Latency: ←124.023 ms
Correlation ID: 2300
Stream: Default stream 7

This performance difference is more than two orders of magnitude!

I suspect that PyTorch’s backend or the Python backend has some compilation optimizations. However, looking at the kernel execution, it seems that PyTorch is still ultimately calling pre-written kernels from CUDNN rather than generating them through compilation techniques.

Question:
Is it possible for my wrapper function approach to achieve execution efficiency similar to PyTorch? If so, what additional work do I need to do?

Here’s my wrapper function implementation:

extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCudnnConv2dForward(
    int n, int c, int h, int w_in,              // Input dimensions
    int k, int r, int s,                         // Kernel dimensions
    int pad_h, int pad_w,                        // Padding
    int stride_h, int stride_w,                  // Stride
    int dilation_h, int dilation_w,              // Dilation
    void* x_data, void* w_data, void* bias_data, // Input, weight, and bias pointers
    void* y_data,                               // Output pointer
    CUstream stream                             // CUDA stream
) {
    // Ensure global context
    mgpuEnsureContext();

    // Get cuDNN handle for this stream
    cudnnHandle_t handle = mgpuCudnnGetHandle(stream);
    
    // Create descriptors
    cudnnTensorDescriptor_t xDesc, yDesc, biasDesc;
    cudnnFilterDescriptor_t wDesc;
    cudnnConvolutionDescriptor_t convDesc;
    
    CUDNN_REPORT_IF_ERROR(cudnnCreateTensorDescriptor(&xDesc));
    CUDNN_REPORT_IF_ERROR(cudnnCreateFilterDescriptor(&wDesc));
    CUDNN_REPORT_IF_ERROR(cudnnCreateTensorDescriptor(&yDesc));
    CUDNN_REPORT_IF_ERROR(cudnnCreateTensorDescriptor(&biasDesc));
    CUDNN_REPORT_IF_ERROR(cudnnCreateConvolutionDescriptor(&convDesc));
    
    // Set input descriptor
    CUDNN_REPORT_IF_ERROR(cudnnSetTensor4dDescriptor(
        xDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w_in));
    
    // Set weight descriptor
    CUDNN_REPORT_IF_ERROR(cudnnSetFilter4dDescriptor(
        wDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, k, c, r, s));
    
    // Set convolution descriptor
    CUDNN_REPORT_IF_ERROR(cudnnSetConvolution2dDescriptor(
        convDesc, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
        CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT));
    
    // Enable Tensor Cores
    CUDNN_REPORT_IF_ERROR(cudnnSetConvolutionMathType(convDesc, CUDNN_TENSOR_OP_MATH));

    // Get output dimensions
    int out_n, out_c, out_h, out_w;
    CUDNN_REPORT_IF_ERROR(cudnnGetConvolution2dForwardOutputDim(
        convDesc, xDesc, wDesc, &out_n, &out_c, &out_h, &out_w));

    fprintf(stderr, "Output dimensions: n=%d, c=%d, h=%d, w=%d\n", 
          out_n, out_c, out_h, out_w);
    
    CUDNN_REPORT_IF_ERROR(cudnnSetTensor4dDescriptor(
        yDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, out_n, out_c, out_h, out_w));
    
    // Set bias descriptor (1xCx1x1)
    CUDNN_REPORT_IF_ERROR(cudnnSetTensor4dDescriptor(
        biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, k, 1, 1));
    
   int requestedAlgoCount = 10;
   int returnedAlgoCount;
   cudnnConvolutionFwdAlgoPerf_t perfResults[10];
   CUDNN_REPORT_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
       handle, xDesc, wDesc, convDesc, yDesc,
       requestedAlgoCount, &returnedAlgoCount, perfResults));
  
   cudnnConvolutionFwdAlgo_t algo = perfResults[0].algo;

    // Get workspace size
    size_t workspaceSize = 0;
    CUDNN_REPORT_IF_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
        handle, xDesc, wDesc, convDesc, yDesc, algo, &workspaceSize));
    
    // Allocate workspace
    void* workspace = nullptr;
    if (workspaceSize > 0) {
      CUdeviceptr wsPtr = 0;
      CUDA_REPORT_IF_ERROR(cuMemAlloc(&wsPtr, workspaceSize));
      workspace = reinterpret_cast<void*>(wsPtr);
    }

    // Execute convolution
    const float alpha = 1.0f;
    const float beta = 0.0f;
    cudnnStatus_t status = cudnnConvolutionForward(
    handle, &alpha, xDesc, x_data, wDesc, w_data, convDesc, algo,
    workspace, workspaceSize, &beta, yDesc, y_data);

    // Report errors (if any)
    CUDNN_REPORT_IF_ERROR(status);

    // Add bias (if provided)
    if (bias_data != nullptr) {
      const float alpha_bias = 1.0f;
      const float beta_bias = 1.0f;
      CUDNN_REPORT_IF_ERROR(cudnnAddTensor(
          handle, &alpha_bias, biasDesc, bias_data, &beta_bias, yDesc, y_data));
    }
    
    // Free workspace
    if (workspace != nullptr) {
      CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(workspace)));
    }
    
    // Clean up descriptors
    CUDNN_REPORT_IF_ERROR(cudnnDestroyTensorDescriptor(xDesc));
    CUDNN_REPORT_IF_ERROR(cudnnDestroyFilterDescriptor(wDesc));
    CUDNN_REPORT_IF_ERROR(cudnnDestroyTensorDescriptor(yDesc));
    CUDNN_REPORT_IF_ERROR(cudnnDestroyTensorDescriptor(biasDesc));
    CUDNN_REPORT_IF_ERROR(cudnnDestroyConvolutionDescriptor(convDesc));
}

Any suggestions or insights would be greatly appreciated!

Hi @lilil111 ,
Let me fetch the detailed answer for this and get back to you.

Thanks

I’ve had a similar issue before, and I’m also troubled.

I don’t know if you have solved this problem.
However, in my problem, I found a point that I had overlooked before. When I called the cudnn API myself, such as executing a conv, I allocated all the input parameters on the shared memory (a part visible to both the host and the device), which caused the API to continuously transfer data during execution, resulting in a large performance difference, making the efficiency of my own call not on the same order of magnitude as pytorch.
After I solved this problem, the efficiency of my own call was still slower than pytorch, but it was on the same order of magnitude. I think this should be understandable, because pytorch does a lot of preprocessing and selects the best execution algorithm at runtime.