Optimizing Unsorted Segmented Reduction with Large Number of Segments

I have an input tensor of shape BxCxN and a segment ID tensor of shape CxNx2, where each input corresponds to a 2D segment ID represented in the last dimension. In total, there are HxW segments. All inputs along the B dimension share the same segment IDs. Since the input tensor is unsorted, any input element can belong to any segment.

For each input, we identify its segment ID, then reduce all values that belong to the same segment ID (either max or mean). The expected output is a BxCxHxW tensor, where each element represents the reduction result for a specific segment.

The challenge is that the number of segments are large: both H and W are around 512, making it impractical to store the entire output in shared memory.

Currently, I’m storing the output in global memory and using atomic operations to compare and update the output values for each thread and each input during the reduction.

I considered assigning each thread to only process a subset of the segments, allowing temporary outputs to fit within shared memory. However, this approach would multiply both the number of reads and the operations involved in conditional checks and reductions. I can also do a bucket sort first, but that would increase the read as well.

I’m wondering if there are more efficient alternatives to improve performance…

Thanks!

Hi silvaurus17,

I would sort the CxNx2 tensor or even put the C and N values into (around 512*512) different buckets for each 2D segment ID. Then each thread or warp or block can process one segment ID (and the tensors belonging to it) at a time.

Why does C appear in the output? Shouldn’t it be BxHxW as C is used for the segment ID? Or actually only HxW as all tensors share the same segment id along B?

However large the output size is, if within the same segment Id you have to create several independent reductions (e.g. a 2D array of reductions), then you can process the reductions element by element and do not have to keep the memory for huge intermediate results. And different threads/warps/blocks (whatever you choose) do not use the same tensor Id, so you do not have to synchronize or reduce beyond those boundaries.