cuFFTDx kernel equivalent to executing a 3D cuFFT plan?

Hi, I’ve been digging my teeth into cuFFTDx for the past two weeks and I have been struggling quite a bit with replicating behaviour shown by executing a 3D plan via cuFFT using cuFFTDx. For the specific test case I am trying to replicate I am using an 8 by 8 by 4 plan. As far as I have understood this requires 3 FFTs, of size 8, 8 and 4 when using cuFFTDx. The code I currently have is largely based off of the 3D box example found here with the big difference being that I need FFTs that are executed per block rather than per thread (I plan on using FFTs in the future of size > 32).

The kernel I am using for this is as follows:

template<class FFTX, class FFTY, class FFTZ, unsigned int storage_size = std::max( { FFTX::storage_size, FFTY::storage_size, FFTZ::storage_size } ) >
__device__ void executeR2C( const unsigned int N, const unsigned int M, const unsigned int L, 
							real_type* input, complex_type* output,
							typename FFTX::workspace_type workspace_x, typename FFTY::workspace_type workspace_y, typename FFTZ::workspace_type workspace_z ) {
	// Registers
	complex_type thread_data[storage_size];

	// FFT::shared_memory_size bytes of shared memory
	extern __shared__ complex_type shared_memory[];

	// Local batch id of this FFT in CUDA block, in range [0; FFT::ffts_per_block)
	const unsigned int local_fft_id = threadIdx.y;
	// Input size
	const unsigned int LMN = L * M * N;

	// Load registers with zeroes
	for ( unsigned int i = 0; i < FFTX::elements_per_thread; i++ ) thread_data[i] = complex_type{ 0, 0 };

	const bool		   run_fft_x	   = threadIdx.x < max( 4U, M ) && threadIdx.y < max( 4U, L );
	const unsigned int stride_x		   = M * L;
	const unsigned int global_fft_id_x = ( blockIdx.x * FFTX::ffts_per_block ) + local_fft_id;
	const unsigned int offset_x		   = size_of<FFTX>::value * global_fft_id_x;
	unsigned int       index_x		   = offset_x + threadIdx.x + threadIdx.y * M;
	if ( run_fft_x ) {
		unsigned int index = index_x;

		for ( unsigned int i = 0; i < FFTX::elements_per_thread; i++ ) {
			// Make sure not to go out-of-bounds
			if ( index < LMN ) {
				thread_data[i] = complex_type{ input[index], 0};
				index += stride_x;
			}
		}
	}

	// Execute FFT in X dimension
	FFTX().execute( thread_data, shared_memory, workspace_x );

	if ( run_fft_x ) {
		unsigned int index = index_x;

		// Exchange via shared memory
		for ( unsigned int i = 0; i < FFTX::elements_per_thread; i++ ) {
			shared_memory[index] = thread_data[i];
			index += stride_x;
		}
	}

	__syncthreads();
	
	// Load registers with zeroes
	for ( unsigned int i = 0; i < FFTY::elements_per_thread; i++ ) thread_data[i] = complex_type{ 0, 0 };

	const bool		   run_fft_y	   = threadIdx.x < max( 4U, L ) && threadIdx.y < max( 4U, N );
	const unsigned int stride_y		   = L;
	const unsigned int global_fft_id_y = ( blockIdx.x * FFTY::ffts_per_block ) + local_fft_id;
	const unsigned int offset_y		   = size_of<FFTY>::value * global_fft_id_y;
	unsigned int       index_y		   = offset_y + threadIdx.x + threadIdx.y * M;
	if ( run_fft_y ) {
		unsigned int index = index_y;

		for ( unsigned int i = 0; i < FFTY::elements_per_thread; i++ ) {
			// Make sure not to go out-of-bounds
			if ( index < LMN ) {
				thread_data[i] = complex_type{ input[index], 0};
				index += stride_y;
			}
		}
	}

	// Execute FFT in Y dimension
	FFTY().execute( thread_data, shared_memory, workspace_y );

	if ( run_fft_y ) {
		unsigned int index = index_y;

		// Exchange via shared memory
		for ( unsigned int i = 0; i < FFTY::elements_per_thread; i++ ) {
			shared_memory[index] = thread_data[i];
			index += stride_y;
		}
	}

	__syncthreads();

	// Load registers with zeroes
	for ( unsigned int i = 0; i < FFTZ::elements_per_thread; i++ ) thread_data[i] = complex_type{ 0, 0 };

	const bool		   run_fft_z	   = threadIdx.x < max( 4U, N ) && threadIdx.y < max( 4U, M );
	const unsigned int stride_z		   = 1;
	const unsigned int global_fft_id_z = ( blockIdx.x * FFTZ::ffts_per_block ) + local_fft_id;
	const unsigned int offset_z		   = size_of<FFTZ>::value * global_fft_id_z;
	unsigned int       index_z		   = offset_z + ( threadIdx.x + threadIdx.y * N ) * L;
	if ( run_fft_z ) {
		unsigned int index = index_z;

		for ( unsigned int i = 0; i < FFTZ::elements_per_thread; i++ ) {
			// Make sure not to go out-of-bounds
			if ( index < LMN ) {
				thread_data[i] = complex_type{ input[index], 0};
				index += stride_z;
			}
		}
	}

	// Execute FFT in Z dimension
	FFTZ().execute( thread_data, shared_memory, workspace_z );

	if ( run_fft_z ) {
		unsigned int index = index_z;

		// Exchange via shared memory
		for ( unsigned int i = 0; i < FFTZ::elements_per_thread; i++ ) {
			shared_memory[index] = thread_data[i];
			index += stride_z;
		}
	}

	__syncthreads();

	if ( run_fft_x ) {
		unsigned int index = index_x;
		for ( unsigned int i = 0; i < FFTX::elements_per_thread; i++ ) {
			thread_data[i] = shared_memory[index];
			index += stride_x;
		}

		index = index_x;
		for ( unsigned int i = 0; i < FFTX::elements_per_thread; i++ ) {
			output[index] = thread_data[i];
			index += stride_x;
		}
	}
 }

I am wondering if I maybe misunderstood the example code and need to make more changes to accommodate for using Block execution? Or is there another issue here?