How to further optimize a mixed-precision GEMM kernel

This is a 1bit x 16bits GEMM kernel working on CUDA cores.
I wrote this imitating a float x float kernel where each thread compute a V by V tile of C and shared memory is used for loading A and B.
The current speed of this kernel is no faster than a float kernel, is my high-level design goes the wrong way? I wonder how to further optimize this kernel. Also I want to know theoretically where should be the performance gain by reducing matrix B to binary numbers.
Here is the profiling results from nsight compute:


Here is the kernel code:

#define OFFSET(i, j, ld) (((i) * (ld)) + (j))
#define FETCH_FLOAT4(ptr) (*reinterpret_cast<float4 *>(&(ptr)))
#define FETCH_FLOAT2(ptr) (*reinterpret_cast<float2 *>(&(ptr)))

template<const int BLOCK_SIZE, const int L, const int S, const int V>
__global__ void matmul_binary(half* A, uint32_t* B, float* C, int M, int K, int N){

    int tid     = threadIdx.x;
    int yblock  = blockIdx.y;
    int xblock  = blockIdx.x;
    int thread_id_x = tid % BLOCK_SIZE;
    int thread_id_y = tid / BLOCK_SIZE;
    int ybase   = blockIdx.y * BLOCK_SIZE + tid / BLOCK_SIZE;
    int xbase   = blockIdx.x * BLOCK_SIZE + tid % BLOCK_SIZE;
    int nthreads = blockDim.x;

    __shared__ half sA[L][S];
    __shared__ uint32_t sB[S][L / 32];
    float a[V], c[V][V] = {0};
    uint32_t b;

    for (int ko = 0; ko < K; ko += S){
        __syncthreads();

        // cooperative fetching
        int ELEMENTS_PER_THREAD_A = sizeof(float4) / sizeof(half);
        for (int i = 0; i < L * S / (nthreads * ELEMENTS_PER_THREAD_A); i++){
            int ya = (i * nthreads + tid) / (S / ELEMENTS_PER_THREAD_A);
            int xa = (i * nthreads + tid) % (S / ELEMENTS_PER_THREAD_A);
            FETCH_FLOAT4(sA[ya][xa * ELEMENTS_PER_THREAD_A]) = 
                FETCH_FLOAT4(A[OFFSET(yblock * L + ya, ko + xa * ELEMENTS_PER_THREAD_A, K)]);
        }
        int ELEMENTS_PER_THREAD_B = sizeof(float2) / sizeof(uint32_t);
        for (int i = 0; i < L * S / (nthreads * 32 * ELEMENTS_PER_THREAD_B); i++){ // must satisfy L * S / (nthreads * 32 * 2) >= 1
            int yb = (i * nthreads + tid) / (L / (32 * ELEMENTS_PER_THREAD_B));
            int xb = (i * nthreads + tid) % (L / (32 * ELEMENTS_PER_THREAD_B));
            FETCH_FLOAT2(sB[yb][xb * ELEMENTS_PER_THREAD_B]) = 
                FETCH_FLOAT2(B[OFFSET((ko + yb), xblock * L / 32 + xb * ELEMENTS_PER_THREAD_B, N / 32)]);
        }

        __syncthreads();
        
        for (int ki = 0; ki < S; ki++){
            for (int iv = 0; iv < V; iv++){
                a[iv] = sA[thread_id_y * V + iv][ki];
            }
            int bitNum  = thread_id_x * V;
            int wordIdx = bitNum / 32;
            int bitIdx  = bitNum % 32;
            b = sB[ki][wordIdx];
            for (int y = 0; y < V; y++){
                for (int x = 0; x < V; x++){
                    int bit = b >> (x + bitIdx) & 1;
                    int sign = bit * 2 - 1;
                    c[y][x] += sign * __half2float(a[y]);
                }
            }
        }
    }

    for (int y = 0; y < V; y++){
        for (int x = 0; x < V; x++){
            C[(ybase * V + y) * N + xbase * V + x] = c[y][x]; // ybase * V + y row, xbase * V +x
        }
    }
}
const int L = 64;
const int S = 64;
const int V =8;
const int BLOCK_SIZE = 8;
dim3 threadsPerBlock((L/V) * (L/V));
dim3 blocksPerGrid((N + L - 1 ) / L,
                   (M + L - 1 ) / L);

matmul_binary<BLOCK_SIZE, L, S, V><<<blocksPerGrid, threadsPerBlock>>>(d_input, d_weight, d_output, M, K, N);

I do not believe, there is general 1 bit x 16 bit hardware support on the Tensor Cores. The closest coming to mind is the sparse matrix feature with conditions, which and how many elements can be 1, and special formatting of this information. It is rather meant for precomputed coefficients.

But perhaps I am mistaken and there is some support.

Thanks for your reply. I neither don’t think it’s supported on Tensor Core, that’s why I use CUDA Cores.

You can of course convert the binary numbers to +1.f/-1.f or +1.f/0.f on use the normal half-matrix multiplications.