Is there an efficient pattern for multiple independent reductions in a single warp?

Let us say each thread in a warp has an array of N floats. (total number of floats: 32*N). I need to compute the sum of all things across the warp which have the same index, and thus have N floats after the reduction, where each of the N outputs has no shared data dependency.

In the special (typical) case of N=1, I would just __shfl_down() by a stride of 1,2,4,8,16 with each followed by an addition. This means 5 shuffles, and 5 additions. (However, the last shuffle only has 1 thread performing useful work, of course.)

For the N>1 case, is it possible to perform N full-warp reductions in less than 5N shuffles and 5N additions? Let us assume for now that I am flexible with where my data ends up.

Great question. I puzzled over something similar a long time ago and am trying to refresh my memory.

I think the answer is “yes” and you can do better than 5N+5N operations.

The quick sketch for N=2:

row0 += shfl_xor(row0,1)
row1 += shfl_xor(row1,1)
row01 = (lane_idx & 1) == 0 ? row0 : row1

→ Neighboring lanes from row0 are in even lanes and row1 neighbors are in odd lanes.

Continuing:

row01 += shfl_xor(row01,2)

→ Half of the lanes are now redundant and we can work on rows 2 and 3 after we partially reduce them… With clever shuffling you might be able to get them in the right position to just drop into the best “free” slots in the row01 warp when you build row0123.

( hand waving at this point because it feels like it’s going to work for N > 2 )

Here’s a diagram of what I’m talking about with N=2 and warps that are only 4 lanes wide:

_row 0_   _row 1_
0 1 2 3   4 5 6 7 
 X   X     X   X   shfl_xor(1)
0 1 2 3   4 5 6 7 
1 0 3 2   5 4 7 6  +
   \         /     
    \       /      "select" with ternary operator on even/odd                     
     0 5 2 7                            
     1 4 3 6                            
      \ X /        shfl_xor(2)          
       V V                              
       0 5                              
       1 4                              
       2 7                              
       3 6         +

Whatever the case, I would draw this out on paper or diagram it if you have a favorite app.

The routines you would develop for N=2-32 would be:

float csp256_warp_reduce_2(float v0, float v1);
float csp256_warp_reduce_3(float v1, float v2, float v3);
...

You would have to document where the sums end up for each case—most likely they’ll be in the first N lanes of the warp (?).

Thanks! After writing it down with two different color pens, I see how it works and can verify that it is more efficient (on paper). That is a very clever solution.

I might actually come back and post a detailed solution / implementation, later.

PS: Your blog and numerous posts have been helpful to me in the past, so thank you for that too.

@allanmac: +1 for the use of ASCII art :-)

Thanks @csp256! A working implementation would be cool to see.

@njuffa: Emacs meta-x picture-mode :)

i had the same problem and changed the whole algo to produce 32 data lines before i start to summarize them. so, with N=32 the entire operation is just 32 additions plus 33*128 bytes of shared memory for holding the data. i even decided to dedicate a separate warp to this process - i.e. 4-8 warps are producing data into first memory array, while at the same time the 9th warp summarize data in the second array and write “sums” to memory

With N=16, you need to perform 16 full-warp additions and then combine two halves of a warp (SHFL+ADD)

With N=8, you are going to perform 8 ADDs followed by 2*(SHFL+ADD)

N=5 is probably best dealed as the sum of N=1 and N=4

however i completely dropped LD/ST operations count which essentially will dominate ALU ops in such approach. It’s a N stores followed by N loads

Nice!

Here is a first stab at it. The function distance325 loads the value into an element of ssd. This is for a specialized brute force k-NN implementation, so I did not write it in the most general way. However, it should be clear what needs to happen to make it work under more arbitrary circumstance.

The result of the posted code, for each warp, is that the distance squared between query vector i and the training vector is held in ssd[0] by all threads such that (threadIdx.x % 4 == i).

I haven’t profiled this section of the code yet, but it is the hottest part of my entire code base (including CPU code). It is being run by every warp in every block inside of four loops, and I am register limited.

register float ssd[3];

distance325(&ssd[0], query[0], s_training[trainingOffset]);
distance325(&ssd[1], query[1], s_training[trainingOffset]);
ssd[0] += __shfl_xor(ssd[0], 1);
ssd[1] += __shfl_xor(ssd[1], 1);
if (threadIdx.x & 1) {
    ssd[0] = ssd[1];
}

distance325(&ssd[1], query[2], s_training[trainingOffset]);
distance325(&ssd[2], query[3], s_training[trainingOffset]);
ssd[1] += __shfl_xor(ssd[1], 1);
ssd[2] += __shfl_xor(ssd[2], 1);
if (threadIdx.x & 1) {
    ssd[1] = ssd[2];
}

ssd[0] += __shfl_xor(ssd[0], 2);
ssd[1] += __shfl_xor(ssd[1], 2);
if (threadIdx.x & 2) {
    ssd[0] = ssd[1];
}

ssd[0] += __shfl_xor(ssd[0], 4);
ssd[0] += __shfl_xor(ssd[0], 8);
ssd[0] += __shfl_xor(ssd[0], 16);

EDIT: In case anyone finds this on Google, the code below is what ended up being the fastest for me (in trying to compute a large number of Hamming distances between 2048 bit long vectors). Note that each bitvector is stored in 2 Int32’s in each of 32 threads in the warp. Because the magnitude of the Hamming weight is bounded above, we can safely pack two values into each dist variable.

// The compiler throws a hissy fit if you try to make dist an array, and tosses everything into local memory.
                    register int dist0, dist1, dist2, dist3, dist4, dist5, dist6, dist7;
                    // Also, the compiler does not like this being in a (fully unrolled) loop... drama queen.
                    dist0 = __popc(query[0][0] ^ train[0]) + __popc(query[0][1] ^ train[1]);
                    dist1 = __popc(query[1][0] ^ train[0]) + __popc(query[1][1] ^ train[1]);
                    dist2 = __popc(query[2][0] ^ train[0]) + __popc(query[2][1] ^ train[1]);
                    dist3 = __popc(query[3][0] ^ train[0]) + __popc(query[3][1] ^ train[1]);
                    dist4 = __popc(query[4][0] ^ train[0]) + __popc(query[4][1] ^ train[1]);
                    dist5 = __popc(query[5][0] ^ train[0]) + __popc(query[5][1] ^ train[1]);
                    dist6 = __popc(query[6][0] ^ train[0]) + __popc(query[6][1] ^ train[1]);
                    dist7 = __popc(query[7][0] ^ train[0]) + __popc(query[7][1] ^ train[1]);
                    dist0 |= (__popc(query[ 8][0] ^ train[0]) + __popc(query[ 8][1] ^ train[1]))<<16;
                    dist1 |= (__popc(query[ 9][0] ^ train[0]) + __popc(query[ 9][1] ^ train[1]))<<16;
                    dist2 |= (__popc(query[10][0] ^ train[0]) + __popc(query[10][1] ^ train[1]))<<16;
                    dist3 |= (__popc(query[11][0] ^ train[0]) + __popc(query[11][1] ^ train[1]))<<16;
                    dist4 |= (__popc(query[12][0] ^ train[0]) + __popc(query[12][1] ^ train[1]))<<16;
                    dist5 |= (__popc(query[13][0] ^ train[0]) + __popc(query[13][1] ^ train[1]))<<16;
                    dist6 |= (__popc(query[14][0] ^ train[0]) + __popc(query[14][1] ^ train[1]))<<16;
                    dist7 |= (__popc(query[15][0] ^ train[0]) + __popc(query[15][1] ^ train[1]))<<16;

                    dist0 += __shfl_xor(dist0,   1);
                    dist1 += __shfl_xor(dist1,   1);
                    if (threadIdx.x & 1) dist0 = dist1;
                    dist2 += __shfl_xor(dist2,   1);
                    dist3 += __shfl_xor(dist3,   1);
                    if (threadIdx.x & 1) dist2 = dist3;
                    dist4 += __shfl_xor(dist4,   1);
                    dist5 += __shfl_xor(dist5,   1);
                    if (threadIdx.x & 1) dist4 = dist5;
                    dist6 += __shfl_xor(dist6,   1);
                    dist7 += __shfl_xor(dist7,   1);
                    if (threadIdx.x & 1) dist6 = dist7;
                    dist0 += __shfl_xor(dist0,   2);
                    dist2 += __shfl_xor(dist2,   2);
                    if (threadIdx.x & 2) dist0 = dist2;
                    dist4 += __shfl_xor(dist4,   2);
                    dist6 += __shfl_xor(dist6,   2);
                    if (threadIdx.x & 2) dist4 = dist6;
                    dist0 += __shfl_xor(dist0,   4);
                    dist4 += __shfl_xor(dist4,   4);
                    if (threadIdx.x & 4) dist0 = dist4;
                    dist0 += __shfl_xor(dist0,   8);
                    dist0 += __shfl_xor(dist0,  16);
                    if (threadIdx.x < 8) dist0 &= 2047;
                    else dist0 >>= 16;

Threads 0 through 15 now have the Hamming distance between the bitvectors query[laneID] and train stored in dist0. This is several times faster than any alternative I am aware of.