How to speed up AtomicAdd kernel using shared memory

Hi. I am seeking help to understand why my code using shared memory and atomic operations is not working.

I’m relatively new to CUDA programming. I’ve studied the various explanations and examples around creating custom kernels and using atomic operations (here, here, here and various other explanatory sites / links I could find on SO and this forum).

Conceptually my problem is as follows:

  • I have a matrix X of size [M, N], a vector Y of size [N], where N >> M, and a matrix Z of size [M, 256].

  • X is an unsigned integer8 array, Y is a float32 vector, Z is a float32 array.

  • For each k < X[i, j]: Z[i, k] += Y[j]

Hence, the following (stylized) CUDA kernel naively implements this procedure using AtomicAdd:

const int i = blockIdx.x * blockDim.x + threadIdx.x; 
const int j = blockIdx.y * blockDim.y + threadIdx.y;
  
unsigned char Xc = X[i, j];
float yj = Y[j];
for (unsigned char k = 0; k < Xc; k++){
	atomicAdd(&Z[i, k], yj);
}

I am trying to improve on this working solution using shared memory but can’t get it to work. Conceptually, I think the solution should look as follows:

  • Assign values to shared memory arrays;
  • Synchronize threads;
  • Compute the loop on the shared arrays;
  • Synchronize threads;
  • Global AtomicAdd over the results in the shared memory

Thus, a starting implementation would look like this (with a threadblock size of (16, 64)):

// Dimensions
const int i = blockIdx.x * blockDim.x + threadIdx.x;
const int ti = threadIdx.x;
const int j = blockIdx.y * blockDim.y + threadIdx.y;
const int tj = threadIdx.y;
  
// Define shared arrays
__shared__ unsigned char Xs[16][64];
__shared__ float Ys[64];
__shared__ float Zs[16][256];

// Fill shared arrays
Xs[ti, tj] = X[i, j];
Ys[tj] = Y[j];
// Synchronize
__syncthreads();
// Fill shared array
for (unsigned char k = 0; k < Xs[ti, tj]; k++){
    atomicAdd(&Zs[ti, k], Ys[tj]);
}
// Synchronize
__syncthreads();
// Atomic add over global memory
?

I don’t understand how to use the shared results to obtain the global results, nor do I understand how to obtain the correct shared results in the first place. Note that all the shared arrays are of fixed size and I don’t need to dynamically change them. Also, I believe with the chosen size I remain below the max. shared memory per threadblock of 48 kB.

Hence, this is probably a question of basic syntax in relation to usage of shared memory which I don’t understand but I really can’t wrap my head around how it should look like. Your help is very much appreciated!

It’s difficult to be precise when working with “stylized” code.

However if it were me, I would first make sure that my global accesses were coalesced (if possible) and my shared accesses were un-bank-conflicted (if possible). At first glance, this kind of “stylized” code (that could not be useful C++ code) suggests the possibility of inefficient usage of memory:

X[i, j]  (global, might be uncoalesced)
Xs[ti, tj] (shared, might be bank-conflicted)

On to your question. It’s not clear what you’re struggling with. Since you’ve decided that you can work with a 16x256 array in shared memory, and this is presumably a copy of the corresponding array in global (i.e. I am assuming M = 16), your shared atomics seem roughly correct, and the global atomics (after all the shared updates are done) would be to put the data from shared into global on a 1:1 correspondence basis:

atomicAdd(&Z[ti, tj], Zs[ti, tj]);
atomicAdd(&Z[ti, tj+64], Zs[ti, tj+64]);
atomicAdd(&Z[ti, tj+128], Zs[ti, tj+128]);
atomicAdd(&Z[ti, tj+192], Zs[ti, tj+192]);

Hi Robert, thanks for taking the time to answer and apologies for not making things clear enough. The full working code of my current kernel is as follows (I am using Torch, that’s why I stylized the code):

__global__ void _cuda_kernel(
    const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X,
    const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
    torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> Z) {
  
  const int i = blockIdx.x * blockDim.x + threadIdx.x;
  const int j = blockIdx.y * blockDim.y + threadIdx.y;
    
  if (i < X.size(0) && j < X.size(1)){
    unsigned char Xc = X[i][j];
    float yj = Y[j];
    for (unsigned char k = 0; k < Xc; k++){
        atomicAdd(&Z[i][k], yj);
    }
  }
}

std::vector<torch::Tensor> _call_cuda_kernel(
    torch::Tensor X,
    torch::Tensor Y,
    torch::Tensor Z) {
  
  int n_x = X.size(0);
  int n_y = X.size(1);
  const dim3 threadsPerBlock(16, 64);
  int bpg_x = (n_x + threadsPerBlock.x - 1) / threadsPerBlock.x;
  int bpg_y = (n_y + threadsPerBlock.y - 1) / threadsPerBlock.y;
  const dim3 numBlocks(bpg_x, bpg_y);

  _cuda_kernel<<<numBlocks, threadsPerBlock>>>(
        X.packed_accessor32<unsigned char,2,torch::RestrictPtrTraits>(),
        Y.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
        Z.packed_accessor32<float,2,torch::RestrictPtrTraits>());
  
  return {Z};
  }

I am trying to do this more efficiently because maybe using shared memory could make this kernel more efficient (to be verified). I don’t know whether my X is coalesced in memory, however I explicitly make the array contiguous in row-major order before sending it to the CUDA kernel (this is done in Python).

M is typically much larger than 16 (into hundreds), and N is typically much larger than 64 (into millions). I’ve chosen 16 to correspond with the threadperblocks of 16 in the M dimension, similar as choosing 64 for the N dimension. I thought that is how you should allocate shared memory (i.e. shared arrays should be equal to the number of threads per block of that dimension of the cuda grid - at least that is how each example in the CUDA explanations seems to work).

I apologize but I don’t understand your code - Z has dimensions [M, 256], so why should I use tj in the second dimension of Z? (which can’t possibly be larger than 256) An example size of my problem looks like:

  • M = 100
  • N = 100,000

A naive Python loop would then look as follows (maybe this illustrates the problem better):

for i in range(100):                  # This is a counter that runs i from 0 to 99 (size of M)

  for j in range(100,000):            # This is a counter that runs j from 0 to 99,999 (size of N)

    Xc = X[i, j]                      # This is a temporary assignment of Xc to the element i, j of X

    for k in range(Xc):               # This is a counter up to the unsigned integer8 value in Xc (i.e. runs maximally up to 255)

      Z[i, k] += Y[j]                 # Adding to each Z[i, k] the jth element of Y

Essentially I am trying to write a custom CUDA kernel for this loop.

The question is: how to use shared memory and shared/global atomics to improve over the naive solution using global atomics above?

Thanks again for your help - really appreciated. I hope I made the problem more clear.

I just wanted to reply here with the code that I finally came up with in order to help future people experiencing similar problems. Quick recap of the problem:

  • Matrix X of size [M, N], a vector Y of size [N], where N >> M, and a matrix Z of size [M, 256].
  • X is an unsigned integer8 array, Y is a float32 vector, Z is a float32 array.

The loop I am optimizing is:

for i in range(M):                

  for j in range(N):            

    Xc = X[i, j]                      

    for k in range(Xc):              

      Z[i, k] += Y[j]        

Solution
First, I implemented a naive AtomicAdd kernel:

__global__ void _cuda_kernel(
    const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X,
    const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
    torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> Z) {
  
  const int i = blockIdx.x * blockDim.x + threadIdx.x;
  const int j = blockIdx.y * blockDim.y + threadIdx.y;
    
  if (i < X.size(0) && j < X.size(1)){
    unsigned char Xc = X[i][j];
    float yj = Y[j];
    for (unsigned char k = 0; k < Xc; k++){
        atomicAdd(&Z[i][k], yj);
    }
  }
}

The next version uses shared memory, this improves speed by approx. ~30% vis-a-vis the first kernel in my measurements. Here, I use a two dimensional grid of size (1, 1024).

__global__ void _cuda_kernel(
  const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X,
  const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
  torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> Z) {

unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int ti_j = threadIdx.y;
unsigned int j = blockIdx.y * blockDim.y + ti_j;

if (i < X.size(0) && j < X.size(1)){
  // Create shared variables
  __shared__ unsigned char Xs[1024];
  __shared__ float ys[1024];
  // Fill variables
  Xs[ti_j] = X[i][j];
  ys[ti_j] = Yj];
  // Sync threads 
  __syncthreads();    
  // Loop over 2nd dimension of Z
  auto Xsc = Xs[ti_j];
  auto yc = ys[ti_j];
  for (unsigned char k = 0; k < Xsc; k++){
    atomicAdd(&Z[i][k], yc);
  }
  // Sync threads 
  __syncthreads(); 
}
}

The third and final version is based on the ‘reduce6’ example in the CUDA documentation. I use a three dimensional grid. Because it is fine for me to work with a fixed threadblock sizes of (32, 1, 32) for resp. array sizes of (M, N, 256) for now (see problem statement), I can significantly simplify the reduction kernel and eliminate any use of shared memory such that we only have to work with registers. This improves speed by another 30% over the previous version (note: this includes the final reduction of the blocksums). Note that this comes at the expense of (potentially) significant additional memory consumption (no free lunch unfortunately), as I am now invoking a new 3-dimensional Z_block float tensor. This tensor can be quite large when N >>, and therefore GPU memory consumption increases.

__global__ void _cuda_kernel(
  const torch::PackedTensorAccessor32<unsigned char,2,torch::RestrictPtrTraits> X, 
  const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> Y,
  torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> Z_block) {
  
  // Thread block identifier
  // https://developer.nvidia.com/blog/cooperative-groups/
  cg::thread_block cta = cg::this_thread_block(); 
  cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
  // M index
  unsigned int ti_x = threadIdx.x;
  unsigned int blockSize_x = blockDim.x;
  unsigned int gridSize_x = blockSize_x*gridDim.x;
  unsigned int i = blockIdx.x * blockSize_x + ti_x; 
  // N index
  unsigned int j = blockIdx.y * blockDim.y + threadIdx.y; 
  // Z(2) index
  unsigned int k = blockIdx.z * blockDim.z + threadIdx.z;

  // Temporary float
  float Z_temp = 0;

  // Load values and perform first reduction
  // Formulation of multiplication instead of if-condition improves speed as no divergent branches are created.
  while (i < X.size(1) && j < X.size(0) && k < Z_block.size(2)){
    auto flag_add = k < X[j][i];
    Z_temp += Y[i] * flag_add;
    i += gridSize_x;  
  }

  // Reduce within warp using shuffle: https://developer.nvidia.com/blog/faster-parallel-reductions-kepler/
  // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions 
  // This is even faster than using shared memory as threads within a warp can directly exchange data
  // Because warpSize is 32, with a threadblock size of 32 we don't need any shared memory in this kernel, and therefore
  // only have to use the fast registers of the warp (invoked by the local temporary floats)
  auto flag_warpreduce = ti_x < 32;
  for (int offset = tile32.size()/2; offset > 0; offset /= 2) {
    Z_temp += tile32.shfl_down(Z_temp * flag_warpreduce, offset);
  }

  // Write overall sum to 0 index of Z_block. We now have a sum per block, but we can sum these blocks using LibTorch sum afterwards (or another standard sum kernel)
  // Again use flag instead of if-statement to avoid divergent branches.
  auto flag_thread_zero = ti_x == 0;
  Z_block[blockIdx.x][j][k] = Z_temp * flag_thread_zero;
}

I have two remaining questions:

  1. I’ve read about blocksumreductions, but I am curious to find out whether there is more speed to be gained by improving the sum of blocks (I now sum the blocksums by calling libtorch’s sum function on Z_block’s first dimension after invoking the above kernel)? I tried e.g. AtomicAdd on the global Z array but this slows down the code significantly (thus probably too many conflicts)
  2. In the CUDA reduction example, there are many if-statements. I replaced them with flags and multiplications as this seems to improve speed (by measurement). I suspect this is due to not having any divergent branches in the kernel, but I was wondering whether there is any risk to this? (i.e. why is this not best practice if it seems to improve speed?) @Robert_Crovella

I’m skeptical of the claim. The only other comment I have is that this doesn’t make sense to me:

Z_block[blockIdx.x][j][k] = Z_temp * flag_thread_zero;

If you have multiple threads in X (seems to be the case, impossible to tell from the code you have posted), how could that do anything correctly? You have multiple threads writing to the same location (right?)

Thanks - re. the other comment: yes, that is indeed strange and I would expect incorrect results. Yet the result is completely correct (checked against multiple GPU and CPU alternatives). The result is also identical on each consecutive run. So according to the theory it should not produce anything correctly, yet it does… Maybe (this is speculating) it is due to the shfl_down sequence that effectively creates only a single Z_temp in the entire warp, which can then be added irrespective of the thread_id within the warp? (I tested this by removing the flag and that gives the same results, so the flag isn’t even necessary). I encourage you to try it out yourself…

Edit: the flag_warpreduce isn’t necessary either if threadIdx.x never exceeds 32.

As it happens, we appear to have different definitions of “correctly”

I don’t doubt that you might observe what appears to be correct behavior. When multiple writes occur to the same GPU global memory location, one of the writes is guaranteed to show up there, eventually. (This is a loose description, we can tighten it up by being very specific about the case we are referring to. For example, here, in this case, I am referring to multiple writes from threads in the same warp, from a single store instruction, with no other write activity to that location of any kind.)

However, the correct description of which write will show up there is “it is undefined”. It may well be that the actual implementation is such that the GPU always chooses the winner as being the one with lowest numbered warp lane, but such behavior is not specified anywhere.

So if you’d like to code according to what appears to work, go ahead. If you’d like to rest your code correctness on specified behavior, I would say your code is broken (in that particular line) even though it appears to work correctly. Good luck!

Thanks - my goal was to understand what happens under the hood because it felt a bit annoying that I observe certain behaviour but can’t really explain / understand it. Fwiw, I added the final if-statement back because even though the observed behaviour might be correct I don’t want the code to display unspecified behaviour as you indicate. Thanks for the help!

it felt a bit annoying that I observe certain behaviour but can’t really explain

Well, sometimes engineering life is a bit like the old joke about a computer scientist, a physicist, and a mathematician who share a train compartment on their way to a conference in a neighboring country.

Shortly after they cross the border, they observe a flock of black sheep in a field. “Funny,” the computer scientist says, “all sheep in this country are black.” “No, no, no,” the physicist objects, “based on our observation we can only state that some sheep in this country are black.” The mathematician replies: “You are both wrong. At best we can claim that some sheep in this country are black on at least one side.”

3 Likes

My two cents is that OP’s latest implementation still does not address the uncoalesced global memory access problem that @Robert_Crovella mentioned at the beginning. The fix would be swapping the definition of i and j, i.e.

 int i = blockDim.y * blockIdx.y + threadIdx.y;
 int j = blockDim.x * blockIdx.x + threadIdx.x;

As a tradeoff, the atomicAdd(&Z[i][k], yc) can now give rise to greater level of serialization, given that threads in the same warp would have to update the same memory locations. The fix would be to introduce a warp-level reduce with active mask, where the float4 data held by the active threads in a warp are reduced to the leader lane (the active thread with the smallest lane index) and only let that leader lane perform the atomicAdd operation. But it would be wise to profile the program beforehand and ensure the serialization is indeed a bottleneck.

Also nice joke by @njuffa.