I try to use shared memory to do gemm but it looks slower than gemm with gloabel memory.The kernal of shared memory is from CUDA docs 3.2.4 shared memory. For the same matrix size 4096*4096,the cubls works fine but shared memory kernel is too low. I want to know whats’ wrong with my code.
4070 laptop 12800HX 32GBmemory win11
nvcc -std=c++17 -arch=sm_89 -g -lcublas -lcudart -G -O3 -o test test.cu
shared mem kernel time 1631.566650ms
cublas time 41.416702ms
naive time 1208.498047ms
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <iostream>
#include <stdio.h>
#include <cublas_v2.h>
// Thread block size
#define BLOCK_SIZE 16
// Matrices are stored in row-major order:
// M(row, col) = *(M.elements + row * M.stride + col)
typedef struct {
int width;
int height;
int stride;
float* elements;
} Matrix;
// Get a matrix element
__device__ float GetElement(const Matrix A, int row, int col)
{
return A.elements[row * A.stride + col];
}
// Set a matrix element
__device__ void SetElement(Matrix A, int row, int col,
float value)
{
A.elements[row * A.stride + col] = value;
}
// Get the BLOCK_SIZExBLOCK_SIZE sub-matrix Asub of A that is
// located col sub-matrices to the right and row sub-matrices down
// from the upper-left corner of A
__device__ Matrix GetSubMatrix(Matrix A, int row, int col)
{
Matrix Asub;
Asub.width = BLOCK_SIZE;
Asub.height = BLOCK_SIZE;
Asub.stride = A.stride;
Asub.elements = &A.elements[A.stride * BLOCK_SIZE * row
+ BLOCK_SIZE * col];
return Asub;
}
// Forward declaration of the matrix multiplication kernel
__global__ void MatMulKernel(const Matrix, const Matrix, Matrix);
__global__ void naiveKernel(const Matrix, const Matrix, Matrix);
// Matrix multiplication - Host code
// Matrix dimensions are assumed to be multiples of BLOCK_SIZE
void MatMul(const Matrix A, const Matrix B, Matrix C)
{
// Load A and B to device memory
Matrix d_A;
d_A.width = d_A.stride = A.width; d_A.height = A.height;
size_t size = A.width * A.height * sizeof(float);
cudaMalloc(&d_A.elements, size);
cudaMemcpy(d_A.elements, A.elements, size,
cudaMemcpyHostToDevice);
Matrix d_B;
d_B.width = d_B.stride = B.width; d_B.height = B.height;
size = B.width * B.height * sizeof(float);
cudaMalloc(&d_B.elements, size);
cudaMemcpy(d_B.elements, B.elements, size,
cudaMemcpyHostToDevice);
// Allocate C in device memory
Matrix d_C;
d_C.width = d_C.stride = C.width; d_C.height = C.height;
size = C.width * C.height * sizeof(float);
cudaMalloc(&d_C.elements, size);
// Invoke kernel
dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
dim3 dimGrid(B.width / dimBlock.x, A.height / dimBlock.y);
cudaEvent_t start,end;
cudaEventCreate(&start);
cudaEventCreate(&end);
cudaEventRecord(start);
// cudaEventSynchronize(start);
MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
cudaEventRecord(end);
cudaEventSynchronize(end);
float msec;
cudaEventElapsedTime(&msec, start, end);
printf("shared mem kernel time %.6f\n",msec);
cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);
float cublas_alpha = 1.0;
float cublas_beta = 0;
cudaEvent_t start2,end2;
cudaEventCreate(&start2);
cudaEventCreate(&end2);
cudaEventRecord(start2);
cudaEventSynchronize(start2);
cublasSgemm_v2(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, 4096,4096, 4096,
&cublas_alpha, d_A.elements, 4096, d_B.elements, 4096, &cublas_beta, d_C.elements, 4096);
cudaEventRecord(end2);
cudaEventSynchronize(end2);
float msec2;
cudaEventElapsedTime(&msec2, start2, end2);
printf("cublas time %.6fms\n",msec2);
cudaEvent_t start3,end3;
cudaEventCreate(&start3);
cudaEventCreate(&end3);
cudaEventRecord(start3);
cudaEventSynchronize(start3);
naiveKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
cudaEventRecord(end3);
cudaEventSynchronize(end3);
float msec3;
cudaEventElapsedTime(&msec3, start3, end3);
printf("naive time %.6fms\n",msec3);
// Read C from device memory
cudaMemcpy(C.elements, d_C.elements, size,
cudaMemcpyDeviceToHost);
// Free device memory
cudaFree(d_A.elements);
cudaFree(d_B.elements);
cudaFree(d_C.elements);
}
int main(void) {
Matrix A, B, C;
A.width = B.width = C.width = 4096;
A.height = B.height = C.height = 4096;
A.stride = B.stride = C.stride = 128;
int sizeA = A.width * A.height * sizeof(float);
int sizeB = B.width * B.height * sizeof(float);
int sizeC = C.width * C.height * sizeof(float);
A.elements = (float *)malloc(sizeA);
B.elements = (float *)malloc(sizeB);
C.elements = (float *)malloc(sizeC);
for (int i = 0; i < A.width * A.height; i++)
{A.elements[i] = 1;}
for (int i = 0; i < B.width * B.height; i++)
{B.elements[i] = 2;}
MatMul(A, B, C);
return 0;
}
// Matrix multiplication kernel called by MatMul()
__global__ void MatMulKernel(Matrix A, Matrix B, Matrix C)
{
// Block row and column
int blockRow = blockIdx.y;
int blockCol = blockIdx.x;
// Each thread block computes one sub-matrix Csub of C
Matrix Csub = GetSubMatrix(C, blockRow, blockCol);
// Each thread computes one element of Csub
// by accumulating results into Cvalue
float Cvalue = 0;
// Thread row and column within Csub
int row = threadIdx.y;
int col = threadIdx.x;
// Loop over all the sub-matrices of A and B that are
// required to compute Csub
// Multiply each pair of sub-matrices together
// and accumulate the results
for (int m = 0; m < (A.width / BLOCK_SIZE); ++m) {
// Get sub-matrix Asub of A
Matrix Asub = GetSubMatrix(A, blockRow, m);
// Get sub-matrix Bsub of B
Matrix Bsub = GetSubMatrix(B, m, blockCol);
// Shared memory used to store Asub and Bsub respectively
__shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
// Load Asub and Bsub from device memory to shared memory
// Each thread loads one element of each sub-matrix
As[row][col] = GetElement(Asub, row, col);
Bs[row][col] = GetElement(Bsub, row, col);
// Synchronize to make sure the sub-matrices are loaded
// before starting the computation
__syncthreads();
// Multiply Asub and Bsub together
for (int e = 0; e < BLOCK_SIZE; ++e)
Cvalue += As[row][e] * Bs[e][col];
// Synchronize to make sure that the preceding
// computation is done before loading two new
// sub-matrices of A and B in the next iteration
__syncthreads();
}
// Write Csub to device memory
// Each thread writes one element
SetElement(Csub, row, col, Cvalue);
}
__global__ void naiveKernel(Matrix A, Matrix B, Matrix C)
{
int n = blockIdx.x * blockDim.x + threadIdx.x;
int m = blockIdx.y * blockDim.y + threadIdx.y;
int N = A.width;
int M = B.height;
int K = A.height;
if (n < N && m < M){
float sum = 0;
for (int k = 0; k < K; k++){
//sum += A[m*K+k] * B[k*N+n];
sum = fmaf(A.elements[m*K+k], B.elements[k*N+n], sum);
}
C.elements[m*N+n] = sum;
}
}