Reduced CuBLAS performance on a particular problem size?

Hi,
I once profiled a call to CuBlas call to see it’s performance using the tensor cores with fp16 operands and output and particular problem size(1536(m)x1024(n)x1024(k)). The accumulate was fp32. While on CUDA 10.2 when the GPU clock is locked at max(1965 MHz), the performance was around 46~48 TFLOPS/s. The same config when using CUDA 11.1 gave me around 33~35 TFLOPS/s(GPU clocks still locked). I was surprised to see decrease in performance on a newer CUDA version. Perhaps I am missing something. It would be great if someone could try it if they have the same config as mine, which I list below or give their opinion.

GPU:-
GeForce RTX 2080 Ti
Machine:-
Intel(R) Xeon(R) Silver 4110 CPU
OS:-
CENTO OS 8(4.18.0-193.6.3.el8_2.x86_64)

Test 1 with CUDA 10.2, Driver 440.33.01
Test 2 with CUDA 11.1 Driver 455.23.05

CODE:-

include
include “curand.h”
include “cuda_fp16.h”
include “common.h”

using namespace std;

__global__ void convertFp32ToFp16 (half *out, float *in, int n) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < n) {
out[idx] = (in[idx]);
}
}

__global__ void convertFp16ToFp32 (float *out, half *in, int n) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < n) {
out[idx] = (in[idx]);
}
}

void print_matrix(float *A, int nr_rows_A, int nr_cols_A) {
for(int i = 0; i < nr_rows_A; i++){
for(int j = 0; j < nr_cols_A; j++){
std::cout << A[i * nr_cols_A + j] << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}

// Fill the array with random numbers on GPU
void GPU_fill_rand(float *A, int nr_rows_A, int nr_cols_A) {
// Create a pseudo-random number generator
curandGenerator_t prng;
curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_DEFAULT);

// Set the seed for the random number generator using the system clock
curandSetPseudoRandomGeneratorSeed(prng, (unsigned long long) clock());

// Fill the array with random numbers on the device
curandGenerateUniform(prng, A, nr_rows_A * nr_cols_A);
}

void gpu_blas_mmul(__half *A, __half *B, __half *C, int m, int k, int n, int iter) {
const __half alf = 1.0f;
const __half bet = 0.0f;
const __half *alpha = &alf;
const __half *beta = &bet;

// Create a handle for CUBLAS
cublasHandle_t handle;
cublasStatus_t cublasStat = cublasCreate(&handle);

// Set the math mode to allow cuBLAS to use Tensor Cores:
cublasStat = cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);

int lda = m, ldb = k, ldc = n;
float matmulTime = 0.0f;

//-------------------------------peforming warmup runs-------------------------------------//
for(int i = 0; i < 1; i++){
// Do the actual multiplication
check_cuda_error(cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N,
/*number of rows of matrix op(A) and C*/ n,
/*number of columns of matrix op(B) and C*/ m,
/*number of columns of op(A) and rows of op(B)*/ k, alpha, B, CUDA_R_16F, ldb, A, CUDA_R_16F, lda, beta, C, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

//-------------------------------------perform actual runs--------------------------------//
cudaDeviceSynchronize();
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
int niter = iter;
for(int i = 0; i < niter; i++){
//Do the actual multiplication
cudaEventRecord(start, NULL);
check_cuda_error(cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N,
/*number of rows of matrix op(A) and C*/ n,
/*number of columns of matrix op(B) and C*/ m,
/*number of columns of op(A) and rows of op(B)* /k, alpha, B, CUDA_R_16F, ldb, A, CUDA_R_16F, lda, beta, C, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
cudaEventRecord(stop, NULL);

//stop event to complete
cudaEventSynchronize(stop);

float msecTotal = 0.0f;
cudaEventElapsedTime(&msecTotal, start, stop);
matmulTime += msecTotal;
}

double flopsPerMatrixMul = 2.0 * (double) m * (double) n * (double) k;
double teraFlops = (flopsPerMatrixMul * 1.0e-12f) / (matmulTime / niter / 1000.0f);
std::cout<<m<<", “<<n<<”, “<<k<<”, “<<teraFlops<<” FLOPs: "<<flopsPerMatrixMul<<std::endl;

// Destroy the handle
cublasDestroy(handle);
}

int main(int argc, char * argv[]) {
int nr_rows_A, nr_cols_A, nr_rows_B, nr_cols_B, nr_rows_C, nr_cols_C;

if(argc != 5){
printf(“4 args required, aborting!.\n”);
return 0;
}

char * mc = argv[1];
char * kc = argv[2];
char * nc = argv[3];
char * iter = argv[4];

int m = atoi(mc);
int k = atoi(kc);
int n = atoi(nc);
int niter = atoi(iter);

nr_rows_A = m;
nr_cols_A = k;
nr_rows_B = k;
nr_cols_B = n;
nr_rows_C = m;
nr_cols_C = n;

float *df_A, *df_B;
__half *d_A, *d_B, *d_C;

check_cuda_error(cudaMalloc(&d_A,nr_rows_A * nr_cols_A * sizeof(__half)));
check_cuda_error(cudaMalloc(&df_A,nr_rows_A * nr_cols_A * sizeof(float)));
GPU_fill_rand(df_A, nr_rows_A, nr_cols_A);
convertFp32ToFp16 <<< (nr_rows_A * nr_cols_A+ 255) / 256, 256 >>> (d_A, df_A, nr_rows_A * nr_cols_A);

check_cuda_error(cudaMalloc(&d_B,nr_rows_B * nr_cols_B * sizeof(__half)));
check_cuda_error(cudaMalloc(&df_B,nr_rows_B * nr_cols_B * sizeof(float)));
GPU_fill_rand(df_B, nr_rows_B, nr_cols_B);
convertFp32ToFp16 <<< (nr_rows_B * nr_cols_B + 255) / 256, 256 >>> (d_B, df_B, nr_rows_B * nr_cols_B);

check_cuda_error(cudaMalloc(&d_C,nr_rows_C * nr_cols_C * sizeof(__half)));

gpu_blas_mmul(d_A, d_B, d_C, nr_rows_A, nr_cols_A, nr_rows_B, niter);

cudaFree(d_A);
cudaFree(d_B);
cudaFree(df_A);
cudaFree(df_B);
cudaFree(d_C);

return 0;
}