I just wanted to reply here with the code that I finally came up with in order to help future people experiencing similar problems. Quick recap of the problem:
- Matrix X of size [M, N], a vector Y of size [N], where N >> M, and a matrix Z of size [M, 256].
- X is an unsigned integer8 array, Y is a float32 vector, Z is a float32 array.
The loop I am optimizing is:
for i in range(M):
for j in range(N):
Xc = X[i, j]
for k in range(Xc):
Z[i, k] += Y[j]
Solution
First, I implemented a naive AtomicAdd kernel:
__global__ void _cuda_kernel(
const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X,
const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> Z) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
const int j = blockIdx.y * blockDim.y + threadIdx.y;
if (i < X.size(0) && j < X.size(1)){
unsigned char Xc = X[i][j];
float yj = Y[j];
for (unsigned char k = 0; k < Xc; k++){
atomicAdd(&Z[i][k], yj);
}
}
}
The next version uses shared memory, this improves speed by approx. ~30% vis-a-vis the first kernel in my measurements. Here, I use a two dimensional grid of size (1, 1024).
__global__ void _cuda_kernel(
const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X,
const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> Z) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int ti_j = threadIdx.y;
unsigned int j = blockIdx.y * blockDim.y + ti_j;
if (i < X.size(0) && j < X.size(1)){
// Create shared variables
__shared__ unsigned char Xs[1024];
__shared__ float ys[1024];
// Fill variables
Xs[ti_j] = X[i][j];
ys[ti_j] = Yj];
// Sync threads
__syncthreads();
// Loop over 2nd dimension of Z
auto Xsc = Xs[ti_j];
auto yc = ys[ti_j];
for (unsigned char k = 0; k < Xsc; k++){
atomicAdd(&Z[i][k], yc);
}
// Sync threads
__syncthreads();
}
}
The third and final version is based on the ‘reduce6’ example in the CUDA documentation. I use a three dimensional grid. Because it is fine for me to work with a fixed threadblock sizes of (32, 1, 32) for resp. array sizes of (M, N, 256) for now (see problem statement), I can significantly simplify the reduction kernel and eliminate any use of shared memory such that we only have to work with registers. This improves speed by another 30% over the previous version (note: this includes the final reduction of the blocksums). Note that this comes at the expense of (potentially) significant additional memory consumption (no free lunch unfortunately), as I am now invoking a new 3-dimensional Z_block float tensor. This tensor can be quite large when N >>, and therefore GPU memory consumption increases.
__global__ void _cuda_kernel(
const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X,
const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> Z_block) {
// Thread block identifier
// https://developer.nvidia.com/blog/cooperative-groups/
cg::thread_block cta = cg::this_thread_block();
cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
// M index
unsigned int ti_x = threadIdx.x;
unsigned int blockSize_x = blockDim.x;
unsigned int gridSize_x = blockSize_x*gridDim.x;
unsigned int i = blockIdx.x * blockSize_x + ti_x;
// N index
unsigned int j = blockIdx.y * blockDim.y + threadIdx.y;
// Z(2) index
unsigned int k = blockIdx.z * blockDim.z + threadIdx.z;
// Temporary float
float Z_temp = 0;
// Load values and perform first reduction
// Formulation of multiplication instead of if-condition improves speed as no divergent branches are created.
while (i < X.size(1) && j < X.size(0) && k < Z_block.size(2)){
auto flag_add = k < X[j][i];
Z_temp += Y[i] * flag_add;
i += gridSize_x;
}
// Reduce within warp using shuffle: https://developer.nvidia.com/blog/faster-parallel-reductions-kepler/
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions
// This is even faster than using shared memory as threads within a warp can directly exchange data
// Because warpSize is 32, with a threadblock size of 32 we don't need any shared memory in this kernel, and therefore
// only have to use the fast registers of the warp (invoked by the local temporary floats)
auto flag_warpreduce = ti_x < 32;
for (int offset = tile32.size()/2; offset > 0; offset /= 2) {
Z_temp += tile32.shfl_down(Z_temp * flag_warpreduce, offset);
}
// Write overall sum to 0 index of Z_block. We now have a sum per block, but we can sum these blocks using LibTorch sum afterwards (or another standard sum kernel)
// Again use flag instead of if-statement to avoid divergent branches.
auto flag_thread_zero = ti_x == 0;
Z_block[blockIdx.x][j][k] = Z_temp * flag_thread_zero;
}
I have two remaining questions:
- I’ve read about blocksumreductions, but I am curious to find out whether there is more speed to be gained by improving the sum of blocks (I now sum the blocksums by calling libtorch’s sum function on Z_block’s first dimension after invoking the above kernel)? I tried e.g. AtomicAdd on the global Z array but this slows down the code significantly (thus probably too many conflicts)
- In the CUDA reduction example, there are many if-statements. I replaced them with flags and multiplications as this seems to improve speed (by measurement). I suspect this is due to not having any divergent branches in the kernel, but I was wondering whether there is any risk to this? (i.e. why is this not best practice if it seems to improve speed?) @Robert_Crovella