cublasSgemv slower than expected

I am checking the performance by doing M(40000x10000) x V(10000) and I find the the cublas routine significantly slower than a simple kernel.

rreddy78@jetson-nano:~/Desktop/Technical$ sudo /usr/local/cuda/bin/nvprof ./SgemvTester

Type Time(%) Time Calls Avg Min Max Name
GPU activities: 100.00% 76.6553s 125 613.24ms 273.62ms 778.29ms void gemv2N_kernel<int, int, float, float, float, int=128, int=4, int=4, int=4, int=1, cublasGemvParams<cublasGemvTensor, cublasGemvTensor, float>>(float const )

rreddy78@jetson-nano:~/Desktop/Technical$ sudo /usr/local/cuda/bin/nvprof ./matrixVectorMultiplication

Type Time(%) Time Calls Avg Min Max Name
GPU activities: 100.00% 18.4994s 126 146.82ms 143.82ms 219.55ms matrixVectorMultiplication(float const *, float const , float, int, int)
How can this be ? My code is as follows:

const int M_SIZE = 40000;
const int W_SIZE = 10000;

__global__ void matrixVectorMultiplication(const float * __restrict__ M, const float * __restrict__ V, float *R, const int M_Size, const int W_Size)
{
  const int COL = blockIdx.x * blockDim.x + threadIdx.x;
  float tmpSum = 0.0f;

#pragma unroll  
  for(int k = 0;k < W_Size;++k)
  {
     // M is stored in column major order
     tmpSum += M[k*M_Size + COL] * V[k];
  }
  R[COL] = tmpSum;
}

exec config is:
// Fastest block configuration on jetson nano
dim3 threadsPerBlock(1024);
dim3 blocksPerGrid((M_SIZE + (threadsPerBlock.x - 1))/threadsPerBlock.x);

Your kernel code would be making out-of-bounds accesses if you allocated your M and V according to your stated sizes.

When I run what I think is a comparable test case, I observe more comparable times between your kernel code and an equivalent Sgemv call. I also see a slightly different gemv kernel being called on CUDA 11.1U1.

Your kernel is not taking into account the alpha and beta parameters, so it’s not really equivalent in the general case.

Probably wouldn’t be able to say anything further without seeing your actual, complete code, for both cases.

const int M_SIZE = 40000;    /* C Rows */
const int W_SIZE = 10000; /* C Columns */
// Initializes a fortran matrix using column wise access

void initializeFortranMatrix(float *M, const int N_Rows, const int N_Cols, int initVal)
{
    for (int col = 0; col < N_Cols; col++)
    {
        for (int row = 0; row < N_Rows; row++)
        {
            M[col * N_Rows + row] = static_cast<float>(initVal);
        }
    }
}

void initializeVector(float *V, const int N_Rows, int initVal)
{
    for (int row = 0; row < N_Rows; row++)
        V[row] = static_cast<float>(initVal);
}

int main(int argc, char **argv)
{
    float *M;
    float *V;
    float *Y;

cublasHandle_t handle;

checkCUBLAS(cublasCreate(&handle));

/* Allocate managed storage */
checkCuda(cudaMallocManaged(&M, sizeof(float) * M_SIZE * W_SIZE));
checkCuda(cudaMallocManaged(&V, sizeof(float) * W_SIZE));
checkCuda(cudaMallocManaged(&Y, sizeof(float) * M_SIZE));

initializeFortranMatrix(M, M_SIZE, W_SIZE, 9); /* Assuming column-wise storage */
initializeVector(V, W_SIZE, 2);                    /* Column */
initializeVector(Y, M_SIZE, 0);                    /* Column */

const float alpha = 1.0f;
const float beta = 0.0f;


for(int i = 0;i < 125;i++)  {
   checkCUBLAS(cublasSgemv(handle,
                           CUBLAS_OP_N,
                           M_SIZE, W_SIZE, &alpha, M, M_SIZE, V, 1, &beta, Y, 1));
   checkCuda(cudaDeviceSynchronize());

}


checkCUBLAS(cublasDestroy(handle));

checkCuda(cudaFree(M));
checkCuda(cudaFree(V));
checkCuda(cudaFree(Y));

  checkCuda(cudaDeviceReset());
  return 0; 
}

nvprof run shows a occupancy as the issue…:

    Kernel: void gemv2N_kernel<int, int, float, float, float, int=128, int=4, int=4, int=4, int=1, cublasGemvParams<cublasGemvTensor<float const >, cublasGemvTensor<float>, float>>(float const )
          1                             sm_efficiency                   Multiprocessor Activity      99.74%      99.74%      99.74%
          1                        achieved_occupancy                        Achieved Occupancy    0.561494    0.561494    0.561494 
          1                         branch_efficiency                         Branch Efficiency      99.76%      99.76%      99.76%
          1                            gld_efficiency             Global Memory Load Efficiency      99.61%      99.61%      99.61%
          1                            gld_throughput                    Global Load Throughput  2.3953GB/s  2.3953GB/s  2.3953GB/s
          1                            gst_efficiency            Global Memory Store Efficiency     100.00%     100.00%     100.00%
          1                            gst_throughput                   Global Store Throughput  242.60KB/s  242.60KB/s  242.60KB/s
          1                           tex_utilization                 Unified Cache Utilization     Low (1)     Low (1)     Low (1) 
          1                            l2_utilization                      L2 Cache Utilization     Low (1)     Low (1)     Low (1)
          1                        shared_utilization                 Shared Memory Utilization     Low (1)     Low (1)     Low (1)
          1                         shared_efficiency                  Shared Memory Efficiency      14.45%      14.45%      14.45%

For the kernel its as follows:

Device "NVIDIA Tegra X1 (0)"
    Kernel: matrixVectorMultiplication(float const *, float const *, float*, int, int)
          2                             sm_efficiency                   Multiprocessor Activity     100.00%     100.00%     100.00%
          2                        achieved_occupancy                        Achieved Occupancy    0.998419    0.998438    0.998429 
          2                         branch_efficiency                         Branch Efficiency     100.00%     100.00%     100.00%
          2                            gld_efficiency             Global Memory Load Efficiency      82.50%      82.50%      82.50%
          2                            gld_throughput                    Global Load Throughput  13.334GB/s  13.474GB/s  13.404GB/s
          2                            gst_efficiency            Global Memory Store Efficiency     100.00%     100.00%     100.00%
          2                            gst_throughput                   Global Store Throughput  1.0923MB/s  1.1038MB/s  1.0980MB/s
          2                           tex_utilization                 Unified Cache Utilization     Mid (4)    High (7)     Mid (5) 
          2                            l2_utilization                      L2 Cache Utilization     Low (2)     Low (2)     Low (2)

So i ran the nvprof with --print-gpu-trace and the results is as follows:

3.09629s  5.8330us                    -               -         -         -         -      112B  18.312MB/s    Pageable      Device  NVIDIA Tegra X1         1         7  [CUDA memcpy HtoD]
9.61335s  565.03ms           (1250 1 1)       (128 1 1)        56  2.5000KB        0B         -           -           -           -  NVIDIA Tegra X1         1         7  void gemv2N_kernel<int, int, float, float, float, int=128, int=4, int=4, int=4, int=1, cublasGemvParams<cublasGemvTensor<float const >, cublasGemvTensor<float>, float>>(float const ) [345]

So my guess is that as the kernel is using 56 registers per thread, its unable to schedule more blocks.

Total number of registers available per block: 32768
Warp size: 32
Maximum number of threads per multiprocessor: 2048
Maximum number of threads per block: 1024

Though its uncertain why 1250 blocks are created. Its more than necessary.

Another point I found in my experimentation is that shared memory usage (tiling) does not really benefit matrix-vector multiplication. It actually causes a slowdown. L1 cache works better.

This along with the lower occupancy might be the reason for numbers seen on the jetson-nano.

The simple version with unrolling works really well. The only version faster than that was when the entire vector V was in constant memory.

Does the use of the values of 1.0f and 0.0f also cause overhead in performance ?

Just in closing this thread, the following general kernel at 87ms

__constant__ float VC[W_SIZE];
__global__ void
matrixVectorMultiplication(const float *__restrict__ M, const float *__restrict__ V, float *R, const int M_Size, const int W_Size, const float alpha, const float beta, const float *__restrict__ Y)
{
    const int ROW = blockIdx.x * blockDim.x + threadIdx.x;

    if (ROW < (M_Size / 2))
    {
        float tmpSum1 = 0.0f;
        float tmpSum2 = 0.0f;

    // Loop unrolling does benefit this kernel
    #pragma unroll
        for (int k = 0; k < W_Size; ++k)
        {
          // M is assumed stored in column major frormat
           float2 valueA = reinterpret_cast<const float2 *>(M)[(k * (M_Size / 2)) + ROW];
            tmpSum1 += (alpha * valueA.x) * VC[k];
           tmpSum2 += (alpha * valueA.y) * VC[k];
        }

        float2 valueY = reinterpret_cast<const float2 *>(Y)[ROW];
        tmpSum1 += (valueY.x * beta);
        tmpSum2 += (valueY.y * beta);

        reinterpret_cast<float2 *>(R)[ROW] = float2{tmpSum1, tmpSum2};
    }
  }