Slow CUDA SGEMM

I’m measuring three approaches to matrix multiplication performance: a naive CUDA implementation, and SGEMM from CuBLAS. For simplicity all matrices are square, type float, size n x n. The compiler is nvcc V11.7.64 and GCC 8.3.1 with compilation flags -O3 for architectures 70 and 80. OS is CentOS 7

I don’t understand why CUBLAS SGEMM is the slower one. Why is a naive GPU implementation faster than an optimized library? What am I missing?

Simple code

__global__ void matrixMultiplicationKernel(int N, float* A, float* B, float* C) {
    float alpha = 1.f, beta = 0.f;
    int row = blockIdx.y*blockDim.y+threadIdx.y;
    int col = blockIdx.x*blockDim.x+threadIdx.x;

    if (row < N && col < N) {
        float tmpSum = 0;
        for (int i = 0; i < N; i++)
            tmpSum += A[row * N + i] * B[i * N + col];
        C[row * N + col] = beta*C[row * N + col] + alpha * tmpSum;
    }
}

void matmat_mul_cuda_kernel(int n, int bs, float *A, float *B, float *C) {
    thrust::device_vector<float> dvA(A, A + n*n);
    thrust::device_vector<float> dvB(B, B + n*n);
    thrust::device_vector<float> dvC(n*n);

    int nthreads = bs;
    int nblocks = ceil(float(n)/float(bs));
    dim3 blocksPerGrid(nblocks, nblocks);
    dim3 threadsPerBlock(nthreads, nthreads);

    matrixMultiplicationKernel<<<blocksPerGrid, threadsPerBlock>>>(n, 
        thrust::raw_pointer_cast(&dvA[0]),
        thrust::raw_pointer_cast(&dvB[0]),
        thrust::raw_pointer_cast(&dvC[0]));

    thrust::copy(dvC.begin(), dvC.end(), C);
}

CUBLAS SGEMM

void matmat_mul_cublas(int n, float *A, float *B, float *C) {
    thrust::device_vector<float> dvA(A, A + n*n);
    thrust::device_vector<float> dvB(B, B + n*n);
    thrust::device_vector<float> dvC(C, C + n*n);

    int lda=n, ldb=n, ldc=n;
    const float alpha = 1.0, beta = 0.0;

    cublasHandle_t handle;
    cublasCreate(&handle);
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, &alpha, 
                thrust::raw_pointer_cast(&dvB[0]), lda, 
                thrust::raw_pointer_cast(&dvA[0]), ldb, &beta, 
                thrust::raw_pointer_cast(&dvC[0]), ldc);

    thrust::copy(dvC.begin(), dvC.end(), C);
    cublasDestroy(handle);
}

The google benchmark results for Tesla V100-PCIE (n = [4096, 32768], 16 blocks):

Benchmark                                     Time             CPU   Iterations
-------------------------------------------------------------------------------
MatMul/CudaKernel/4096/16/real_time        40.8 ms         40.8 ms           17
MatMul/CudaKernel/32768/16/real_time       4170 ms         4169 ms            1
MatMul/CuBlas/4096/real_time               61.0 ms         61.0 ms           11
MatMul/CuBlas/32768/real_time              9908 ms         9906 ms            1

ldd info

        libcublas.so.11 => /mnt/software/c/cuda/11.7.0_515.43.04/lib64/libcublas.so.11 
        libcudart.so.11.0 => /mnt/software/c/cuda/11.7.0_515.43.04/lib64/libcudart.so.11.0 
        libcublasLt.so.11 => /mnt/software/c/cuda/11.7.0_515.43.04/lib64/libcublasLt.so.11

My guess is you are making an error in measurement, or in some other aspect you haven’t shown. This is on a Tesla V100 PCIE with CUDA 11.4:

$ cat t2109.cu
#include <cublas_v2.h>
#include <thrust/device_vector.h>
#include <iostream>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}
__global__ void matrixMultiplicationKernel(int N, float* A, float* B, float* C) {
    float alpha = 1.f, beta = 0.f;
    int row = blockIdx.y*blockDim.y+threadIdx.y;
    int col = blockIdx.x*blockDim.x+threadIdx.x;

    if (row < N && col < N) {
        float tmpSum = 0;
        for (int i = 0; i < N; i++)
            tmpSum += A[row * N + i] * B[i * N + col];
        C[row * N + col] = beta*C[row * N + col] + alpha * tmpSum;
    }
}

void matmat_mul_cuda_kernel(int n, int bs, float *A, float *B, float *C) {
    thrust::device_vector<float> dvA(A, A + n*n);
    thrust::device_vector<float> dvB(B, B + n*n);
    thrust::device_vector<float> dvC(n*n);

    int nthreads = bs;
    int nblocks = ceil(float(n)/float(bs));
    dim3 blocksPerGrid(nblocks, nblocks);
    dim3 threadsPerBlock(nthreads, nthreads);

    matrixMultiplicationKernel<<<blocksPerGrid, threadsPerBlock>>>(n,
        thrust::raw_pointer_cast(&dvA[0]),
        thrust::raw_pointer_cast(&dvB[0]),
        thrust::raw_pointer_cast(&dvC[0]));

    thrust::copy(dvC.begin(), dvC.end(), C);
}

void matmat_mul_cublas(int n, float *A, float *B, float *C) {
    thrust::device_vector<float> dvA(A, A + n*n);
    thrust::device_vector<float> dvB(B, B + n*n);
    thrust::device_vector<float> dvC(C, C + n*n);

    int lda=n, ldb=n, ldc=n;
    const float alpha = 1.0, beta = 0.0;

    cublasHandle_t handle;
    cublasCreate(&handle);
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, &alpha,
                thrust::raw_pointer_cast(&dvB[0]), lda,
                thrust::raw_pointer_cast(&dvA[0]), ldb, &beta,
                thrust::raw_pointer_cast(&dvC[0]), ldc);

    thrust::copy(dvC.begin(), dvC.end(), C);
    cublasDestroy(handle);
}

int main(){

  const int n = 4096;
  float *A = new float[n*n];
  float *B = new float[n*n];
  float *C = new float[n*n];
  // warm up
  matmat_mul_cublas(n, A, B, C);
  matmat_mul_cuda_kernel(n, 32, A, B, C);
  cudaDeviceSynchronize();
  unsigned long long dt = dtime_usec(0);
  matmat_mul_cublas(n, A, B, C);
  dt = dtime_usec(dt);
  std::cout << "cublas time: " << dt << "us" << std::endl;
  dt = dtime_usec(0);
  matmat_mul_cuda_kernel(n, 32, A, B, C);
  dt = dtime_usec(dt);
  std::cout << "kernel time: " << dt << "us" << std::endl;

}

$ nvcc -o t2109 t2109.cu -lcublas -arch=sm_70 -O3
$ ./t2109
cublas time: 83996us
kernel time: 121461us
$ nvprof ./t2109
==16739== NVPROF is profiling process 16739, command: ./t2109
cublas time: 84565us
kernel time: 120242us
==16739== Profiling application: ./t2109
==16739== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   42.08%  173.61ms        10  17.361ms  15.931ms  18.087ms  [CUDA memcpy HtoD]
                   34.42%  141.98ms         2  70.990ms  65.285ms  76.695ms  matrixMultiplicationKernel(int, float*, float*, float*)
                   18.50%  76.338ms         4  19.084ms  16.625ms  24.871ms  [CUDA memcpy DtoH]
                    4.96%  20.444ms         2  10.222ms  10.214ms  10.229ms  volta_sgemm_128x32_sliced1x4_nn
                    0.04%  163.94us         2  81.970us  81.954us  81.986us  void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__uninitialized_fill::functor<thrust::device_ptr<float>, float>, unsigned long>, thrust::cuda_cub::__uninitialized_fill::functor<thrust::device_ptr<float>, float>, unsigned long>(thrust::device_ptr<float>, float)
      API calls:   45.12%  603.34ms        21  28.731ms  2.9370us  365.01ms  cudaFree
                   31.33%  418.96ms        14  29.925ms  16.090ms  94.595ms  cudaMemcpyAsync
                   22.29%  298.00ms        18  16.556ms  4.7470us  292.94ms  cudaMalloc
                    0.66%  8.8728ms        12  739.40us  247.69us  2.9730ms  cuDeviceTotalMem
                    0.41%  5.5088ms      1188  4.6370us     158ns  263.45us  cuDeviceGetAttribute
                    0.08%  1.0749ms        16  67.180us  4.1460us  92.554us  cudaStreamSynchronize
                    0.05%  646.42us        12  53.868us  35.556us  131.27us  cuDeviceGetName
                    0.02%  221.64us       746     297ns     169ns  1.4220us  cuGetProcAddress
                    0.02%  211.24us         6  35.206us  9.9180us  82.808us  cudaLaunchKernel
                    0.00%  40.938us         9  4.5480us  1.5450us  10.199us  cudaDeviceSynchronize
                    0.00%  32.839us        36     912ns     445ns  7.9090us  cudaEventCreateWithFlags
                    0.00%  30.118us         4  7.5290us  3.1350us  13.649us  cuDeviceGetPCIBusId
                    0.00%  26.690us        36     741ns     413ns  6.8180us  cudaEventDestroy
                    0.00%  14.603us         9  1.6220us     337ns  4.7200us  cudaGetDevice
                    0.00%  13.311us         4  3.3270us     759ns  6.7450us  cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
                    0.00%  12.019us        35     343ns     159ns  2.0810us  cudaGetLastError
                    0.00%  10.797us        16     674ns     179ns  4.0850us  cuDeviceGet
                    0.00%  10.137us        16     633ns     316ns  2.2920us  cudaDeviceGetAttribute
                    0.00%  6.4130us         1  6.4130us  6.4130us  6.4130us  cudaFuncGetAttributes
                    0.00%  5.2780us        12     439ns     273ns  1.0140us  cuDeviceGetUuid
                    0.00%  4.0730us         5     814ns     382ns  1.6460us  cuDeviceGetCount
                    0.00%  2.5270us         1  2.5270us  2.5270us  2.5270us  cudaGetSymbolAddress
                    0.00%  1.7370us         2     868ns     823ns     914ns  cuInit
                    0.00%     956ns         4     239ns     146ns     405ns  cudaPeekAtLastError
                    0.00%     466ns         2     233ns     212ns     254ns  cuDriverGetVersion
                    0.00%     340ns         1     340ns     340ns     340ns  cudaGetDeviceCount
$

The overall timing of the routines you have presented suggests that CUBLAS is faster, and the profiler shows that whereas the naive kernel takes about 70ms, the cublas kernel takes about 10 ms.

Changing block size from 32 to 16 does not have a meaningful impact on the results.

No, I won’t be able to take a look at a google test harness.

Your cublas routine copies the host C vector to the device C vector at the beginning of the routine. Nothing wrong with that of course, but your naive routine does not. That makes it not an apples-to-apples comparison, but according to my testing that doesn’t seriously skew the results. CUBLAS still wins, whether you measure your whole routine, or the kernel itself.

The classic computer science off-by-one error.

The number of the counting shall be 3.

The google benchmark routine just calls matmat_mul_*(n, A, B, C) repeatedly and reports the average time (one warm-up call is not included). Therefore both of the functions copy the data each call.

I disagree.

This copies data:

This doesn’t:

This is evident in the fact that in profiler report that I show, there are 10 instances of CUDA memcpy H->D. If this were an apples-to-apples comparison, there would either be 8 of those, or 12 of those.

The code that you have shown is flawed in this respect, for careful comparison purposes.

I won’t argue it further. It’s OK if we disagree.

1 Like

That is a mistake. Thanks.
Strangely google benchmark time is really not consistent with straightforward time measurement.