Unexpected cudnnConvolutionForward performance with varying input channels

Hi all,

I apologize if a similar question has already been asked – a quick search did not yield relevant results, but I may have missed some.

I am benchmarking cudnnConvolutionForward for a single-precision convolution with parameters listed below and with varying number of input channels from 64 to 129 in increments of 1. Benchmarking is performed on a Tesla T4 with other environmental details listed below. The code used for benchmarking is also pasted below.

According to the very helpful deep learning performance guide, I expected to see a gradual increase in average inference time as the number of input channels increased from 64 to 129, with “dips” in inference time at a number of channels divisible by certain powers of two (e.g., 64, 96, 128) due to avoiding quantization effects.

However, I am experiencing the opposite behavior. At batch size 1, I experience large spikes in inference time when the number of input channels is divisible by 16 (64, 80, 96, 112, and 128) and minor spikes at other values divisible by 8 (72, 88, 104, 120). These results are depicted in the attached image.

batch-size-1

Does anyone have any insight into why these apparent performance regressions might be occurring? It seems that there must be a gap in my understanding, as I thought that the values for which I’m experiencing spikes in inference time should be those in which I’d experience dips. Or perhaps I am incorrectly benchmarking performance or using the API (code attached below, for reference).

It may be worth noting that I also experience this behavior (thought it is subdued) with slightly larger batch sizes (e.g., 2, 4), but these effects mostly subside at significantly larger batch sizes (e.g., 16, 32).

Additionally, all convolutions report using algorithm CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM after using cudnnFindConvolutionForwardAlgorithm. I have also experienced the same behavior described above when using the heuristic selection (which selects CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM), though the spikes are subdued and the performance for all convolutions is worse than using IMPLICIT_GEMM.

Thank you in advance for any help!

Convolution parameters

  • Batch size: 1
  • Input channels: vary from 64 to 129
  • Output channels: 128
  • Filter shape: 3x3
  • Input shape: 56x56
  • Strides: 2, 2
  • Padding: 1, 1
  • Dilation: 1, 1
  • All operations are performed in single precision

Environment

  • GPU: Tesla T4 (on an AWS g4dn.xlarge instance)
  • Driver version: 440.64.00
  • CUDA version: 10.2
  • cuDNN version: 7.6.5
  • Linux distro: Ubuntu 16.04.6
  • Compiler flags: -arch=sm_75 -std=c++11 -O3

Code

#include <cudnn.h>
#include <cassert>
#include <cstdlib>
#include <iostream>

#define checkCUDNN(expression)                               \
  {                                                          \
    cudnnStatus_t status = (expression);                     \
    if (status != CUDNN_STATUS_SUCCESS) {                    \
      std::cerr << "Error on line " << __LINE__ << ": "      \
                << cudnnGetErrorString(status) << std::endl; \
      std::exit(EXIT_FAILURE);                               \
    }                                                        \
  }

#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); }
void cudaErrCheck_(cudaError_t stat, const char *file, int line) {
   if (stat != cudaSuccess) {
      fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line);
   }
}

int main(int argc, const char* argv[]) {
  if (argc != 3) {
    std::cerr << "Usage: " << argv[0] << " batch_size in_channels" << std::endl;
    std::exit(EXIT_FAILURE);
  }

  cudnnHandle_t cudnn;
  checkCUDNN(cudnnCreate(&cudnn));

  int batch_size = std::atoi(argv[1]);
  int in_channels = std::atoi(argv[2]);
  int in_height = 56;
  int in_width = 56;
  int out_channels = 128;
  int kernel_height = 3;
  int kernel_width = 3;
  int stride_height = 2;
  int stride_width = 2;
  int pad_height = 1;
  int pad_width = 1;

  int num_warmup = 100;
  int num_iteration = 10000;

  cudnnTensorDescriptor_t input_descriptor;
  checkCUDNN(cudnnCreateTensorDescriptor(&input_descriptor));
  checkCUDNN(cudnnSetTensor4dDescriptor(input_descriptor,
                                        /*format=*/CUDNN_TENSOR_NHWC,
                                        /*dataType=*/CUDNN_DATA_FLOAT,
                                        batch_size,
                                        in_channels,
                                        in_height,
                                        in_width));

  cudnnFilterDescriptor_t kernel_descriptor;
  checkCUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor));
  checkCUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor,
                                        /*dataType=*/CUDNN_DATA_FLOAT,
                                        /*format=*/CUDNN_TENSOR_NCHW,
                                        out_channels,
                                        in_channels,
                                        kernel_height,
                                        kernel_width));

  cudnnConvolutionDescriptor_t convolution_descriptor;
  checkCUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor));
  checkCUDNN(cudnnSetConvolutionMathType(convolution_descriptor, CUDNN_TENSOR_OP_MATH));
  checkCUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor,
                                             pad_height,
                                             pad_width,
                                             stride_height,
                                             stride_width,
                                             /*dilation_height=*/1,
                                             /*dilation_width=*/1,
                                             /*mode=*/CUDNN_CROSS_CORRELATION,
                                             /*computeType=*/CUDNN_DATA_FLOAT));

  int batch_size_o{0}, channels_o{0}, height_o{0}, width_o{0};
  checkCUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor,
                                                   input_descriptor,
                                                   kernel_descriptor,
                                                   &batch_size_o,
                                                   &channels_o,
                                                   &height_o,
                                                   &width_o));

  cudnnTensorDescriptor_t output_descriptor;
  checkCUDNN(cudnnCreateTensorDescriptor(&output_descriptor));
  checkCUDNN(cudnnSetTensor4dDescriptor(output_descriptor,
                                        /*format=*/CUDNN_TENSOR_NHWC,
                                        /*dataType=*/CUDNN_DATA_FLOAT,
                                        batch_size_o,
                                        channels_o,
                                        height_o,
                                        width_o));

  cudnnConvolutionFwdAlgo_t convolution_algorithm;
  int requested_algo_count = 100;
  int returned_algo_count{0};
  cudnnConvolutionFwdAlgoPerf_t perf_results[requested_algo_count];
  memset(perf_results, 0, sizeof(cudnnConvolutionFwdAlgoPerf_t) * requested_algo_count);
  checkCUDNN(
      cudnnFindConvolutionForwardAlgorithm(cudnn,
                                           input_descriptor,
                                           kernel_descriptor,
                                           convolution_descriptor,
                                           output_descriptor,
                                           requested_algo_count,
                                           &returned_algo_count,
                                           perf_results));
  assert(returned_algo_count > 0);
  convolution_algorithm = perf_results[0].algo;
  int workspace_bytes = perf_results[0].memory;
  std::cerr << "Convolution algorithm: " << convolution_algorithm << std::endl;

  void* d_workspace{nullptr};
  cudaMalloc(&d_workspace, workspace_bytes);

  int in_bytes = batch_size * in_channels * in_height * in_width * sizeof(float);
  int out_bytes = batch_size_o * channels_o * height_o * width_o * sizeof(float);

  float* d_input{nullptr};
  cudaMalloc(&d_input, in_bytes);
	cudaMemset(d_input, 2, in_bytes);

  float* d_output{nullptr};
  cudaMalloc(&d_output, out_bytes);
  cudaMemset(d_output, 0, out_bytes);

  float* d_kernel{nullptr};
  int kernel_bytes = kernel_height * kernel_width * out_channels * in_channels * sizeof(float);
  cudaMalloc(&d_kernel, kernel_bytes);
  cudaMemset(d_kernel, 1, kernel_bytes);

  const float alpha = 1.0f, beta = 0.0f;
  for (int i = 0; i < num_warmup; i++) {
    checkCUDNN(cudnnConvolutionForward(cudnn,
                                       &alpha,
                                       input_descriptor,
                                       d_input,
                                       kernel_descriptor,
                                       d_kernel,
                                       convolution_descriptor,
                                       convolution_algorithm,
                                       d_workspace,
                                       workspace_bytes,
                                       &beta,
                                       output_descriptor,
                                       d_output));
  }

  cudaEvent_t start;
  cudaEvent_t stop;
  cudaErrCheck(cudaEventCreate(&start));
  cudaErrCheck(cudaEventCreate(&stop));
  cudaErrCheck(cudaEventRecord(start));
  for (int i = 0; i < num_iteration; i++) {
    checkCUDNN(cudnnConvolutionForward(cudnn,
                                       &alpha,
                                       input_descriptor,
                                       d_input,
                                       kernel_descriptor,
                                       d_kernel,
                                       convolution_descriptor,
                                       convolution_algorithm,
                                       d_workspace,
                                       workspace_bytes,
                                       &beta,
                                       output_descriptor,
                                       d_output));
  }
  cudaErrCheck(cudaEventRecord(stop));
	float duration_ms;
  cudaErrCheck(cudaEventSynchronize(stop));
  cudaErrCheck(cudaEventElapsedTime(&duration_ms, start, stop));
  std::cout << duration_ms / num_iteration << std::endl;

  float* h_output = new float[out_bytes];
  cudaMemcpy(h_output, d_output, out_bytes, cudaMemcpyDeviceToHost);

  delete[] h_output;
  cudaFree(d_kernel);
  cudaFree(d_input);
  cudaFree(d_output);
  cudaFree(d_workspace);

  cudnnDestroyTensorDescriptor(input_descriptor);
  cudnnDestroyTensorDescriptor(output_descriptor);
  cudnnDestroyFilterDescriptor(kernel_descriptor);
  cudnnDestroyConvolutionDescriptor(convolution_descriptor);
  cudnnDestroy(cudnn);
}

Thanks for sharing the detailed info along with sample script to reproduce the issue.
We will look into it and update you accordingly.

Thanks

Thanks, SunilJB.

Just checking in whether there’s any update on this.

Thanks!