Performance issue. deployment of fast load method result in decreased performance

I my kernel, there are 3 stages:

  1. data load from global to shared memory in non-trivial pattern: columns from global memory (stridden elements) are loaded to rows in shared memory (sequential elements), when loading to shared, I also perform conversion from int16 to float:
shared_mem[shared_idx] = (float)global_mem[global_idx];
  1. compute: convolution with FFT, vector multiplication and then IFFT. (implemented with cufftDx)

  2. store data from shared memory to global memory in non-trivial pattern: rows from shared memory are stored in global memory as columns

I tried to optimize stage 1. I implemented two kernels that perform only stage 1, each kernel with different loading method -
In method A, each thread load few sequential elements from global memory and store them in shared memory in different rows.
In method B, sequential threads load sequential elements from global memory to their place in shared memory.

I found method B to be about x2 faster then method A when running stage 1 alone. But when I tried to deploy method B in my main kernel (with all 3 stages), the kernel runs dramatically slower (in factor of 2 comparing to the same kernel with method A in stage 1).
I measured the time each stage takes inside the kernel using clock() and found that when method B is implemented in stage 1, It runs slower comparing to method A, and also make stages 2 and 3 run slower (in factor of 2-3).
It surprised me because the performance were dramatically decreased although we are talking about the exact same kernel, with the only difference in stage 1, and after stage 1 , syncthreads() is called. so in my understanding, the method deployed in stage 1 should not affect the next two stages.

Can someone explain me what happened and what can I do?
Thank a lot,
Ori.

  • I run on RTX5000, CUDA 11.6.
    **Here is my kernel:
template<class FFT, class IFFT>
__launch_bounds__(FFT::max_threads_per_block)
__global__ void unpack_pulse_compresion_transpose_kernel(Complex32_t *pSrc, typename FFT::value_type *pDest ,
													int nx, int nz, int ny, cufftComplex *pVector) {

    using complex_type = typename FFT::value_type;
    using scalar_type  = typename complex_type::value_type;
	const int stride = size_of<FFT>::value / FFT::elements_per_thread;

	clock_t start_unpack, finish_unpack, start_PC, finish_PC, start_transpose, finish_transpose;


	extern __shared__ complex_type shared_mem[];
	int stride_gloabl = (nz*blockDim.x)/CHANNELS_PER_SM;
	int stride_sh = blockDim.x/CHANNELS_PER_SM;
	int elements_per_thread = (CHANNELS_PER_SM*nx)/blockDim.x;

	int start_load_idx = blockIdx.y*nz*nx + blockIdx.x*CHANNELS_PER_SM;

	int global_col = threadIdx.x%CHANNELS_PER_SM;
	int global_row = threadIdx.x/CHANNELS_PER_SM;
	int sh_row = global_col;
	int sh_col = global_row;

// Loading data from global & convert to float (Choose method A or B) 

#if LOAD_IMPLEMENTATION_A
	for (size_t j = 0; j < FFT::elements_per_thread; j++)
	{
		for (size_t i = 0; i < CHANNELS_PER_SM; i++)
		{
			shared_mem[i*nx + threadIdx.x +j*stride].x = (float)pSrc[start_load_idx +(threadIdx.x +j*stride)*nz +i].x;
			shared_mem[i*nx + threadIdx.x +j*stride].y = (float)pSrc[start_load_idx +(threadIdx.x +j*stride)*nz +i].y;
		}
	}
#endif

#if LOAD_IMPLEMENTATION_B
	for (size_t j = 0; j < elements_per_thread; j++)
	{
		shared_mem[sh_row*nx + sh_col +j*stride_sh].x = (float)pSrc[start_load_idx +global_row*nz + global_col +stride_gloabl*j].x;
		shared_mem[sh_row*nx + sh_col +j*stride_sh].y = (float)pSrc[start_load_idx +global_row*nz + global_col +stride_gloabl*j].y;
	}
#endif

	__syncthreads();

// Convolution using FFT and IFFT

	for (size_t i = 0; i < CHANNELS_PER_SM; i++)
	{
		FFT().execute(shared_mem + i*nx);

		for (size_t j = 0; j < FFT::elements_per_thread; j++)
		{
			float destI = pVector[threadIdx.x + i*stride].x * shared_mem[threadIdx.x + j*stride + i*nx].x - pVector[threadIdx.x + i*stride].y * shared_mem[threadIdx.x + j*stride + i*nx].y;
			float destQ = pVector[threadIdx.x + i*stride].x * shared_mem[threadIdx.x + j*stride + i*nx].y + pVector[threadIdx.x + i*stride].y * shared_mem[threadIdx.x + j*stride + i*nx].x;
			shared_mem[threadIdx.x + j*stride + i*nx].x = destI;
			shared_mem[threadIdx.x + j*stride + i*nx].y = destQ;
		}

		IFFT().execute(shared_mem + i*nx);

		for (size_t j = 0; j < FFT::elements_per_thread; j++)
		{
			pDest[ (blockIdx.x*CHANNELS_PER_SM + i)*nx*ny + (threadIdx.x +j*stride)*ny + blockIdx.y ].x = shared_mem[threadIdx.x +j*stride + i*nx].x;
			pDest[ (blockIdx.x*CHANNELS_PER_SM + i)*nx*ny + (threadIdx.x +j*stride)*ny + blockIdx.y ].y = shared_mem[threadIdx.x +j*stride + i*nx].y;
		}
		
	}

// Store back to global memory as cloumns


	for (size_t i = 0; i < CHANNELS_PER_SM; i++)
	{
		for (size_t j = 0; j < FFT::elements_per_thread; j++)
		{
			pDest[ (blockIdx.x*CHANNELS_PER_SM + i)*nx*ny + (threadIdx.x +j*stride)*ny + blockIdx.y ].x = shared_mem[threadIdx.x +j*stride + i*nx].x;
			pDest[ (blockIdx.x*CHANNELS_PER_SM + i)*nx*ny + (threadIdx.x +j*stride)*ny + blockIdx.y ].y = shared_mem[threadIdx.x +j*stride + i*nx].y;
		}
	}

	__syncthreads();


}

Have you used a profiler?

Yes, I used Nsight Compute.
When running kernels with only stage 1, the profiler shows that:
Method A has compute throughput of 4.46 and memory throughput of 57.11
Method B has compute throughput of 12.09 and memory throughput of 47.5 (seems more balanced).
Method A get warning about not-optimal L1TEX Global Load Access Pattern, while method B don’t get such warning.
With method A, the kernel issues an instruction every 64 cycles, while with method B it issues an instruction every 16.6 cycles.
With method A, lg_throttle couse 214.2 cycles stall per wrap, while with method B its only couse 32.7 cycles stall.
Both method almost achieve the theoretical accupancy (47% / 50% )

After adding stages 2 and 3 to the kernel:
Method A has compute throughput of 21.55 and memory throughput of 41.91
Method B has compute throughput of 9.63 and memory throughput of 18.57 (seems more balanced).
Method B get warning about low utilization of all compute pipeline, while method A don’t get such warning.
With method A, the kernel issues an instruction every 7.2 cycles, while with method B it issues an instruction every 15.1 cycles.
With method A, lg_throttle is the main stall, results in 9.5 cycles stall per wrap, while with method B, stall barrier is the main stall, results in 18.11 cycles stall per wrap. Stall long scoreboard results in 16.14 cycles stall, stall mio throttle results in 8.87 and stall LG throttle results in 7.47 cycles stall.
Both method achieve the theoretical accupancy (50%)