Warp-level reduction before atomic accumulation

Dear CUDA team,

I am trying to improve a kernel that performs atomic global memory FP32 additions by a modified version that first does an inter-warp reduction to reduce the number of necessary memory transactions. The “naïve” version of my kernel performs reductions using the following function:

void atomicAddSlow(float *out, uint32_t index, float value) {
    atomicAdd(out + index, value);
}

My current attempt looks as follows, but I do not think it is correct. My code is also eerily similar to a “bad example” about unsafe implicit warp-synchronous programming here: Using CUDA Warp-Level Primitives | NVIDIA Technical Blog

void atomicAddReduce(float *out, uint32_t index, float value) {
    uint32_t active = __activemask();

    for (uint32_t offset = 16; offset != 0; offset >>= 1) {
        float value2    = __shfl_down_sync(active, value, offset);
        uint32_t index2 = __shfl_down_sync(active, index, offset);

        if (index == index2)
            value += value2;
    }

    // Thread's position within warp
    uint32_t lane_idx = threadIdx.x & 31;

    // Designate a leader thread within the set of peers
    uint32_t peers = __match_any_sync(active, index);
    uint32_t leader_idx  = __ffs(peers) - 1;

    // If the current thread is the leader, perform atomic op.
    if (lane_idx == leader_idx)
        atomicAdd(ptr + index, value);
}

I should mention that this implementation is subject to further constraints:

  • the atomicAdd*() function may be called in a conditional context where some threads in the warp are disabled.
  • The target pointer out is always consistent across the whole warp, but index may be different per thread. The implementation should reduce the addition locally per unique index within the warp and then perform one global memory transaction.
  • I can’t use shared memory, only warp-level register exchange
  • I read that the CUDA compiler should in principle be able to perform a warp-local reduction automatically as an optimization. However, I can confirm that this optimization does not take place when looking both PTX and decompiled SASS output of code using atomicAddSlow().

It seems to me that I can’t be the first person wanting to do this. Is there an existing implementation of such a function that somebody knows?

Many thanks,
Wenzel

what is your target device architecture?

a shuffle operation produces an undefined result if the target lane is not participating. You don’t seem to have handled this case. You might be able to fix this with an appropriate modification to your if-test.

I’m not aware of one. CG reduce may be of interest, although it may not be readily amenable to the case of varying index, I haven’t studied it carefully.

What are your concerns?

What are your concerns?

It seems to me that the tree-based reduction cannot work correctly when the indices mismatch. Essentially, a number of additions might not be taking place that the algorithm depends on. Instead, one must adapt the shift amount to reduce within each group of matching indices.

This was my next attempt:

__device__ void reduce_atomic(float *ptr, uint32_t index, float value) {
    uint32_t active = __activemask();
    uint32_t peers = __match_any_sync(active, index);

    // Thread's position within warp
    uint32_t lane_idx = threadIdx.x & 31;

    for (uint32_t offset = 16; offset != 0; offset >>= 1) {
        uint32_t index = __fns(peers, lane_idx, offset);
        float value2 = __shfl_down_sync(active, value, index);
        if (index < 32)
            value += value2;
    }

    // Designate a leader thread within the set of peers
    uint32_t leader_idx = __ffs(peers) - 1;

    // If the current thread is the leader, perform atomic op.
    if (lane_idx == leader_idx)
        atomicAdd(ptr + index, value);
}

(Untested)
Alas, this method also does not work for me because the __fns intrinsic (which maps to the fns PTX instruction) is not implemented for some target environments in which I want to run this code (notably, OptiX).

what is your target device architecture?

I am grasping for straws at this point, if you have any ideas don’t hesitate to share them.

I took a look at the C++ CG reduce op, and it seems to only work for simpler reductions that don’t do what I need (to reduce within each group of matching indices).

Edit: I was wrong.

CG solution for sm 70 and newer. it uses labeled_partition to group threads in a warp by your index. Programming Guide :: CUDA Toolkit Documentation



#include <iostream>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

__global__
void kernel(float* output, const int* indices, const float* input){
  #if __CUDA_ARCH__ >= 700
  auto warp = cg::tiled_partition<32>(cg::this_thread_block());
  const float value = input[threadIdx.x];
  const int index = indices[threadIdx.x];
  auto sameindexgroup = cg::labeled_partition(warp, index);
  float reduced = cg::reduce(sameindexgroup,value, cg::plus<float>{});
  if(sameindexgroup.thread_rank() == 0){
    atomicAdd(output + index, reduced);
  }
  #endif
}



int main(){
  float* output; cudaMallocHost(&output, sizeof(float) * 32);
  int* indices; cudaMallocHost(&indices, sizeof(int) * 32);
  float* input; cudaMallocHost(&input, sizeof(float) * 32);

  for(int i = 0; i < 32; i++){
    indices[i] = i % 4;
    input[i] = 1;
    output[i] = 0;
  }

  kernel<<<1,32>>>(output, indices, input);
  cudaDeviceSynchronize();

  for(int i = 0; i < 32; i++){
    std::cout << output[i]  << " ";
  }
  std::cout << "\n";
}

3 Likes

This was very helpful, many thanks @striker159.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.