Cooperative groups are much slower than CUB

A block reduce using cooperative groups takes 3.4 times as long as CUB. I find this odd since cooperative groups are hardware accelerated.

Questions:

  • Why might cooperative groups be slower?
  • Is CUB recommended over cooperative groups?

I tested this with two programs given below on my RTX 3060 targeting compute capability 80. They each perform about 67 million block reductions in a loop on a single SM

Method Runtime
Cooperative groups 36.531s
CUB 10.851s

Block reduce with cooperative groups

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

namespace cg = cooperative_groups;

constexpr auto BLOCK_SIZE{1024};
constexpr auto ITERATIONS{64 * 1024 * 1024};

__global__ void reduce_kernel() {
  const auto block{cg::this_thread_block()};
  const auto tile{cg::tiled_partition<BLOCK_SIZE>(block)};
  const auto lane{tile.thread_rank()};
  for (auto i{0}; i < ITERATIONS; ++i) {
    cg::reduce(tile, lane, cg::plus<unsigned>());
  }
}

int main() {
  reduce_kernel<<<1, BLOCK_SIZE>>>();
  cudaDeviceSynchronize();
  return 0;
}

Block reduce with CUB

#include <cub/block/block_reduce.cuh>

constexpr auto BLOCK_SIZE{1024};
constexpr auto ITERATIONS{64 * 1024 * 1024};

__global__ void reduce_kernel() {
  const auto lane{threadIdx.x};
  for (auto i{0}; i < ITERATIONS; ++i) {
    cub::BlockReduce<unsigned, BLOCK_SIZE>{}.Sum(lane);
  }
}

int main() {
  reduce_kernel<<<1, BLOCK_SIZE>>>();
  cudaDeviceSynchronize();
  return 0;
}

1 Like

Aside: Both of these have the potential to be reduced to something meaningless, because neither one modifies global state. The compiler is free to discard code that has no effect on observable program behavior. Even if this does not result in full code discard (which it appears not to), there is no guarantee that the compiler has not found different avenues to optimize, since neither one results in observable change to global state.

Not sure what that means. CUB is free to use cooperative groups, and/or free to use any feature that cooperative groups might use, such as warp shuffle. In fact both appear to use REDUX (for cc8.0).

You can study the SASS.

There is overlap in functionality, but there are also differences in capability. And I wouldn’t try to draw conclusions based on such a test case, for reasons already stated.

FWIW, after making some changes to address the concerns I raised above, I don’t see a difference in the measurement, and I can’t explain at the moment why cub is much faster than cg.

1 Like

@Robert_Crovella Thank you for the quick response. I have found many of your answers on this forum very insightful.

That is an excellent point. I updated the programs to use the results, given below. The difference is similar (2.3 times as long) except now there is modulus overhead.

Method Runtime
Cooperative groups 50.865s
CUB 22.440s

I also observed a similar slowdown when switching a larger program from using CUB to cooperative groups.

The guide states “reduce […] takes advantage of hardware acceleration […] for the arithmetic add, min, or max operations and the logical AND, OR, or XOR”.

My interpretation of this quote is that cooperative groups use a hardware accelerated reduction that, in theory, cannot be beat. At best, I would expect CUB to match this, yet CUB is much faster.

In any case, it is surprising that there is a discrepancy between the two methods at all. Both have access to the same features and both know the block size at compile time. The only difference I know of is that CUB uses shared memory.

Thank you for sharing this useful tool. I mean to ask from a user’s perspective:

  • Why am I getting poorer performance with cooperative groups? Specifically, am I using it incorrectly?
  • Given that CUB is faster, should I prefer it where possible?

I agree. For now, I am focused only on block reduce which both are capable of.

Thank you for mentioning the improvement. The updated test cases are below.

Block reduce with cooperative groups

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

namespace cg = cooperative_groups;

constexpr auto BLOCK_SIZE{1024};
constexpr auto ITERATIONS{64 * 1024 * 1024};

__global__ void reduce_kernel(unsigned &result) {
  const auto block{cg::this_thread_block()};
  const auto tile{cg::tiled_partition<BLOCK_SIZE>(block)};
  const auto lane{tile.thread_rank()};
  auto value{0u};
  for (auto i{0}; i < ITERATIONS; ++i) {
    const auto thread_value{i % lane};
    const auto increment{cg::reduce(tile, thread_value, cg::plus<unsigned>())};
    value = increment % value;
  }
  invoke_one(tile, [&] { result = value; });
}

int main() {
  unsigned *result;
  cudaMallocManaged(&result, sizeof *result);
  reduce_kernel<<<1, BLOCK_SIZE>>>(*result);
  cudaDeviceSynchronize();
  // Prints "result: 139".
  printf("result: %u\n", *result);
  return 0;
}

Block reduce with CUB

#include <cub/block/block_reduce.cuh>

constexpr auto BLOCK_SIZE{1024};
constexpr auto ITERATIONS{64 * 1024 * 1024};

__global__ void reduce_kernel(unsigned &result) {
  const auto lane{threadIdx.x};
  auto value{0u};
  for (auto i{0}; i < ITERATIONS; ++i) {
    const auto thread_value{i % lane};
    const auto increment{
        cub::BlockReduce<unsigned, BLOCK_SIZE>{}.Sum(thread_value)};
    value = increment % value;
  }
  if (threadIdx.x == 0) {
    result = value;
  }
}

int main() {
  unsigned *result;
  cudaMallocManaged(&result, sizeof *result);
  reduce_kernel<<<1, BLOCK_SIZE>>>(*result);
  cudaDeviceSynchronize();
  // Prints "result: 139".
  printf("result: %u\n", *result);
  return 0;
}

I suspect the discrepancy is related to the size. If I change the block size to 32, then the cg method becomes ~4x faster rather than ~4x slower. (In that case, CUB might catch up if we switched from block reduce to warp reduce).

AFAIK, the largest “hardware-accelerated” tile for cg is a tile size of 32 (templated, known at compile time). While you can obviously use a larger tile size, I don’t know what decisions/implementation cg is using in that case, and it may not be “best”. Most of the examples I see in the programming guide pick a tile size of 32.

I agree that it would be nice if CUB and cg were comparable in performance, generally. You’re welcome to file a bug.

Bug filed: 4950431.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.

Bring back conclusion of ticket 4950431

Robert is correct in his last comment .What happened here is you are comparing different level of synchronization and we can not generically conclude if CUB is faster or cooperative groups . At the time , we don’t suggest CG(cooperative groups) tile reduction as block-wide, it will work for that purpose, but it should be used in cases where it’s a subset of a block. In the future , we might consider a way to call CG reduce on thread_block to match CUB performance . But for now the reported behavior is in line with our expectations.

If the tile partition size of CG is smaller or equal to the size of a warp (32 threads), it will use the warp synchronization (which also supports any subset of a warp, that’s why its any size smaller than 32). This is why reduction of a tile with size 32 is fast, much faster than block-wide reduction in CUB, because warp synchronization is much faster than block synchronization .
If a tile is larger than a warp we can’t only use warp synchronization, because it now contains multiple warps. The synchronization mechanism used in cooperative groups tile is much more generic and because of that it needs to be implemented in software. CUB can use block-wide synchronization which is hardware accelerated. This is why you will see a big difference in performance of tiles with size <= 32 and > 32 compared to CUB.

Another thing we want to FYI here is , users need to reserve shared memory for thread_block_tile usage for tiles of size larger than 32 on Compute Capability 7.5 or lower , on Compute Capability 8.0 and higher , block_tile_memory is not needed . This is described in CG docs: CUDA C++ Programming Guide

/// The following code will create tiles of size 128 on all Compute Capabilities.
/// block_tile_memory can be omitted on Compute Capability 8.0 or higher.
__global__ void kernel(...) {
    // reserve shared memory for thread_block_tile usage,
    //   specify that block size will be at most 256 threads.
    __shared__ block_tile_memory<256> shared;
    thread_block thb = this_thread_block(shared);

    // Create tiles with 128 threads.
    auto tile = tiled_partition<128>(thb);

    // ...
}
1 Like