I have an input tensor of shape BxCxN and a segment ID tensor of shape CxNx2, where each input corresponds to a 2D segment ID represented in the last dimension. In total, there are HxW segments. All inputs along the B dimension share the same segment IDs. Since the input tensor is unsorted, any input element can belong to any segment.
For each input, we identify its segment ID, then reduce all values that belong to the same segment ID (either max or mean). The expected output is a BxCxHxW tensor, where each element represents the reduction result for a specific segment.
The challenge is that the number of segments are large: both H and W are around 512, making it impractical to store the entire output in shared memory.
Currently, I’m storing the output in global memory and using atomic operations to compare and update the output values for each thread and each input during the reduction.
I considered assigning each thread to only process a subset of the segments, allowing temporary outputs to fit within shared memory. However, this approach would multiply both the number of reads and the operations involved in conditional checks and reductions. I can also do a bucket sort first, but that would increase the read as well.
I’m wondering if there are more efficient alternatives to improve performance…
Thanks!