Recently I started to learn CUDA programming, and be attracted at once. Now I want to program a FFT code with CUDA. Before this, I tested a 2-dimension using CUFFT lib, it aquire a good performance in about 25000000 clock ticks( one third for memory copy). But my code do 1-dimension FFT will cost the same time. I checked it, found that it’s very time-consuming while writing shared memory.
To avoid bank conflict, I set stride to be a big odd number, where’s my fault? Can you help me and give me some advice? thank you!
[codebox]#define MATRIX_SIZE (1<<9)
#define BLOCK_SIZE (1<<9)
global
void FFT_2D_Radix2(DATA_TYPE* dg_buffer, int N )
{
int tid, rev, pos, pre, stride = 33;
tid = threadIdx.x;
rev = bit_reverse3(tid, tail_zero_nums(N));
shared DATA_TYPE s_DataR[MATRIX_SIZE]; // 512*4 = 2kB
shared DATA_TYPE s_DataI[MATRIX_SIZE]; // 512*4 = 2kB
shared DATA_TYPE s_CosTable[MATRIX_SIZE]; // 512*4 = 2kB
shared DATA_TYPE s_SinTable[MATRIX_SIZE]; // 512*4 = 2kB
pos = tid * stride % MATRIX_SIZE;
s_DataR[pos] = dg_buffer[blockIdx.x * BLOCK_SIZE + rev]; //------------------------------------time-consuming
s_DataI[pos] = dg_buffer[N*N + blockIdx.x * BLOCK_SIZE + rev]; //-----------------------------------time-consuming
float theta = GV_2PI / N;
s_SinTable[pos] = __sinf( theta * tid );
s_CosTable[pos] = __cosf( theta * tid );
__syncthreads();
int step, w;
for(step = 1;step<N;step=step*2)
{
if(tid & step)
{
w = ( tid & ( step - 1 ) ) * stride % MATRIX_SIZE;
DATA_TYPE tempR = s_DataR[pos] * s_CosTable[w] + s_DataI[pos] * s_SinTable[w];
DATA_TYPE tempI = s_DataI[pos] * s_CosTable[w] - s_DataR[pos] * s_SinTable[w];
pre = ( tid - step ) * stride % MATRIX_SIZE;
s_DataR[pos] = s_DataR[pre] - tempR; //-----------------------------------------------time-consuming
s_DataI[pos] = s_DataI[pre] - tempI; //-----------------------------------------------time-consuming
s_DataR[pre] += tempR; //-----------------------------------------------time-consuming
s_DataI[pre] += tempI; //-----------------------------------------------time-consuming
}
__syncthreads();
}
}[/codebox]