I crafted a benchmark for median computation of 40 elements per segment, 1 million segments, using 1 warp per segment with 2 elements per thread. It compares cub::blockradixsort, cub::warpmergesort, and a sorting network for this thread configuration found in the linked repository (https://gitlab.rlp.net/pararch/faster-segmented-sort-on-gpus/-/blob/main/gtx_1080_ti_keys_only/my_regsort_kernels.h?ref_type=heads#L4582)
Using cuda 12.2 on an A100, I observe the following timings:
medianKernel_sortingnetwork: 0.000992722s
medianKernel_cubradixsort: 0.00561694s
medianKernel_cubmergesort: 0.00165363s
I did not check other thread configurations for 64 elements, so I cannot tell whether 32*2 is optimal for any of the three algorithms.
//nvcc -std=c++17 -O3 -arch=sm_80 --expt-relaxed-constexpr main.cu -o main
#include <cub/cub.cuh>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <vector>
#include <numeric>
#include <iostream>
#include <algorithm>
#include <cassert>
#include <chrono>
#include <random>
#define CMP_SWP(t1,_a,_b) if(_a>_b) {t1 _t=_a;_a=_b;_b=_t;}
#define EQL_SWP(t1,_a,_b) if(_a!=_b) {t1 _t=_a;_a=_b;_b=_t;}
#define SWP(t1,_a,_b) {t1 _t=_a;_a=_b;_b=_t;}
template<class K>
__device__ inline void exch_intxn(K &k0, K &k1, int mask, const int bit) {
K ex_k0, ex_k1;
ex_k0 = k0;
ex_k1 = __shfl_xor_sync(0xffffffff,k1, mask);
CMP_SWP(K, ex_k0, ex_k1);
if(bit) EQL_SWP(K, ex_k0, ex_k1);
k0 = ex_k0;
k1 = __shfl_xor_sync(0xffffffff,ex_k1, mask);
}
template<class K>
__device__ inline void exch_paral(K &k0, K &k1, int mask, const int bit) {
K ex_k0, ex_k1;
if(bit) SWP(K, k0, k1);
ex_k0 = k0;
ex_k1 = __shfl_xor_sync(0xffffffff,k1, mask);
CMP_SWP(K, ex_k0, ex_k1);
if(bit) EQL_SWP(K, ex_k0, ex_k1);
k0 = ex_k0;
k1 = __shfl_xor_sync(0xffffffff,ex_k1, mask);
if(bit) SWP(K, k0, k1);
}
template<class K>
__device__
void sortnetwork64_32_2(
K& rg_k0,
K& rg_k1
){
const int tid = (threadIdx.x & 31);
const int bit1 = (tid>>0)&0x1;
const int bit2 = (tid>>1)&0x1;
const int bit3 = (tid>>2)&0x1;
const int bit4 = (tid>>3)&0x1;
const int bit5 = (tid>>4)&0x1;
CMP_SWP(K, rg_k0, rg_k1);
exch_intxn(rg_k0, rg_k1, 0x1, bit1);
CMP_SWP(K, rg_k0, rg_k1);
exch_intxn(rg_k0, rg_k1, 0x3, bit2);
exch_paral(rg_k0, rg_k1, 0x1, bit1);
CMP_SWP(K, rg_k0, rg_k1);
exch_intxn(rg_k0, rg_k1, 0x7, bit3);
exch_paral(rg_k0, rg_k1, 0x2, bit2);
exch_paral(rg_k0, rg_k1, 0x1, bit1);
CMP_SWP(K, rg_k0, rg_k1);
exch_intxn(rg_k0, rg_k1, 0xf, bit4);
exch_paral(rg_k0, rg_k1, 0x4, bit3);
exch_paral(rg_k0, rg_k1, 0x2, bit2);
exch_paral(rg_k0, rg_k1, 0x1, bit1);
CMP_SWP(K, rg_k0, rg_k1);
exch_intxn(rg_k0, rg_k1, 0x1f, bit5);
exch_paral(rg_k0, rg_k1, 0x8, bit4);
exch_paral(rg_k0, rg_k1, 0x4, bit3);
exch_paral(rg_k0, rg_k1, 0x2, bit2);
exch_paral(rg_k0, rg_k1, 0x1, bit1);
CMP_SWP(K, rg_k0, rg_k1);
}
//uses 1 warp per segment, 2 elements per thread. assumes max segment size of 64
__global__
void medianKernel_sortingnetwork(float* medianOut, const int* data, const int* beginOffsets, const int* endOffsets, int numSegments){
__builtin_assume(blockDim.x == 32);
const int warpId = blockIdx.x;
if(warpId < numSegments){
const int segmentBegin = beginOffsets[warpId];
const int segmentEnd = endOffsets[warpId];
const int segmentSize = segmentEnd - segmentBegin;
int items[2];
cub::LoadDirectStriped<32>(threadIdx.x, data + segmentBegin, items, segmentSize, std::numeric_limits<int>::max());
sortnetwork64_32_2(items[0], items[1]);
float median = 0;
if(segmentSize % 2 == 1){
const int medianIndex = segmentSize / 2;
const int medianThread = medianIndex / 2;
const int medianIndexInThread = medianIndex % 2;
const int src = (medianIndexInThread == 0) ? items[0] : items[1];
median = __shfl_sync(0xFFFFFFFF, src, medianThread);
}else{
float median_l = 0;
float median_r = 0;
const int medianIndex_r = segmentSize / 2;
const int medianThread_r = medianIndex_r / 2;
const int medianIndexInThread_r = medianIndex_r % 2;
const int src_r = (medianIndexInThread_r == 0) ? items[0] : items[1];
median_r = __shfl_sync(0xFFFFFFFF, src_r, medianThread_r);
const int medianIndex_l = (segmentSize-1) / 2;
const int medianThread_l = medianIndex_l / 2;
const int medianIndexInThread_l = medianIndex_l % 2;
const int src_l = (medianIndexInThread_l == 0) ? items[0] : items[1];
median_l = __shfl_sync(0xFFFFFFFF, src_l, medianThread_l);
median = (median_l + median_r) / 2.0f;
}
if(threadIdx.x == 0){
medianOut[warpId] = median;
}
}
}
//uses 1 warp per segment, 2 elements per thread. assumes max segment size of 64
__global__
void medianKernel_cubradixsort(float* medianOut, const int* data, const int* beginOffsets, const int* endOffsets, int numSegments){
__builtin_assume(blockDim.x == 32);
using BlockRadixSort = cub::BlockRadixSort<int, 32, 2>;
__shared__ typename BlockRadixSort::TempStorage tempSort;
const int warpId = blockIdx.x;
if(warpId < numSegments){
const int segmentBegin = beginOffsets[warpId];
const int segmentEnd = endOffsets[warpId];
const int segmentSize = segmentEnd - segmentBegin;
int items[2];
cub::LoadDirectStriped<32>(threadIdx.x, data + segmentBegin, items, segmentSize, std::numeric_limits<int>::max());
BlockRadixSort(tempSort).Sort(items);
float median = 0;
if(segmentSize % 2 == 1){
const int medianIndex = segmentSize / 2;
const int medianThread = medianIndex / 2;
const int medianIndexInThread = medianIndex % 2;
const int src = (medianIndexInThread == 0) ? items[0] : items[1];
median = __shfl_sync(0xFFFFFFFF, src, medianThread);
}else{
float median_l = 0;
float median_r = 0;
const int medianIndex_r = segmentSize / 2;
const int medianThread_r = medianIndex_r / 2;
const int medianIndexInThread_r = medianIndex_r % 2;
const int src_r = (medianIndexInThread_r == 0) ? items[0] : items[1];
median_r = __shfl_sync(0xFFFFFFFF, src_r, medianThread_r);
const int medianIndex_l = (segmentSize-1) / 2;
const int medianThread_l = medianIndex_l / 2;
const int medianIndexInThread_l = medianIndex_l % 2;
const int src_l = (medianIndexInThread_l == 0) ? items[0] : items[1];
median_l = __shfl_sync(0xFFFFFFFF, src_l, medianThread_l);
median = (median_l + median_r) / 2.0f;
}
if(threadIdx.x == 0){
medianOut[warpId] = median;
}
}
}
//uses 1 warp per segment, 2 elements per thread. assumes max segment size of 64
__global__
void medianKernel_cubmergesort(float* medianOut, const int* data, const int* beginOffsets, const int* endOffsets, int numSegments){
__builtin_assume(blockDim.x == 32);
using WarpMergeSort = cub::WarpMergeSort<int, 2>;
__shared__ typename WarpMergeSort::TempStorage tempSort;
const int warpId = blockIdx.x;
if(warpId < numSegments){
const int segmentBegin = beginOffsets[warpId];
const int segmentEnd = endOffsets[warpId];
const int segmentSize = segmentEnd - segmentBegin;
int items[2];
cub::LoadDirectStriped<32>(threadIdx.x, data + segmentBegin, items, segmentSize, std::numeric_limits<int>::max());
WarpMergeSort(tempSort).Sort(items, std::less<int>{});
float median = 0;
if(segmentSize % 2 == 1){
const int medianIndex = segmentSize / 2;
const int medianThread = medianIndex / 2;
const int medianIndexInThread = medianIndex % 2;
const int src = (medianIndexInThread == 0) ? items[0] : items[1];
median = __shfl_sync(0xFFFFFFFF, src, medianThread);
}else{
float median_l = 0;
float median_r = 0;
const int medianIndex_r = segmentSize / 2;
const int medianThread_r = medianIndex_r / 2;
const int medianIndexInThread_r = medianIndex_r % 2;
const int src_r = (medianIndexInThread_r == 0) ? items[0] : items[1];
median_r = __shfl_sync(0xFFFFFFFF, src_r, medianThread_r);
const int medianIndex_l = (segmentSize-1) / 2;
const int medianThread_l = medianIndex_l / 2;
const int medianIndexInThread_l = medianIndex_l % 2;
const int src_l = (medianIndexInThread_l == 0) ? items[0] : items[1];
median_l = __shfl_sync(0xFFFFFFFF, src_l, medianThread_l);
median = (median_l + median_r) / 2.0f;
}
if(threadIdx.x == 0){
medianOut[warpId] = median;
}
}
}
int main(){
int numSegments = 1000000;
int elementsPerSegment = 40;
const int timingIterations = 10;
std::mt19937 gen(42);
std::uniform_int_distribution<> distrib(0, 65536);
std::vector<int> data(numSegments * elementsPerSegment);
std::generate(data.begin(), data.end(), [&](){return distrib(gen);});
std::vector<int> offsets(numSegments+1);
for(int i = 0; i < numSegments+1; i++){
offsets[i] = elementsPerSegment * i;
}
thrust::device_vector<int> d_data = data;
thrust::device_vector<int> d_offsets = offsets;
thrust::device_vector<float> d_medianOut(numSegments);
std::vector<float> medianOutGpu(numSegments);
std::vector<float> medianOutCpu(numSegments);
auto timebegin = std::chrono::system_clock::now();
for(int iter = 0; iter < timingIterations; iter++){
for(int i = 0; i < numSegments; i++){
const int segmentBegin = offsets[i];
const int segmentEnd = offsets[i+1];
const int segmentSize = segmentEnd - segmentBegin;
if (segmentSize % 2 == 0) {
std::nth_element(
data.data() + segmentBegin,
data.data() + segmentBegin + segmentSize / 2,
data.data() + segmentEnd
);
std::nth_element(
data.data() + segmentBegin,
data.data() + segmentBegin + (segmentSize - 1) / 2,
data.data() + segmentEnd
);
medianOutCpu[i] = (data[segmentBegin + (segmentSize - 1) / 2] + data[segmentBegin + segmentSize / 2]) / 2.0;
}else{
std::nth_element(
data.data() + segmentBegin,
data.data() + segmentBegin + segmentSize / 2,
data.data() + segmentEnd
);
medianOutCpu[i] = data[segmentBegin + segmentSize / 2];
}
}
}
auto timeend = std::chrono::system_clock::now();
std::chrono::duration<double> delta = timeend - timebegin;
std::cout << "median CPU: " << delta.count() / timingIterations << "s\n";
timebegin = std::chrono::system_clock::now();
for(int iter = 0; iter < timingIterations; iter++){
medianKernel_sortingnetwork<<<numSegments, 32>>>(
d_medianOut.data().get(),
d_data.data().get(),
d_offsets.data().get(),
d_offsets.data().get() + 1,
numSegments
);
}
cudaDeviceSynchronize();
timeend = std::chrono::system_clock::now();
delta = timeend - timebegin;
std::cout << "medianKernel_sortingnetwork: " << delta.count() / timingIterations << "s\n";
thrust::copy(d_medianOut.begin(), d_medianOut.end(), medianOutGpu.begin());
for(int i = 0; i < numSegments; i++){
if(std::abs(medianOutGpu[i] - medianOutCpu[i]) > 1e-5){
std::cout << "medianKernel_sortingnetwork error segment " << i << " " << medianOutGpu[i] << " " << medianOutCpu[i] << "\n";
break;
}
}
thrust::fill(d_medianOut.begin(), d_medianOut.end(), 0);
timebegin = std::chrono::system_clock::now();
for(int iter = 0; iter < timingIterations; iter++){
medianKernel_cubradixsort<<<numSegments, 32>>>(
d_medianOut.data().get(),
d_data.data().get(),
d_offsets.data().get(),
d_offsets.data().get() + 1,
numSegments
);
}
cudaDeviceSynchronize();
timeend = std::chrono::system_clock::now();
delta = timeend - timebegin;
std::cout << "medianKernel_cubradixsort: " << delta.count() / timingIterations << "s\n";
thrust::copy(d_medianOut.begin(), d_medianOut.end(), medianOutGpu.begin());
for(int i = 0; i < numSegments; i++){
if(std::abs(medianOutGpu[i] - medianOutCpu[i]) > 1e-5){
std::cout << "medianKernel_cubradixsort error segment " << i << " " << medianOutGpu[i] << " " << medianOutCpu[i] << "\n";
break;
}
}
thrust::fill(d_medianOut.begin(), d_medianOut.end(), 0);
timebegin = std::chrono::system_clock::now();
for(int iter = 0; iter < timingIterations; iter++){
medianKernel_cubmergesort<<<numSegments, 32>>>(
d_medianOut.data().get(),
d_data.data().get(),
d_offsets.data().get(),
d_offsets.data().get() + 1,
numSegments
);
}
cudaDeviceSynchronize();
timeend = std::chrono::system_clock::now();
delta = timeend - timebegin;
std::cout << "medianKernel_cubmergesort: " << delta.count() / timingIterations << "s\n";
thrust::copy(d_medianOut.begin(), d_medianOut.end(), medianOutGpu.begin());
for(int i = 0; i < numSegments; i++){
if(std::abs(medianOutGpu[i] - medianOutCpu[i]) > 1e-5){
std::cout << "medianKernel_cubmergesort error segment " << i << " " << medianOutGpu[i] << " " << medianOutCpu[i] << "\n";
break;
}
}
}