Is it correct that my Pascal card is calling Maxwell_gemm kernels through cublas? And if so, why is cublas unusably slow for me?

I’ve had problems with poor cublas gemm performance for a long time, I was hoping that a newer cuda toolkit version would solve my issues but now that I have upgraded to v 9.2 I can see that the performance issues are still very much here so I will consult this forum again and see if anyone has a solution.

My problem is that cublas functions such as dgemm and sgemm are so slow that they are practically unusable for me since poorly written custom kernels often outperform them on my system. Here is an example where I calculate a large batch of small matrix-matrix multiplications both using the cublasSgemmStridedBatched function and with a naïve, unoptimized, manual kernel:

#include "mex.h"
#include "cublas_v2.h"
#include <time.h>

// Kernel code for naive matrix multiplication 
__global__ void MatrixMultiplication(const float* A, const float* B, float* C, const int m, const int n, const int q, const int k){
	float e;
	for (int page = blockIdx.z * blockDim.z + threadIdx.z; page < k; page += blockDim.z * gridDim.z) {
		// Loop rows of A
		for (int rows = blockIdx.x * blockDim.x + threadIdx.x; rows < m; rows += blockDim.x * gridDim.x) {
			// Loop columns of B.
			for (int cols = blockIdx.y * blockDim.y + threadIdx.y; cols < q; cols += blockDim.y * gridDim.y) {
				// Initialize temporary variable e to hold values of current row/col run
				e = 0;
				// Loop columns of A/rows of B
				for (int steps = 0; steps < n; steps++) {
					e = e + A[rows + steps*m + m*n*page] * B[steps + cols*n + n*q*page];
				}
				C[rows + cols*m + m*q*page] = e;
			}
		}
	}
}

/* The MEX gateway function */
void mexFunction(int nlhs, mxArray ∗plhs[], int nrhs,const mxArray ∗prhs[]) {

// Get host input matrices A and B.
const float *A, *B;
A = (float*)mxGetPr(prhs[0]);
B = (float*)mxGetPr(prhs[1]);

// Get size of inputs.
size_t m, n, q, k;
clock_t t1, t2;
const mwSize *Adims, *Bdims;
Adims = mxGetDimensions(prhs[0]);
Bdims = mxGetDimensions(prhs[1]);
m = Adims[0];
n = Adims[1];
k = Adims[2];
q = Bdims[1];

// Allocate device memory for A, B and C.
float *dA, *dB, *dCcublas, *dCmanual;
cudaMalloc(&dA, sizeof(float) * m * n * k);
cudaMalloc(&dB, sizeof(float) * n * q * k);
cudaMalloc(&dCcublas, sizeof(float) * m * q * k);
cudaMalloc(&dCmanual, sizeof(float) * m * q * k);

// Copy A & B to device.
cudaMemcpy(dA, A, m * n * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(dB, B, n * q * k * sizeof(float), cudaMemcpyHostToDevice);

// Cublas handle
cublasHandle_t h;
cublasCreate(&h);

// Calculate C = A*B using cublas.
const float One = 1.0;
const float Zero = 0.0;

t1 = clock();
cublasSgemmStridedBatched(h, CUBLAS_OP_N, CUBLAS_OP_N, m, q, n, &One, dA, m, m * n, dB, n, n * q, &Zero, dCcublas, m, m*q, k);
cudaDeviceSynchronize();
t2 = clock();
printf("cublasSgemmStridedBatched took: %.3f ms\n", (double)(t2 - t1) / CLOCKS_PER_SEC * 1000);

// Calculate C = A*B using manual kernel.
int xnTPB = 32; int ynTPB = 2; int znTPB = 16;
int xnumBlocks = (m + xnTPB - 1) / xnTPB;
int ynumBlocks = (q + ynTPB - 1) / ynTPB;
int znumBlocks = (k + znTPB - 1) / znTPB;
dim3 Nblocks(xnumBlocks, ynumBlocks, znumBlocks);
dim3 nTPB(xnTPB, ynTPB, znTPB);

t1 = clock();
MatrixMultiplication<<<Nblocks, nTPB>>>(dA, dB, dCmanual, m, n, q, k);
cudaDeviceSynchronize();
t2 = clock();
printf("Manual matmult kernel took: %.3f ms\n", (double)(t2 - t1) / CLOCKS_PER_SEC * 1000);

float *HostOutputCcublas, *HostOutputCmanual;
mwSize dims[3] = { m, q, k };
plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL);
plhs[1] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL);
HostOutputCcublas = (float*)mxGetPr(plhs[0]);
HostOutputCmanual = (float*)mxGetPr(plhs[1]);
cudaDeviceSynchronize();
cudaMemcpy(HostOutputCcublas, dCcublas, m * q * k * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(HostOutputCmanual, dCmanual, m * q * k * sizeof(float), cudaMemcpyDeviceToHost);

// Free memory & destroy cublas handle.
cudaFree(dA);
cudaFree(dB);
cudaFree(dCcublas);
cudaFree(dCmanual);
cublasDestroy(h);
}

Im compiling and running the function from Matlab with the following Matlab code:

mexcuda WhatsWrongWithMyCublas.cu -lcublas
m = 150;
n = 256;
q = 2;
k = 10000;
A = rand(m,n,k,'single');
B = rand(n,q,k,'single');
[Cblas,Cmanual] = WhatsWrongWithMyCublas(A,B);

The output matrices Cblas and Cmanual are identical and the printed output becomes:

cublasSgemmStridedBatched took: 7.000 ms
Manual matmult kernel took: 5.000 ms

When profiling the kernel launches with Nsight it verifies that what clock() reports is true with the following:

Function Name: maxwell_sgemm_128_64_nn Grid Dimensions: {2,1,10000} Block Dimensions: {128,1,1} Duration: 7200.987 µs Occupancy: 25 %
Function Name: MatrixMultiplication    Grid Dimensions: {5,1,625}   Block Dimensions: {32,2,16} Duration: 5200.831 µs Occupancy: 100 %

So how can this possibly be? How can a highly optimized batched gemm library be slower than a naïve looped matrix multiplication at precisely what it is designed to do—thousands of equally sized small matrix multiplications?

One thing I noticed is that cublas is using the “maxwell_sgemm_128_64_nn” kernel, but Im using a pascal architecture card (GTX 1080ti). Is it supposed to call a cublas function named “Maxwell” or is there a corresponding “pascal_sgemm_128_64_nn” kernel that my system should be using?

This exact problem persists across multiple cuda toolkit versions, multiple GPU driver versions and across multiple different GPUs I’ve tested (all pascal series).

Is there anyone else who has this problem or knows what the reason for it could be? Obviously there is something wrong here…

Yes, it’s expected that cublas operations on pascal GPUs may run cublas kernels internally named with “maxwell” in them. If you study a cc5.2 SM and a cc6.1 SM, I think you will find not much difference with respect to optimizing for this.

I agree that for the specific choices for m,n,q you have chosen, your kernel outpaces the cublas call (by ~2x factor, or perhaps less).

Try changing q from 2 to 200. When I do that I witness the cublas call being about 3x faster than your kernel.

My suggestion would be to file this as a performance bug/RFE at developer.nvidia.com

I suspect the very narrow q dimension is not well optimized for.

If you profile your code with the flops_count_sp metric, you will find the CUBLAS kernel using about 8x as many flops as your kernel. In the q=200 scenario, the profiler shows the cublas kernel using 2x as many flops as your kernel.

non-matlab version of your code, with q=200 case, on CUDA 9.2/Tesla V100/CentOS 7:

$ cat t1403.cu
#include <stdio.h>
#include <cublas_v2.h>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}

// Kernel code for naive matrix multiplication
__global__ void MatrixMultiplication(const float* A, const float* B, float* C, const int m, const int n, const int q, const int k){
        float e;
        for (int page = blockIdx.z * blockDim.z + threadIdx.z; page < k; page += blockDim.z * gridDim.z) {
                // Loop rows of A
                for (int rows = blockIdx.x * blockDim.x + threadIdx.x; rows < m; rows += blockDim.x * gridDim.x) {
                        // Loop columns of B.
                        for (int cols = blockIdx.y * blockDim.y + threadIdx.y; cols < q; cols += blockDim.y * gridDim.y) {
                                // Initialize temporary variable e to hold values of current row/col run
                                e = 0;
                                // Loop columns of A/rows of B
                                for (int steps = 0; steps < n; steps++) {
                                        e = e + A[rows + steps*m + m*n*page] * B[steps + cols*n + n*q*page];
                                }
                                C[rows + cols*m + m*q*page] = e;
                        }
                }
        }
}

void test() {
size_t m = 150;
size_t n = 256;
size_t q = 200;
size_t k = 10000;
// Get host input matrices A and B.
float *A, *B;
A = (float*)malloc(m*n*k*sizeof(A[0]));
B = (float*)malloc(n*q*k*sizeof(B[0]));

// Get size of inputs.
unsigned long long t1, t2;

// Allocate device memory for A, B and C.
float *dA, *dB, *dCcublas, *dCmanual;
cudaMalloc(&dA, sizeof(float) * m * n * k);
cudaMalloc(&dB, sizeof(float) * n * q * k);
cudaMalloc(&dCcublas, sizeof(float) * m * q * k);
cudaMalloc(&dCmanual, sizeof(float) * m * q * k);
for (size_t i = 0; i < m*n*k; i++) A[i] = 1.0f;
for (size_t i = 0; i < n*q*k; i++) B[i] = 1.0f;
// Copy A & B to device.
cudaMemcpy(dA, A, m * n * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(dB, B, n * q * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemset(dCcublas, 0, sizeof(float)*m*q*k);
cudaMemset(dCmanual, 0, sizeof(float)*m*q*k);
// Cublas handle
cublasHandle_t h;
cublasCreate(&h);

// Calculate C = A*B using cublas.
const float One = 1.0;
const float Zero = 0.0;

t1 = dtime_usec(0);
cublasSgemmStridedBatched(h, CUBLAS_OP_N, CUBLAS_OP_N, m, q, n, &One, dA, m, m * n, dB, n, n * q, &Zero, dCcublas, m, m*q, k);
cudaDeviceSynchronize();
t2 = dtime_usec(t1);
printf("cublasSgemmStridedBatched took: %.3f ms\n", (t2 / (float)USECPSEC) * 1000.0f);

// Calculate C = A*B using manual kernel.
int xnTPB = 32; int ynTPB = 2; int znTPB = 16;
int xnumBlocks = (m + xnTPB - 1) / xnTPB;
int ynumBlocks = (q + ynTPB - 1) / ynTPB;
int znumBlocks = (k + znTPB - 1) / znTPB;
dim3 Nblocks(xnumBlocks, ynumBlocks, znumBlocks);
dim3 nTPB(xnTPB, ynTPB, znTPB);

t1 = dtime_usec(0);
MatrixMultiplication<<<Nblocks, nTPB>>>(dA, dB, dCmanual, m, n, q, k);
cudaDeviceSynchronize();
t2 = dtime_usec(t1);
printf("Manual matmult kernel took: %.3f ms\n", (t2 / (float)USECPSEC)* 1000);

float *HostOutputCcublas, *HostOutputCmanual;
HostOutputCcublas = (float*)malloc(m*q*k*sizeof(float));
HostOutputCmanual = (float*)malloc(m*q*k*sizeof(float));
cudaMemcpy(HostOutputCcublas, dCcublas, m * q * k * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(HostOutputCmanual, dCmanual, m * q * k * sizeof(float), cudaMemcpyDeviceToHost);
for (size_t i = 0; i < m*q*k; i++) if (HostOutputCcublas[i] != HostOutputCmanual[i]) {printf("results mismatch at %lu, cublas: %f, kernel: %f\n", i, HostOutputCcublas[i], HostOutputCmanual[i]); return;}
// Free memory & destroy cublas handle.
printf("result: %f\n", HostOutputCcublas[0]);
cudaFree(dA);
cudaFree(dB);
cudaFree(dCcublas);
cudaFree(dCmanual);
cublasDestroy(h);
}

int main(){

  test();
  return 0;
}

$ nvcc t1403.cu -o t1403 -lcublas
$ ./t1403
cublasSgemmStridedBatched took: 27.222 ms
Manual matmult kernel took: 93.264 ms
result: 256.000000
$

For me the poor cublas performance is not limited to strange and narrow matrix sizes.
I’ve now written an example where all the dimensions are n and n (n*n) matrix multiplications are calculated. In other words, the matrices are all square and should be ideal for batched gemm.

#include "mex.h"
#include "cublas_v2.h"
#include <time.h>

// Kernel code for naive matrix multiplication 
__global__ void MatrixMultiplication(const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C, const int n) {
	double e;
	for (int page = blockIdx.z * blockDim.z + threadIdx.z; page < n; page += blockDim.z * gridDim.z) {
		// Loop rows of A
		for (int rows = blockIdx.x * blockDim.x + threadIdx.x; rows < n; rows += blockDim.x * gridDim.x) {
			// Loop columns of B.
			for (int cols = blockIdx.y * blockDim.y + threadIdx.y; cols < n; cols += blockDim.y * gridDim.y) {
				// Initialize temporary variable e to hold values of current row/col run
				e = 0;
				// Loop columns of A/rows of B
				for (int steps = 0; steps < n; steps++) {
					e += A[rows + steps*n] * B[steps + cols*n + n*n*page];
				}
				C[rows + cols*n + n*n*page] = e;
			}
		}
	}
}

/* The MEX gateway function */
void mexFunction(int nlhs, mxArray ∗plhs[], int nrhs, const mxArray ∗prhs[]) {

	// Get host input matrices A and B.
	const double *A, *B;
	A = mxGetPr(prhs[0]);
	B = mxGetPr(prhs[1]);

	// Get size of inputs.
	size_t n;
	clock_t t1, t2;
	n = mxGetN(prhs[0]);

	// Allocate device memory for A, B and C.
	double *dA, *dB, *dCcublas, *dCmanual;
	cudaMalloc(&dA, sizeof(double) * n * n);
	cudaMalloc(&dB, sizeof(double) * n * n * n);

	cudaMalloc(&dCcublas, sizeof(double) * n * n * n);
	cudaMalloc(&dCmanual, sizeof(double) * n * n * n);

	// Copy A & B to device.
	cudaMemcpy(dA, A, n * n * sizeof(double), cudaMemcpyHostToDevice);
	cudaMemcpy(dB, B, n * n * n * sizeof(double), cudaMemcpyHostToDevice);

	// Cublas handle
	cublasHandle_t h;
	cublasCreate(&h);

	// Calculate C = A*B using cublas.
	const double One = 1.0;
	const double Zero = 0.0;

	t1 = clock();
	cublasDgemmStridedBatched(h, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, &One, dA, n, 0, dB, n, n*n, &Zero, dCcublas, n, n*n, n);
	cudaDeviceSynchronize();
	t2 = clock();
	printf("cublasDgemmStridedBatched took: %.3f ms\n", (double)(t2 - t1) / CLOCKS_PER_SEC * 1000);

	// Calculate C = A*B using manual kernel.
	int xnTPB = 8; int ynTPB = 8; int znTPB = 8;
	int xnumBlocks = (n + xnTPB - 1) / xnTPB;
	int ynumBlocks = (n + ynTPB - 1) / ynTPB;
	int znumBlocks = (n + znTPB - 1) / znTPB;
	dim3 Nblocks(xnumBlocks, ynumBlocks, znumBlocks);
	dim3 nTPB(xnTPB, ynTPB, znTPB);

	t1 = clock();
	MatrixMultiplication<<<Nblocks, nTPB>>>(dA, dB, dCmanual, n);
	cudaDeviceSynchronize();
	t2 = clock();
	printf("Manual matmult kernel took: %.3f ms\n", (double)(t2 - t1) / CLOCKS_PER_SEC * 1000);

	double *HostOutputCcublas, *HostOutputCmanual;
	mwSize dims[3] = { n, n, n };
	plhs[0] = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
	plhs[1] = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
	HostOutputCcublas = (double*)mxGetPr(plhs[0]);
	HostOutputCmanual = (double*)mxGetPr(plhs[1]);
	cudaDeviceSynchronize();
	cudaMemcpy(HostOutputCcublas, dCcublas, n * n * n * sizeof(double), cudaMemcpyDeviceToHost);
	cudaMemcpy(HostOutputCmanual, dCmanual, n * n * n * sizeof(double), cudaMemcpyDeviceToHost);

	// Free memory & destroy cublas handle.
	cudaFree(dA);
	cudaFree(dB);
	cudaFree(dCcublas);
	cudaFree(dCmanual);
	cublasDestroy(h);
}

With n = 200 cublasDgemmStridedBatched takes 16 ms for me while my manual multiplication takes 9 ms.
And the profiler tells me that the occupancy during dgemm is only 12.5 %…

I don’t think my suggestion would change in any way. File a performance RFE/bug.

It’s inconvenient for me to continually translate your examples from a matlab callable routine to a C-callable routine, so I’ll otherwise skip trying to analyze your latest case.

Your GPU doesn’t have a relatively large amount of double-precision (DP) throughput (the DP throughput is 1/32 of the SP throughput). Of course, this shouldn’t on the surface make a difference, since both of your tests are using double-precision, but it’s possible that cublas dgemm is simply unoptimized for that case (i.e. for the case of running on a cc6.1 device). That’s just a guess, though. Maybe if you converted to float everything would still be the same.

In my previous test where I built a C-callable harness, I was running on a Tesla V100, and the code was using float, but since I was able to “see” the trend/case you were referring to, I thought it was useful to share my observation. Perhaps not. When I change my previous test case to m=n=q=200, I get a cublas time of 21ms and your kernel takes 94ms (on Tesla V100). I don’t have a GTX1080Ti or anything like it currently conveniently available to test on.

I just ran my C-code version on a GTX 960 with CUDA 9.0:

with m=150, n=256, q = 2, got 35ms for the cublas version, and 20ms for the “manual” version
with m=150, n=256, q = 20, got 35ms for the cublas version, and 184ms for the “manual” version

note that your original posted code used float, whereas your second posting uses double

My code uses float.

Yes, that is pretty consistent with what I’m getting. The float functions of CUBLAS are working fine in most circumstances. It’s the double-precision versions that sadly appear to be completely broken

I’m able to also observe that in the case of double, the cublas method becomes slower and the kernel method becomes faster (on GTX 960). I’ve filed an internal bug but I don’t expect any immediate progress on it. I don’t have any further information.