Hi, this topic will be pretty much copy of https://stackoverflow.com/questions/52185919/cuda-kernel-for-xnor-convolution-super-fast-58x-in-theory-is-too-slow.
Inspired by XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks (https://arxiv.org/pdf/1603.05279.pdf) and failed to find xnor conv2d in opensource I made my own implementation of it based on custom CUDA kernel and link it with pytorch using this example https://gist.github.com/szagoruyko/440c561f7fce5f1b20e6154d801e6033. The problem with my implementation is that, despite its functional correctness, it do not give any significant speedup as it should. Yes, pytorch uses modern nvidia cuDNN library which is quite effective and based on fantastic subcubic algorithms like winograd convolution which are totally unapplicable in my case. But still properly written XNOR conv2d kernel should be significantly faster.
Here is my code which is pretty self explained. I am new to CUDA so I definitely missed something like: use of shared memory, or properly organized warps, or memory accesses
Here is part I mainly interested in:
__global__ void bconv2d(float Omg[B_SIZE][C_OUT][W_OUT][H_OUT],
int bImg[B_SIZE][W_IN][H_IN][C_BIN_IN],
int bweight[C_OUT][W_KERN][H_KERN][C_BIN_IN],
float bias[C_OUT])
{
int acc = 0;
const int w_out = blockIdx.y;
const int h_out = blockIdx.z;
const int c_out = threadIdx.x;
const int b = blockIdx.x;
#pragma unroll
for (int c_bin_in = 0; c_bin_in < C_BIN_IN; c_bin_in++)
#pragma unroll
for (int w_kern = 0; w_kern < W_KERN; w_kern++)
#pragma unroll
for (int h_kern = 0; h_kern < H_KERN; h_kern++)
acc += 32 - 2*__popc(bImg[b][w_out + w_kern][h_out + h_kern][c_bin_in] ^ bweight[c_out][w_kern][h_kern][c_bin_in]);
Omg[b][c_out][w_out][h_out] = acc + bias[c_out];
}
And here is the “whole” code:
import torch
import pycuda.autoinit
import pycuda.driver as drv
import numpy as np
import time
# FIXME: get rid of this kack if possible
torch.cuda.FloatTensor(2)
from pycuda.compiler import SourceModule
class Holder(pycuda.driver.PointerHolderBase):
def __init__(self, t):
super(Holder, self).__init__()
self.t = t
self.gpudata = t.data_ptr()
def get_pointer(self):
return self.t.data_ptr()
ModuleText = """
#include <math.h>
// W_OUT, H_OUT, W_IN, H_IN, W_KERN, H_KERN, C_OUT, C_IN, B_SIZE
#define C_BIN_IN C_IN/32
__global__ void Img2bImg(int bin_arr[B_SIZE][W_IN][H_IN][C_BIN_IN],
char arr[B_SIZE][W_IN][H_IN][C_IN])
{
const int b = blockIdx.x;
const int w_in = blockIdx.y;
const int h_in = blockIdx.z;
const int c_bin_in = threadIdx.x;
int acc = 0;
char * cur_arr = &arr[b][w_in][h_in][c_bin_in << 5];
for (int c = 0; c < 32; c++)
{
int t = (cur_arr[c] & 0x80) >> 7;
acc |= t << c;
}
bin_arr[b][w_in][h_in][c_bin_in] = acc;
}
__global__ void weight2bweight(int bin_arr[C_OUT][W_KERN][H_KERN][C_BIN_IN],
float arr[C_OUT][C_IN][W_KERN][H_KERN])
{
const int c_out = blockIdx.x;
const int w_kern = blockIdx.y;
const int h_kern = blockIdx.z;
const int c_bin_in = threadIdx.x;
int acc = 0;
for (int c = 0; c < 32; c++)
{
int s = signbit(arr[c_out][c + (c_bin_in << 5)][w_kern][h_kern]);
acc |= s << c;
}
bin_arr[c_out][w_kern][h_kern][c_bin_in] = acc;
}
__global__ void bconv2d(float Omg[B_SIZE][C_OUT][W_OUT][H_OUT],
int bImg[B_SIZE][W_IN][H_IN][C_BIN_IN],
int bweight[C_OUT][W_KERN][H_KERN][C_BIN_IN],
float bias[C_OUT])
{
int acc = 0;
const int w_out = blockIdx.y;
const int h_out = blockIdx.z;
const int c_out = threadIdx.x;
const int b = blockIdx.x;
#pragma unroll
for (int c_bin_in = 0; c_bin_in < C_BIN_IN; c_bin_in++)
#pragma unroll
for (int w_kern = 0; w_kern < W_KERN; w_kern++)
#pragma unroll
for (int h_kern = 0; h_kern < H_KERN; h_kern++)
acc += 32 - 2*__popc(bImg[b][w_out + w_kern][h_out + h_kern][c_bin_in] ^ bweight[c_out][w_kern][h_kern][c_bin_in]);
Omg[b][c_out][w_out][h_out] = acc + bias[c_out];
}
"""
class TrueBinConv():
def __init__(self, x_size, m):
B_SIZE, C_IN, W_IN, H_IN = x_size
C_OUT, C_IN_2, W_KERN, H_KERN = m.weight.data.size()
assert C_IN_2 == C_IN
assert C_IN % 32 == 0 and C_OUT % 32 == 0
C_BIN_IN = C_IN // 32
W_OUT = W_IN - W_KERN + 1
H_OUT = H_IN - H_KERN + 1
CompileText = ModuleText.replace('W_IN', str(W_IN)).replace('H_IN', str(H_IN)) \
.replace('C_IN', str(C_IN)).replace('B_SIZE', str(B_SIZE)).replace('H_OUT', str(H_OUT)) \
.replace('W_OUT', str(W_OUT)).replace('H_KERN', str(H_KERN)).replace('W_KERN', str(W_KERN)) \
.replace('C_OUT', str(C_OUT))
sm = SourceModule(CompileText)
self.Img2bImg = sm.get_function('Img2bImg')
weight2bweight = sm.get_function('weight2bweight')
self.bconv2d = sm.get_function('bconv2d')
self.bweight = torch.cuda.IntTensor(C_OUT, W_KERN, H_KERN, C_BIN_IN).zero_()
weight2bweight(Holder(self.bweight), Holder(m.weight.data.contiguous()),
block=(C_BIN_IN, 1, 1), grid=(C_OUT, W_KERN, H_KERN))
self.bias = m.bias.data
self.alpha = (torch.abs(m.weight.data).sum(1).sum(1).sum(1) / (C_IN * W_KERN * H_KERN)).view(C_OUT, 1, 1)
def ForwardPass(self, x):
B_SIZE, C_IN, W_IN, H_IN = x.size()
C_OUT, W_KERN, H_KERN, C_BIN_IN = self.bweight.size()
assert C_BIN_IN == C_IN // 32
W_OUT = W_IN - W_KERN + 1
H_OUT = H_IN - H_KERN + 1
x2 = torch.cuda.ByteTensor(B_SIZE, W_IN, H_IN, C_IN)
x2.copy_(x.transpose(1, 2).transpose(2, 3).type(torch.ByteTensor))
bx = torch.cuda.IntTensor(B_SIZE, W_IN, H_IN, C_BIN_IN)
self.Img2bImg(Holder(bx), Holder(x2), block=(C_BIN_IN, 1, 1), grid=(B_SIZE, W_IN, H_IN))
st_time = time.time()
ox = torch.cuda.FloatTensor(B_SIZE, C_OUT, W_OUT, H_OUT)
self.bconv2d(Holder(ox), Holder(bx), Holder(self.bweight), Holder(self.bias),
block=(C_OUT, 1, 1), grid=(B_SIZE, W_OUT, H_OUT))
print('Binary kernel runs: ', ox.sum(), (time.time() - st_time) * 1000)
ox = ox * self.alpha
return ox