How to launch CUDA Cooperative Groups Standard Deviation example kernel?

Hi everyone!
I’m in a process of learning CUDA C++ programming, and successfully implemented couple kernels without synchronization, but now for one of the algorithms I need to implement a Standard Deviation and several other similar computations.
I’ve found an example of Standard Deviation kernel in CUDA Cooperative Groups documentation, but don’t understand how to correctly launch it. I’ve also seen the reductionMultiBlockCG CUDA sample, but it appears that it uses a bit of a different approach.
I’d greatly appreciate some example code how to run this kernel correctly. And also, it is a bit unclear for me how the cg::reduce() collective knows the source of the data to reduce?

P.S. I’m experimenting on a laptop with NVIDIA Quattro 2000 GPU with Compute Capability of 7.5. And on my laptop, the above-mentioned reductionMultiBlockCG CUDA sample outperforms Thrust-based sum reduction ~6 times on the same data set. So, I really interested in Cooperative Groups implementation :).

The refered functions in the documentation are ordinary device functions which are meant to be called from within a kernel. They are not kernels themselves, and thus cannot be run as such.

The code to launch the multi block CG kernel is shown in the sample.

In this code int avg = cg::reduce(tile, thread_sum, cg::plus<int>()) / length;
cg::reduce will perform a reduction of all values of thread_sum across the threads specified by tile.

Hi @striker159, thank you for the prompt response.
Yes, I understand that this function should be called from a __global__ kernel. But the problem is that I don’t understand how to call this function correctly. For example, when I currently call this function from my kernel, it runs only on the first block and ignores the other. I.e., when I try to run this kernel on n=8 array {1,2,3,4,5,6,7,8} with 2 blocks and 4 threads per block, I get results only from the first half of the array. Would greatly appreciate some full example.

And thank you for the reduce clarification - it is clear how it works now.

If you need help debugging your code, please show your code.

Below is a simplified code from the CUDA documentation that only computes mean. It works when I have 1 block and 8 threads, and (obviously) doesn’t work when I have 2 blocks and 4 threads. The problem is that I cannot get how to correctly write statisticsKernel to return correct mean. With current implementation, it returns mean only for the first half of data. And I kind of understand why, but cannot understand how to improve it? I.e. how to sync means from different blocks? And for some reason, debugger catches only the first block and doesn’t enter into another.

#include "cuda_runtime.h"
#include "device_launch_parameters.h"

#include <stdio.h>
#include <math.h>
#include <iostream>
#include <cassert>

// This is for Visual Studio Intellisence to work 
#ifndef __CUDACC__
#define __CUDACC__
#endif // !__CUDACC__
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

/// Calculate approximate standard deviation of integers in vec
__device__ double reduce_mean(const cg::thread_block_tile<32>& tile, const double* vec, double length) {
    double thread_sum = 0;

    // calculate average first
    for (int i = tile.thread_rank(); i < length; i += tile.num_threads()) {
        thread_sum += vec[i];
    }
    // cg::plus<int> allows cg::reduce() to know it can use hardware acceleration for addition
    return cg::reduce(tile, thread_sum, cg::plus<double>()) / length;

    //int thread_diffs_sum = 0;
    //for (int i = tile.thread_rank(); i < length; i += tile.num_threads()) {
    //    double diff = vec[i] - avg;
    //    thread_diffs_sum += diff * diff;
    //}

    //// temporarily use floats to calculate the square root
    //double diff_sum = cg::reduce(tile, thread_diffs_sum, cg::plus<double>()) / length;

    //return sqrtf(diff_sum);
}


__global__ void statisticsKernel(const double* data, int length, double* mean)
{
    // Handle to thread block group
    cg::thread_block block = cg::this_thread_block();
    cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);

    *mean = reduce_mean(tile, data, length);
}

int main()
{
    constexpr int length = 8;
    double h_data[length]{ 1, 2, 3, 4, 5, 6, 7, 8 };

    double* d_data = nullptr;
    cudaMalloc(&d_data, sizeof(double) * length);
    cudaMemcpyAsync(d_data, h_data, sizeof(double) * length, cudaMemcpyHostToDevice);

    constexpr int n_blocks = 1;
    constexpr int n_threads = 8;

    dim3 dimBlock(n_threads, 1, 1);
    dim3 dimGrid(n_blocks, 1, 1);

    double* d_mean = nullptr;
    cudaMalloc(&d_mean, n_blocks * sizeof(double));

    void* kernelArgs[] = { (void*)&d_data, (void*)&length, (void*)&d_mean };

    int sharedMemorySize = n_threads * sizeof(double) * 3;

    cudaLaunchCooperativeKernel(
        (void*)statisticsKernel,
        dimGrid, dimBlock,
        kernelArgs,
        sharedMemorySize);

    auto error = cudaDeviceSynchronize();

    double h_mean = 0;
    cudaMemcpyAsync(&h_mean, d_mean, sizeof(double), cudaMemcpyDeviceToHost);
    cudaFree(d_data);
    cudaFree(d_mean);

    std::cout << "Mean: " << h_mean << std::endl;

    assert(h_mean == 4.5);

    return 0;
}

You need to learn more about CUDA C++ or perhaps C++.

This is causing everyone that calls that function to write its result in the same location. Obviously you cannot get a result from two different groups that way.

Take a look at a CUDA sample code like vectorAdd, to see how different threads write their individual results to different locations.

Furthermore, your code as posted only launches one block:

So you have to change that to get two blocks to run.

Hi @Robert_Crovella, @striker159,

Probably I incorrectly asked my question: the above code wasn’t finished, and I looked for some suggestion, how to call the device function correctly. I was managed to get the correct results in the below code:


#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <device_atomic_functions.h>

#include <stdio.h>
#include <math.h>
#include <iostream>

// This is for Visual Studio Intellisence to work 
#ifndef __CUDACC__
#define __CUDACC__
#endif // !__CUDACC__
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

/// Calculate approximate standard deviation of integers in vec
__device__ double reduce_mean(const cg::thread_block_tile<32>& tile, const double* vec, double length) {
    double thread_sum = 0;

    // calculate average first
    for (int i = tile.thread_rank(); i < length; i += tile.num_threads()) {
        thread_sum += vec[i];
    }
    // cg::plus<int> allows cg::reduce() to know it can use hardware acceleration for addition
    return cg::reduce(tile, thread_sum, cg::plus<double>()) / length;
}

__global__ void statisticsKernel(const double* data, int length, double* mean)
{
    // Handle to thread block group
    cg::grid_group grid = cg::this_grid();
    cg::thread_block block = cg::this_thread_block();
    cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);

    mean[blockIdx.x] = reduce_mean(tile, data + blockIdx.x * block.size(), block.size());

    if (grid.thread_rank() == 0) {
        for (int block = 1; block < gridDim.x; block++) {
            mean[0] += mean[block];
        }
        mean[0] /= gridDim.x;
    }
}

int main()
{
    constexpr int length = 8;
    double h_data[length]{ 1, 2, 3, 4, 5, 6, 7, 8 };

    double* d_data = nullptr;
    cudaMalloc(&d_data, sizeof(double) * length);
    cudaMemcpyAsync(d_data, h_data, sizeof(double) * length, cudaMemcpyHostToDevice);

    constexpr int n_blocks = 2;
    constexpr int n_threads = 4;

    dim3 dimBlock(n_threads, 1, 1);
    dim3 dimGrid(n_blocks, 1, 1);

    double* d_mean = nullptr;
    cudaMalloc(&d_mean, n_blocks * sizeof(double));

    void* kernelArgs[] = { (void*)&d_data, (void*)&length, (void*)&d_mean };

    int sharedMemorySize = n_threads * sizeof(double);

    cudaLaunchCooperativeKernel(
        (void*)statisticsKernel,
        dimGrid, dimBlock,
        kernelArgs,
        sharedMemorySize);

    auto error = cudaDeviceSynchronize();

    double h_mean = 0;
    cudaMemcpyAsync(&h_mean, d_mean, sizeof(double), cudaMemcpyDeviceToHost);
    cudaFree(d_data);
    cudaFree(d_mean);

    std::cout << "Expected Mean: 4.5 Actual Mean: " << h_mean << '\n';

    return 0;
}


But I don’t like a couple of moments there. For example, at the end of __global__ kernel code I anyway need to compute mean between blocks’ mean mean[0] /= gridDim.x; is there any way to avoid this?
Furthermore, I feel something is wrong with a way to pass data to the reduce_mean() function. In production version, I will add a stride here and use shared memory, but maybe there is some other pattern how to do this?

Perhaps you’ll have some other recommendations.
I greatly appreciate your time and patience. Thanks for your suggestions.

I see several issues with your code.

  • unless you use grid-group.sync(), the kernel can be launched with standard triple chevron syntax.
  • you create a group of 32 threads, but your block only has 4 threads. That is not valid.
  • You should think about how to handle the case when there are multiple groups of size 32 per block.
  • every thread writes its result to mean[blockIdx.x] instead of only one thread.
  • you do not synchronize after writing blockwide mean to memory
  • you should use multiple threads to compute total mean. either use a second reduction kernel, or use grid.sync() and perform the final reduction in the same kernel
  • reduction of doubles does not use hardware acceleration

Hi @striker159 ,

Thank you for your comments.
I think I made it working at the end. Not sure if the below implementation is correct, but it works :). I mostly focused on the algorithm implementation, so temporarily ignored several issues you mentioned, like double hardware acceleration, etc.
I still have several questions and would be grateful if you can guide me in the right direction:

  1. I still have several threads writing to global memory (but not every thread as previously). Haven’t found a way how to do it only from a one thread.
  2. This implementation works correctly only with a power of two number of elements in the input array, is there a some CUDA pattern to make it working with odd one?
  3. Maybe some parts of the kernel could be implemented in a more elegant way? For example, I’ve seen in some examples the usage of atomic functions and types. Would really appreciate some suggestions here.

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <device_atomic_functions.h>

#include <stdio.h>
#include <math.h>
#include <iostream>

// This is for Visual Studio Intellisence to work 
#ifndef __CUDACC__
#define __CUDACC__
#endif // !__CUDACC__
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

/// Calculate approximate standard deviation of integers in vec
__device__ double reduce_mean(const cg::thread_block_tile<4>& tile, const double* vec, int length) {
    double thread_sum = 0;

    // calculate average first
    for (int i = tile.thread_rank(); i < length; i+= tile.num_threads()) {
        thread_sum += vec[i];
    }

    // cg::plus<int> allows cg::reduce() to know it can use hardware acceleration for addition
    return cg::reduce(tile, thread_sum, cg::plus<double>()) / length;
}

__global__ void statisticsKernel(const double* data, int length, double* mean)
{
    cg::grid_group grid = cg::this_grid();
    cg::thread_block block = cg::this_thread_block();

    extern double __shared__ sdata[];

    // Stride over grid and add the values to a shared memory buffer
    sdata[block.thread_rank()] = 0;
    for (int i = grid.thread_rank(); i < length; i += grid.size())
    {
        sdata[block.thread_rank()] += data[i];
    }

    cg::sync(block);

    cg::thread_block_tile<4> tile = cg::tiled_partition<4>(block);
    double tile_mean = reduce_mean(tile, sdata, block.size());

    if (block.thread_rank() == 0)
        mean[blockIdx.x] = tile_mean;

    grid.sync();

    if (grid.thread_rank() == 0)
    {
        for (int block = 1; block < gridDim.x; block++) {
            mean[0] += mean[block];
        }
        mean[0] /= gridDim.x * length / grid.size();
    }
}

int main()
{
    constexpr int length = 16;
    double h_data[length]{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 };

    double* d_data = nullptr;
    cudaMalloc(&d_data, sizeof(double) * length);
    cudaMemcpyAsync(d_data, h_data, sizeof(double) * length, cudaMemcpyHostToDevice);

    constexpr int n_blocks = 2;
    constexpr int n_threads = 4;

    dim3 dimBlock(n_threads, 1, 1);
    dim3 dimGrid(n_blocks, 1, 1);

    double* d_mean = nullptr;
    cudaMalloc(&d_mean, n_blocks * sizeof(double));

    void* kernelArgs[] = { (void*)&d_data, (void*)&length, (void*)&d_mean };

    int sharedMemorySize = n_threads * sizeof(double);

    cudaLaunchCooperativeKernel(
        (void*)statisticsKernel,
        dimGrid, dimBlock,
        kernelArgs,
        sharedMemorySize);

    auto error = cudaDeviceSynchronize();

    double h_mean = 0;
    cudaMemcpyAsync(&h_mean, d_mean, sizeof(double), cudaMemcpyDeviceToHost);
    cudaFree(d_data);
    cudaFree(d_mean);

    std::cout << "Expected Mean: 8.5 Actual Mean: " << h_mean << '\n';

    return 0;
}

I think I’ve managed to implement the mean reduction of an array of arbitrary size. I pad an odd length array with zeros to the power of two size and compute the mean in the last block considering this padding. Tested it with a different array lengths, block and thread numbers. In prod, I’ll change the thread block tile size from 4 to 32.
I’ve got an idea how to detect the last block from threadFenceReduction CUDA example. Incidentally, I’m uncertain if I need to call the __threadfence() intrinsic function, it appears that everything works even without it.

#include "cuda_runtime.h"
#include <cuda_runtime_api.h>
#include "device_launch_parameters.h"
#include <device_atomic_functions.h>

#include <stdio.h>
#include <math.h>
#include <iostream>

// This is for Visual Studio Intellisence to work 
#ifndef __CUDACC__
#define __CUDACC__
#endif // !__CUDACC__
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

/// Calculate approximate standard deviation of integers in vec
__device__ float reduce_mean(const cg::thread_block_tile<4>& tile, const float* vec, int length) {
    float thread_sum = 0;

    // calculate average first
    for (int i = tile.thread_rank(); i < length; i+= tile.num_threads()) {
        thread_sum += vec[i];
    }

    // cg::plus<float> allows cg::reduce() to know it can use hardware acceleration for addition
    //auto reduce_sum = cg::reduce(tile, thread_sum, cg::plus<float>());
    //printf("Block %d reduce_sum=%f mean=%f \n", blockIdx.x, reduce_sum, reduce_sum / length);
    return cg::reduce(tile, thread_sum, cg::plus<float>()) / length;
}

// Global variable used by reduceSinglePass to count how many blocks have
// finished
__device__ unsigned int retirementCount = 0;

__global__ void statisticsKernelArbitrarySize(const float* data, int length, int odd_len, float* mean)
{
    cg::grid_group grid = cg::this_grid();
    cg::thread_block block = cg::this_thread_block();

    extern float __shared__ sdata[];

    // Stride over grid and add the values to a shared memory buffer
    sdata[block.thread_rank()] = 0;
    for (int i = grid.thread_rank(); i < length; i += grid.size())
    {
        sdata[block.thread_rank()] += data[i];
    }

    cg::sync(block);

    cg::thread_block_tile<4> tile = cg::tiled_partition<4>(block);

    float tile_mean = reduce_mean(tile, sdata, block.size());

    if (block.thread_rank() == 0)
        mean[blockIdx.x] = tile_mean;

    __shared__ bool amLast;

    // wait until all outstanding memory instructions in this thread are
    // finished
    //__threadfence();

    // Thread 0 takes a ticket
    if (block.thread_rank() == 0) {
        unsigned int ticket = atomicInc(&retirementCount, grid.num_blocks());
        // If the ticket ID is equal to the number of blocks, we are the last
        // block!
        amLast = (ticket == grid.num_blocks() - 1);
    }

    grid.sync();

    if (grid.thread_rank() == 0)
    {
        for (int block = 1; block < grid.num_blocks(); block++) {
            mean[0] += mean[block];
        }
        if (amLast)
            mean[0] /= grid.num_blocks() * (float)(length - odd_len) / grid.size();
        else
            mean[0] /= grid.num_blocks() * (float)length / grid.size();
    }
}

int main()
{

    constexpr int length = 16;
    constexpr int odd_len = 2;
    float h_data[length]{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0 };

    float* d_data = nullptr;
    cudaMalloc(&d_data, sizeof(float) * length);
    cudaMemcpyAsync(d_data, h_data, sizeof(float) * length, cudaMemcpyHostToDevice);

    constexpr int n_blocks = 2;
    constexpr int n_threads = 4;

    dim3 dimBlock(n_threads, 1, 1);
    dim3 dimGrid(n_blocks, 1, 1);

    float* d_mean = nullptr;
    cudaMalloc(&d_mean, n_blocks * sizeof(float));

    int sharedMemorySize = n_threads * sizeof(float);

    //void* kernelArgs[] = { (void*)&d_data, (void*)&length, (void*)&d_mean };
    void* kernelArgs[] = { (void*)&d_data, (void*)&length, (void*)&odd_len, (void*)&d_mean };

    cudaLaunchCooperativeKernel(
        (void*)statisticsKernelArbitrarySize,
        dimGrid, dimBlock,
        kernelArgs,
        sharedMemorySize);

    auto error = cudaDeviceSynchronize();

    float h_mean = 0;
    cudaMemcpyAsync(&h_mean, d_mean, sizeof(float), cudaMemcpyDeviceToHost);
    cudaFree(d_data);
    cudaFree(d_mean);

    float expected = 0;
    for (int i = 0; i < length; i++)
    {
        expected += h_data[i];
    }

    std::cout << "\nExpected Mean: " << expected / (length - odd_len) <<" Actual Mean: " << h_mean << '\n';

    return 0;
}

I don’t know why you want to pad your data. The kernel already uses strided loops to be indepedent of the input size, doesn’ it?
Why do you need to know which block is last? After sync, you can use any block for the final reduction.

I would suggest that you take a step back and focus on the core algorithm, which is the parallel reduction.

Assume you have a function blockreduce that computes a block-wide sum with more than 32 threads per block for inputs of arbitrary size (how would you implement it?)

then the reduction kernel simply follows this pseudo-code:

each block performs blockreduce on one or more chunks of data
each block writes its result to global memory
sync all blocks
one block performs blockreduce on the intermediate results in global memory
write final result to global memory

To sync all blocks, one can either use cg::grid.sync, or use a second kernel. Performance-wise, it does not matter.
And the resulting algorithm will be the same performance as the one used by Thrust. There is no magic behind cooperative groups. (I wrote a simple benchmark comparing thrust::reduce to the custom kernel on RTX 4090 with cuda 11.8 on Linux).

If you get all this working, computing the mean correctly should be simple enough.

Hello @striker159,
You were completely right, I double-checked the Thrust performance, and it works even faster than Cooperative Groups example. I’ve used Windows & Visual Studio and have forgotten to switch from Debug to Release build mode. Feeling like an idiot. For that reason, and because I’m already a bit out of schedule, I’ll continue to use Thrust in my project (the library already has implementations of ready-to-use algorithms I require). I’ll definitely return to this topic sometime later, but not sure if it makes sense to keep this thread open until that time.
Anyway, thank you a lot for your time and for being my guide inside the CUDA world :).