This is a 1bit x 16bits GEMM kernel working on CUDA cores.
I wrote this imitating a float x float kernel where each thread compute a V by V tile of C and shared memory is used for loading A and B.
The current speed of this kernel is no faster than a float kernel, is my high-level design goes the wrong way? I wonder how to further optimize this kernel. Also I want to know theoretically where should be the performance gain by reducing matrix B to binary numbers.
Here is the profiling results from nsight compute:
Here is the kernel code:
#define OFFSET(i, j, ld) (((i) * (ld)) + (j))
#define FETCH_FLOAT4(ptr) (*reinterpret_cast<float4 *>(&(ptr)))
#define FETCH_FLOAT2(ptr) (*reinterpret_cast<float2 *>(&(ptr)))
template<const int BLOCK_SIZE, const int L, const int S, const int V>
__global__ void matmul_binary(half* A, uint32_t* B, float* C, int M, int K, int N){
int tid = threadIdx.x;
int yblock = blockIdx.y;
int xblock = blockIdx.x;
int thread_id_x = tid % BLOCK_SIZE;
int thread_id_y = tid / BLOCK_SIZE;
int ybase = blockIdx.y * BLOCK_SIZE + tid / BLOCK_SIZE;
int xbase = blockIdx.x * BLOCK_SIZE + tid % BLOCK_SIZE;
int nthreads = blockDim.x;
__shared__ half sA[L][S];
__shared__ uint32_t sB[S][L / 32];
float a[V], c[V][V] = {0};
uint32_t b;
for (int ko = 0; ko < K; ko += S){
__syncthreads();
// cooperative fetching
int ELEMENTS_PER_THREAD_A = sizeof(float4) / sizeof(half);
for (int i = 0; i < L * S / (nthreads * ELEMENTS_PER_THREAD_A); i++){
int ya = (i * nthreads + tid) / (S / ELEMENTS_PER_THREAD_A);
int xa = (i * nthreads + tid) % (S / ELEMENTS_PER_THREAD_A);
FETCH_FLOAT4(sA[ya][xa * ELEMENTS_PER_THREAD_A]) =
FETCH_FLOAT4(A[OFFSET(yblock * L + ya, ko + xa * ELEMENTS_PER_THREAD_A, K)]);
}
int ELEMENTS_PER_THREAD_B = sizeof(float2) / sizeof(uint32_t);
for (int i = 0; i < L * S / (nthreads * 32 * ELEMENTS_PER_THREAD_B); i++){ // must satisfy L * S / (nthreads * 32 * 2) >= 1
int yb = (i * nthreads + tid) / (L / (32 * ELEMENTS_PER_THREAD_B));
int xb = (i * nthreads + tid) % (L / (32 * ELEMENTS_PER_THREAD_B));
FETCH_FLOAT2(sB[yb][xb * ELEMENTS_PER_THREAD_B]) =
FETCH_FLOAT2(B[OFFSET((ko + yb), xblock * L / 32 + xb * ELEMENTS_PER_THREAD_B, N / 32)]);
}
__syncthreads();
for (int ki = 0; ki < S; ki++){
for (int iv = 0; iv < V; iv++){
a[iv] = sA[thread_id_y * V + iv][ki];
}
int bitNum = thread_id_x * V;
int wordIdx = bitNum / 32;
int bitIdx = bitNum % 32;
b = sB[ki][wordIdx];
for (int y = 0; y < V; y++){
for (int x = 0; x < V; x++){
int bit = b >> (x + bitIdx) & 1;
int sign = bit * 2 - 1;
c[y][x] += sign * __half2float(a[y]);
}
}
}
}
for (int y = 0; y < V; y++){
for (int x = 0; x < V; x++){
C[(ybase * V + y) * N + xbase * V + x] = c[y][x]; // ybase * V + y row, xbase * V +x
}
}
}
const int L = 64;
const int S = 64;
const int V =8;
const int BLOCK_SIZE = 8;
dim3 threadsPerBlock((L/V) * (L/V));
dim3 blocksPerGrid((N + L - 1 ) / L,
(M + L - 1 ) / L);
matmul_binary<BLOCK_SIZE, L, S, V><<<blocksPerGrid, threadsPerBlock>>>(d_input, d_weight, d_output, M, K, N);
