Why is my 'trivial' convolution kernel faster than cuDNN?

I implemented a basic CUDA kernel for convolution:

template<typename T>
__global__ void convNoTensorKernel(int imgM, int imgN, int krnlM, int krnlN, const T *img, const T *krnl, T *out) {

    const int outM = imgM - krnlM + 1;
    const int outN = imgN - krnlN + 1;
    const int totalIdxs = outM * outN;

    const int idx = threadIdx.x + blockIdx.x * MAX_CUDA_THREAD_COUNT;
    if (idx >= totalIdxs) {
        return;
    }

    const int outBlockX = idx % outN;
    const int outBlockY = (idx / outN) % outM;
    const int imgStartIdx = outBlockX + outBlockY * imgN;

    T sum = 0;
    for (int i = 0; i < krnlM * krnlN; ++i) {
        const int krnlX = i % krnlN;
        const int krnlY = (i / krnlN) % krnlM;
        const int rowOffset = krnlY * imgN;
        const int imgIdx = imgStartIdx + rowOffset + krnlX;

        const T imgValue = img[imgIdx];
        const T krnlValue = krnl[(krnlM * krnlN) - 1 - i];

        sum += imgValue * krnlValue;
    }

    out[outBlockX + outBlockY * outN] = sum;
}

For a convolution of a 2048x2048 image with a 128x128 kernel, this takes 0.481s.

With cuDNN this takes around 7.4s:

template<typename T>
void cuda_conv(Matrix<T> &img, Matrix<T> &krnl, Matrix<T> &out) {

    const float alpha = 1.0f;
    const float beta = 0.0f;

    // Create a cuDNN handle:
    cudnnHandle_t handle;
    cudnnCreate(&handle);

    // Create your tensor descriptors:
    cudnnTensorDescriptor_t cudnnIdesc;
    cudnnFilterDescriptor_t cudnnFdesc;
    cudnnTensorDescriptor_t cudnnOdesc;
    cudnnConvolutionDescriptor_t cudnnConvDesc;
    cudnnCreateTensorDescriptor( &cudnnIdesc );
    cudnnCreateFilterDescriptor( &cudnnFdesc );
    cudnnCreateTensorDescriptor( &cudnnOdesc );
    cudnnCreateConvolutionDescriptor( &cudnnConvDesc );

    checkCUDAError( "SetImgDescriptor failed", cudnnSetTensor4dDescriptor(cudnnIdesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, img.B(), img.C(), img.M(), img.N()) );

    checkCUDAError( "SetFilterDescriptor failed", cudnnSetFilter4dDescriptor(cudnnFdesc, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, krnl.B(), krnl.C(), krnl.M(), krnl.N()) );

    checkCUDAError( "SetOutDescriptor failed", cudnnSetTensor4dDescriptor(cudnnOdesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, out.B(), out.C(), out.M(), out.N()) );

    checkCUDAError( "SetConvDescriptor failed", cudnnSetConvolution2dDescriptor(cudnnConvDesc, 0, 0, 1, 1, 1, 1, CUDNN_CONVOLUTION, CUDNN_DATA_HALF) );

    // Set the math type to allow cuDNN to use Tensor Cores:
    checkCUDAError( "SetConvMathType failed", cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );

    // Choose a supported algorithm:
    int algoCount = 0;
    cudnnConvolutionFwdAlgoPerf_t algoPerf;
    checkCUDAError( "GetConvForwardAlgo failed", cudnnFindConvolutionForwardAlgorithm(handle, cudnnIdesc, cudnnFdesc, cudnnConvDesc, cudnnOdesc, 1, &algoCount, &algoPerf) );
    checkCUDAError( "ConvolutionForwardAlgorithm failed", algoPerf.status );

    // Allocate your workspace:
    uint8_t *workSpace = nullptr;
    size_t workSpaceSize = 0;
    checkCUDAError( "WorkspaceSize failed", cudnnGetConvolutionForwardWorkspaceSize(handle, cudnnIdesc, cudnnFdesc, cudnnConvDesc, cudnnOdesc, algoPerf.algo, &workSpaceSize) );
    if (workSpaceSize > 0) {
        checkCUDAError( "Workspace malloc failed", cudaMalloc((void**)&workSpace, workSpaceSize) );
    }

    checkCUDAError( "Conv failed", cudnnConvolutionForward(handle, (void*)(&alpha), cudnnIdesc, img.dataOnGPU,
                                        cudnnFdesc, krnl.dataOnGPU, cudnnConvDesc, algoPerf.algo,
                                        workSpace, workSpaceSize, (void*)(&beta),
                                        cudnnOdesc, out.dataOnGPU) );
}

Compiled with nvcc --std=c++17 -lineinfo -g -O3 -arch=sm_75 on RTX2080.

Even after skipping cudnnFindConvolutionForwardAlgorithm and using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM it still takes around 1s.

Are there errors in my cuDNN code that I am not aware of?

I am very grateful for any hints or tips.

Can you provide the complete code, with timing?

I measure time with std::chrono::high_resolution_clock.
There really is not much to my code. Matrix is just a templated class holding the data in an array dataOnGPU which is allocated via cudaMalloc( (void**)&dataOnGPU, sizeof(T) * B() * C() * M() * N() ) and filled with data beforehand.

Well, I do not see std::chrono::high_resolution_clock in the code you posted. Without knowing what exactly you are measuring it is hard to help. That is why I asked for the complete code.

This is the function that launches the cuDNN code:

    Matrix<T> convcudnn(Matrix<T> &K) {
        Matrix<T> out(_B, K.B(), _M - K.M() + 1, _N - K.N() + 1);
        out.allocateGPU();

        K.toGPU();
        this->toGPU();

        cuda_conv<T>(*this, K, out);

        out.toCPU();

        return out;
    }

where this is the image and K is the kernel.

I measure

    const auto begin = std::chrono::high_resolution_clock::now();
    Matrix<half> out = IMG.convcudnn(KRNL);
    const auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::high_resolution_clock::now() - begin);

And this is how I launch my custom kernel function:

    Matrix<T> convNoTensor(Matrix<T> &kernel) {
        const int outM = M() - kernel.M() + 1;
        const int outN = N() - kernel.N() + 1;
        const int threads = outM * outN;
        const int blocks = std::ceil(threads / (float)MAX_CUDA_THREAD_COUNT);

        Matrix<T> out(outM, outN);
        out.allocateGPU();

        this->toGPU();
        kernel.toGPU();

        std::cout << "Launching " << threads << " threads for convolution." << std::endl;
        convNoTensorKernel<T><<<blocks, MAX_CUDA_THREAD_COUNT>>>(M(), N(), kernel.M(), kernel.N(), dataOnGPU, kernel.dataOnGPU, out.dataOnGPU);

        checkCUDAError( "PeekAtLastError for convNoTensor failed", cudaPeekAtLastError() );
        checkCUDAError( "DeviceSync for convNoTensor failed", cudaDeviceSynchronize() );

        out.toCPU();
        return out;
    }

Measuring is the same, just exchange convcudnn with convNoTensor.

1 Like