Here is a first stab at it. The function distance325 loads the value into an element of ssd. This is for a specialized brute force k-NN implementation, so I did not write it in the most general way. However, it should be clear what needs to happen to make it work under more arbitrary circumstance.
The result of the posted code, for each warp, is that the distance squared between query vector i and the training vector is held in ssd[0] by all threads such that (threadIdx.x % 4 == i).
I haven’t profiled this section of the code yet, but it is the hottest part of my entire code base (including CPU code). It is being run by every warp in every block inside of four loops, and I am register limited.
register float ssd[3];
distance325(&ssd[0], query[0], s_training[trainingOffset]);
distance325(&ssd[1], query[1], s_training[trainingOffset]);
ssd[0] += __shfl_xor(ssd[0], 1);
ssd[1] += __shfl_xor(ssd[1], 1);
if (threadIdx.x & 1) {
ssd[0] = ssd[1];
}
distance325(&ssd[1], query[2], s_training[trainingOffset]);
distance325(&ssd[2], query[3], s_training[trainingOffset]);
ssd[1] += __shfl_xor(ssd[1], 1);
ssd[2] += __shfl_xor(ssd[2], 1);
if (threadIdx.x & 1) {
ssd[1] = ssd[2];
}
ssd[0] += __shfl_xor(ssd[0], 2);
ssd[1] += __shfl_xor(ssd[1], 2);
if (threadIdx.x & 2) {
ssd[0] = ssd[1];
}
ssd[0] += __shfl_xor(ssd[0], 4);
ssd[0] += __shfl_xor(ssd[0], 8);
ssd[0] += __shfl_xor(ssd[0], 16);
EDIT: In case anyone finds this on Google, the code below is what ended up being the fastest for me (in trying to compute a large number of Hamming distances between 2048 bit long vectors). Note that each bitvector is stored in 2 Int32’s in each of 32 threads in the warp. Because the magnitude of the Hamming weight is bounded above, we can safely pack two values into each dist variable.
// The compiler throws a hissy fit if you try to make dist an array, and tosses everything into local memory.
register int dist0, dist1, dist2, dist3, dist4, dist5, dist6, dist7;
// Also, the compiler does not like this being in a (fully unrolled) loop... drama queen.
dist0 = __popc(query[0][0] ^ train[0]) + __popc(query[0][1] ^ train[1]);
dist1 = __popc(query[1][0] ^ train[0]) + __popc(query[1][1] ^ train[1]);
dist2 = __popc(query[2][0] ^ train[0]) + __popc(query[2][1] ^ train[1]);
dist3 = __popc(query[3][0] ^ train[0]) + __popc(query[3][1] ^ train[1]);
dist4 = __popc(query[4][0] ^ train[0]) + __popc(query[4][1] ^ train[1]);
dist5 = __popc(query[5][0] ^ train[0]) + __popc(query[5][1] ^ train[1]);
dist6 = __popc(query[6][0] ^ train[0]) + __popc(query[6][1] ^ train[1]);
dist7 = __popc(query[7][0] ^ train[0]) + __popc(query[7][1] ^ train[1]);
dist0 |= (__popc(query[ 8][0] ^ train[0]) + __popc(query[ 8][1] ^ train[1]))<<16;
dist1 |= (__popc(query[ 9][0] ^ train[0]) + __popc(query[ 9][1] ^ train[1]))<<16;
dist2 |= (__popc(query[10][0] ^ train[0]) + __popc(query[10][1] ^ train[1]))<<16;
dist3 |= (__popc(query[11][0] ^ train[0]) + __popc(query[11][1] ^ train[1]))<<16;
dist4 |= (__popc(query[12][0] ^ train[0]) + __popc(query[12][1] ^ train[1]))<<16;
dist5 |= (__popc(query[13][0] ^ train[0]) + __popc(query[13][1] ^ train[1]))<<16;
dist6 |= (__popc(query[14][0] ^ train[0]) + __popc(query[14][1] ^ train[1]))<<16;
dist7 |= (__popc(query[15][0] ^ train[0]) + __popc(query[15][1] ^ train[1]))<<16;
dist0 += __shfl_xor(dist0, 1);
dist1 += __shfl_xor(dist1, 1);
if (threadIdx.x & 1) dist0 = dist1;
dist2 += __shfl_xor(dist2, 1);
dist3 += __shfl_xor(dist3, 1);
if (threadIdx.x & 1) dist2 = dist3;
dist4 += __shfl_xor(dist4, 1);
dist5 += __shfl_xor(dist5, 1);
if (threadIdx.x & 1) dist4 = dist5;
dist6 += __shfl_xor(dist6, 1);
dist7 += __shfl_xor(dist7, 1);
if (threadIdx.x & 1) dist6 = dist7;
dist0 += __shfl_xor(dist0, 2);
dist2 += __shfl_xor(dist2, 2);
if (threadIdx.x & 2) dist0 = dist2;
dist4 += __shfl_xor(dist4, 2);
dist6 += __shfl_xor(dist6, 2);
if (threadIdx.x & 2) dist4 = dist6;
dist0 += __shfl_xor(dist0, 4);
dist4 += __shfl_xor(dist4, 4);
if (threadIdx.x & 4) dist0 = dist4;
dist0 += __shfl_xor(dist0, 8);
dist0 += __shfl_xor(dist0, 16);
if (threadIdx.x < 8) dist0 &= 2047;
else dist0 >>= 16;
Threads 0 through 15 now have the Hamming distance between the bitvectors query[laneID] and train stored in dist0. This is several times faster than any alternative I am aware of.