Your posted code has a variety of problems that prevent it from even compiling. It could not possibly run and produce any sensible results.
Before diving in further, let me say that this is an inefficient approach to do a general parallel reduction, and so I don’t recommend doing it this way. However I imagine this is for a learning exercise. Also, the methodology I will suggest below is really inferior to cooperative groups. So for learning purposes, we can get something to work. However I would not recommend this for production code.
One problem with your approach is that it’s necessary to ensure (for correctness) that the kernel code starting at line 20 in your kernel posting does not begin to execute until all other threadblocks are finished. And __syncthreads() is not a device-wide barrier. It only pertains to the threads of a particular block.
Normally for a device wide barrier, I would suggest using CUDA cooperative groups, but a cooperative kernel launch cannot use CUDA dynamic parallelism (CDP). (Specifically, a kernel launched from device code cannot use the cooperative launch API, currently). Since this discussion is predicated on learning and experimentation, and the type of synchronization we need here is relatively simple, we could consider another technique such as an atomic block-draining technique such as is used in the CUDA sample code threadFenceReduction.
Another issue we will face is the launch nesting depth. Your kernel design suggests that block 0 in each nested kernel launch will launch a new kernel, working on a data set size cut in half. Therefore the total number of launches required is log2(N) where N is the data set size. We can also see that the host-launched kernel that starts this process cannot be considered to be “complete” until the very last nested kernel is complete. This creates a set of kernel launches log2(N) that are each nested, so the nesting depth is also log2(N). CDP has limits around both of these aspects. The total number of outstanding kernel launches must obey a particular limit, and the nesting depth must obey a particular limit.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#nesting-and-synchronization-depth
The above section should be read carefully. The maximum nesting depth is 24 subject to various caveats. The CUDA runtime cannot necessarily handle a depth of 24 with no help from the programmer. Furthermore 2^^24 is only 16 million, so we might want to seek another method to handle larger data set sizes. For example once the data set size is below that which can be handled by a threadblock with a typical threadblock-level reduction, we could stop the recursion and let the threadblock handle the remaining data. This might allow us to handle something like 2^^(24+10) data set size.
(We could extend this concept to each level of nested processing, and drastically reduce the number of nesting levels required to some very small number like 2 or 3. But this is all trying to improve on an algorithm which is performance-flawed to begin with. So we will keep it simple for demonstration.)
Here’s an example of the above approach:
$ cat t1481.cu
#include <stdio.h>
#include <assert.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
const int nTPB = 1024; // must be a power of 2
__device__ int blkID = 0;
__device__ int numK = 0;
template<typename T>
__global__ void myReduce(T* d_in, unsigned n)
{
unsigned ID= threadIdx.x + blockDim.x*blockIdx.x;
if (ID == 0) atomicAdd(&numK, 1);
int tID = threadIdx.x;
__shared__ T sdata[nTPB];
if (gridDim.x == 1){ // when kernel has only 1 block use classical parallel reduction
sdata[tID] = (tID < n)?d_in[tID]:0;
if ((tID + nTPB) < n) sdata[tID] += d_in[tID+nTPB];
for (int i = nTPB>>1; i > 0; i>>=1){
__syncthreads();
if (tID < i) sdata[tID]+=sdata[tID+i];}
d_in[tID] = sdata[tID];
}
else { // when kernel has more than 1 block, sweep down once and launch new kernel
unsigned offset = (n&1) + n>>1; // divide n by 2, round up to next even integer
if ((ID + offset)<n)
d_in[ID] += d_in[ID + offset]; // sweep
__threadfence();
if (tID == 0) {
int my_ID = atomicAdd(&blkID, 1); // get my block ID (atomic block-draining)
if (my_ID == gridDim.x -1){ // am I the last block for this kernel launch?
blkID = 0;
__threadfence();
unsigned my_off = (offset&1) + offset;
int num_b = ((my_off>>1)+nTPB-1)/nTPB;
myReduce <<<num_b, nTPB>>> (d_in, offset);
assert(cudaSuccess==cudaGetLastError());
}
}
}
}
typedef int mt;
int main(){
int ds = 1048576*256+37;
int dsb = ds*sizeof(mt); // make sure this fits in int quantity
mt *data, *h_data = (mt *)malloc(dsb);
cudaMalloc(&data, dsb);
for (int i = 0; i < ds; i++) h_data[i] = 1;
cudaMemcpy(data, h_data, dsb, cudaMemcpyHostToDevice);
int my_ds = (ds%2)?(ds/2+1):(ds/2);
myReduce<<<(my_ds+nTPB-1)/nTPB, nTPB>>>(data, ds);
cudaMemcpy(h_data, data, sizeof(mt), cudaMemcpyDeviceToHost);
cudaCheckErrors("some error");
printf("result is: %d, should be: %d\n", h_data[0], ds);
int nk;
cudaMemcpyFromSymbol(&nk, numK, sizeof(int));
printf("%d kernel launches\n", nk);
return 0;
}
$ nvcc -arch=sm_70 -rdc=true -o t1481 t1481.cu -lcudadevrt
$ ./t1481
result is: 268435493, should be: 268435493
19 kernel launches
$
The code above hasn’t been thoroughly validated; use at your own risk. (Or, even better, don’t use it at all.)