1D cufft of matrix columns is very slow (1.5 second)

Hello.

My project has a lot of Fourier transforms, mostly one-dimensional transformations of matrix rows and columns.
Matrix dimentions = 8192x8192 cu Complex.
fft by row is pretty fast - ~6ms.

But for conversion by columns the time is abnormally long - ~1.5 second
, and I suspect that I am doing something wrong.
I’ll attach a small test of how I perform Fourier.

I use dev Kit AGX Orin 32GB H01
Before running the test i run sudo jetson_clocks, and set power plan - MAXN.

// nvcc -o fft_test.o fft_test.cu -lcufft
#include <math.h>
#include <stdio.h>
#include <cuComplex.h>
#include <cufft.h>
#define BLOCK_DIM 16

__global__ void data_set(cuComplex *data, float value, int col, int row)
{
    int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
    int yIndex = blockIdx.y * blockDim.y + threadIdx.y;

    if ( (xIndex < row) && (yIndex < col) ){
       int matIndex = yIndex * row + xIndex;
       data[matIndex].x               = value;
       data[matIndex].y               = 0;
    }
}

int main()
{
    dim3    gridSize;
    dim3    blockSize;

    float milliseconds = 0;
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    int n_col = 8192;
    int n_row = 8192;

    cuComplex *gpu_data;
    cudaMalloc( (void**)&gpu_data, n_col*n_row*sizeof(cuComplex) );
    cudaDeviceSynchronize();

    gridSize    = dim3(n_row/BLOCK_DIM,n_col/BLOCK_DIM,1);
    blockSize   = dim3(BLOCK_DIM,BLOCK_DIM,1);
    data_set<<<gridSize,blockSize>>>(gpu_data,1,n_col,n_row);

    cufftHandle row_plan;
    int row_n[] = {n_row};
    int row_inembed[1] = {n_row};
    int row_onembed[1] = {n_row};
    int row_istride = 1;
    int row_idist = n_row;
    int row_ostride = 1;
    int row_odist = n_row;
    int row_batch = n_row;
    cufftPlanMany(&row_plan, 1, row_n, row_inembed, row_istride, row_idist, row_onembed, row_ostride, row_odist, CUFFT_C2C, row_batch);

    cufftHandle col_plan;
    int col_n[] = {n_col};
    int col_inembed[1] = {n_col};
    int col_onembed[1] = {n_col};
    int col_istride = n_col;
    int col_idist = 1;
    int col_ostride = n_col;
    int col_odist = 1;
    int col_batch = n_col;
    cufftPlanMany(&col_plan, 1, col_n, col_inembed, col_istride, col_idist, col_onembed, col_ostride, col_odist, CUFFT_C2C, col_batch);


    // row fft
    cudaEventRecord(start);

    cufftExecC2C(row_plan, gpu_data, gpu_data, CUFFT_FORWARD);
    cudaDeviceSynchronize();

    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&milliseconds, start, stop);
    printf("row fft execotion time : %f  ms\n", milliseconds);
    // 6 ms


    // column fft
    cudaEventRecord(start);

    cufftExecC2C(col_plan, gpu_data, gpu_data, CUFFT_FORWARD);
    cudaDeviceSynchronize();

    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&milliseconds, start, stop);
    printf("column fft execotion time : %f  ms\n", milliseconds);
    // 1.500 s

}

nvidia@nvidia-desktop:~$ ./fft_test.o

row fft execotion time : 5.611392  ms
column fft execotion time : 1625.680298  ms

Hi,

Could you add a warmup function and benchmark the function with batch (ex. loop 10000 times) to see if any difference?

Thanks.

Hello.
I add warm_up function and run ff with loop 100 times.

// nvcc -o fft_test.o fft_test.cu -lcufft
#include <math.h>
#include <stdio.h>
#include <cuComplex.h>
#include <cufft.h>
#define BLOCK_DIM 16

__global__ void data_set(cuComplex *data, float value, int col, int row)
{
    int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
    int yIndex = blockIdx.y * blockDim.y + threadIdx.y;

    if ( (xIndex < row) && (yIndex < col) ){
       int matIndex = yIndex * row + xIndex;
       data[matIndex].x               = value;
       data[matIndex].y               = 0;
    }
}

__global__ void warm_up_gpu(){
  unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
  float ia, ib;
  ia = ib = 0.0f;
  ib += ia + tid;
}


int main()
{
    dim3    gridSize;
    dim3    blockSize;

    float milliseconds = 0;
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    int n_col = 8192;
    int n_row = 8192;

    cuComplex *gpu_data;
    cudaMalloc( (void**)&gpu_data, n_col*n_row*sizeof(cuComplex) );
    cudaDeviceSynchronize();

    gridSize    = dim3(n_row/BLOCK_DIM,n_col/BLOCK_DIM,1);
    blockSize   = dim3(BLOCK_DIM,BLOCK_DIM,1);
    data_set<<<gridSize,blockSize>>>(gpu_data,1,n_col,n_row);

    cufftHandle row_plan;
    int row_n[] = {n_row};
    int row_inembed[1] = {n_row};
    int row_onembed[1] = {n_row};
    int row_istride = 1;
    int row_idist = n_row;
    int row_ostride = 1;
    int row_odist = n_row;
    int row_batch = n_row;
    cufftPlanMany(&row_plan, 1, row_n, row_inembed, row_istride, row_idist, row_onembed, row_ostride, row_odist, CUFFT_C2C, row_batch);

    cufftHandle col_plan;
    int col_n[] = {n_col};
    int col_inembed[1] = {n_col};
    int col_onembed[1] = {n_col};
    int col_istride = n_col;
    int col_idist = 1;
    int col_ostride = n_col;
    int col_odist = 1;
    int col_batch = n_col;
    cufftPlanMany(&col_plan, 1, col_n, col_inembed, col_istride, col_idist, col_onembed, col_ostride, col_odist, CUFFT_C2C, col_batch);


    // row fft
    cudaEventRecord(start);

    cufftExecC2C(row_plan, gpu_data, gpu_data, CUFFT_FORWARD);
    cudaDeviceSynchronize();

    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&milliseconds, start, stop);
    printf("row fft execotion time : %f  ms\n", milliseconds);
    // 6 ms


// warm_up function 
    gridSize    = dim3(n_row/BLOCK_DIM,1,1);
    blockSize   = dim3(BLOCK_DIM,BLOCK_DIM,1);
    warm_up_gpu<<<gridSize,blockSize>>>();


    // column fft
    cudaEventRecord(start);


    for (int i=0; i < 100; i+=1){
        cufftExecC2C(col_plan, gpu_data, gpu_data, CUFFT_FORWARD);
        cudaDeviceSynchronize();
    }
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&milliseconds, start, stop);
    printf("column fft execotion time : %f  ms\n", milliseconds);
}

$ ./fft_test.o
row fft execotion time : 5.604736 ms
column fft execotion time : 148735.843750 ms

Still the same, very slow execution for columns.

Hello.
I updated the drivers to the latest version, but the problem is still there.

nvidia@jetsonHost:/usr/bin$ sudo ./jetson_clocks --show
SOC family:tegra234  Machine:Jetson AGX Orin
Online CPUs: 0-7
cpu0: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu1: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu2: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu3: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu4: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu5: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu6: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
cpu7: Online=1 Governor=schedutil MinFreq=2188800 MaxFreq=2188800 CurrentFreq=2188800 IdleStates: WFI=0 c7=0 
GPU MinFreq=930750000 MaxFreq=930750000 CurrentFreq=930750000
EMC MinFreq=204000000 MaxFreq=3199000000 CurrentFreq=3199000000 FreqOverride=1
DLA0_CORE:   Online=1 MinFreq=0 MaxFreq=1408000000 CurrentFreq=1408000000
DLA0_FALCON: Online=1 MinFreq=0 MaxFreq=742400000 CurrentFreq=742400000
DLA1_CORE:   Online=1 MinFreq=0 MaxFreq=1408000000 CurrentFreq=1408000000
DLA1_FALCON: Online=1 MinFreq=0 MaxFreq=742400000 CurrentFreq=742400000
PVA0_VPS0: Online=1 MinFreq=0 MaxFreq=704000000 CurrentFreq=704000000
PVA0_AXI:  Online=1 MinFreq=0 MaxFreq=486400000 CurrentFreq=486400000
FAN Dynamic Speed control=active hwmon3_pwm1=46
NV Power Mode: MAXN
nvidia@jetsonHost:/usr/local/cuda-11.4/samples/1_Utilities/deviceQuery$ ./deviceQuery
./deviceQuery Starting...

 CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: "Orin"
  CUDA Driver Version / Runtime Version          11.4 / 11.4
  CUDA Capability Major/Minor version number:    8.7
  Total amount of global memory:                 30588 MBytes (32074350592 bytes)
  (014) Multiprocessors, (128) CUDA Cores/MP:    1792 CUDA Cores
  GPU Max Clock rate:                            930 MHz (0.93 GHz)
  Memory Clock rate:                             930 Mhz
  Memory Bus Width:                              128-bit
  L2 Cache Size:                                 4194304 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total shared memory per multiprocessor:        167936 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  1536
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 2 copy engine(s)
  Run time limit on kernels:                     No
  Integrated GPU sharing Host Memory:            Yes
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Disabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Managed Memory:                Yes
  Device supports Compute Preemption:            Yes
  Supports Cooperative Kernel Launch:            Yes
  Supports MultiDevice Co-op Kernel Launch:      Yes
  Device PCI Domain ID / Bus ID / location ID:   0 / 0 / 0
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 11.4, CUDA Runtime Version = 11.4, NumDevs = 1
Result = PASS
nvidia@jetsonHost:/usr/bin$ cat /etc/nv_tegra_release
# R35 (release), REVISION: 3.1, GCID: 32827747, BOARD: t186ref, EABI: aarch64, DATE: Sun Mar 19 15:19:21 UTC 2023

I use AGX Orin 32 GB H01 this last jetpack for H01 version.
I would like an answer from nvidia - 1.5 seconds for 8192 fft is this the real performance of AGX Orin?

I changed the matrix size to 4096 x 8192 and the execution time decreased by 1000 times.

row fft execotion time : 1.426304 ms
column fft execotion time : 1.426304 ms

What could be the reason for such low performance for 8192 x 8192 matrix?

Hello.
I still haven’t solved this problem.
I wrote an option for half precision.

For fft by row the result has not changed - 6 ms.

For fft by columns time execution decreased from 1.5 s to 90 ms.

// nvcc -o fft_test.o fft_test.cu -lcufft -gencode arch=compute_87,code=sm_87
#include <math.h>
#include <stdio.h>
#include <cuComplex.h>
#include <cufft.h>
#include <cufftXt.h>
#include <cuda_fp16.h>
#define BLOCK_DIM 16

__global__ void data_set(half2 *data, half value, int col, int row)
{
    int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
    int yIndex = blockIdx.y * blockDim.y + threadIdx.y;

    if ( (xIndex < row) && (yIndex < col) ){
       int matIndex = yIndex * row + xIndex;
       data[matIndex].x               = value;
       data[matIndex].y               = 0;
    }
}


__global__ void Scale_f(half2 *a, half size ,int col ,int row)
{
    __shared__ half2 a_t[BLOCK_DIM][BLOCK_DIM];
    half scale_f = __float2half(1.0)/size;

    int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
    int yIndex = blockIdx.y * blockDim.y + threadIdx.y;

    if ( (xIndex < row) && (yIndex < col) ){

        int matIndex = yIndex * row + xIndex;
        a_t[threadIdx.y][threadIdx.x] = a[matIndex];

    }
    __syncthreads();
    if ( (xIndex < row) && (yIndex < col) ){

        int matIndex = yIndex * row + xIndex;
        a[matIndex].x = a_t[threadIdx.y][threadIdx.x].x*scale_f;
        a[matIndex].y = a_t[threadIdx.y][threadIdx.x].y*scale_f;
    }
}
int main()
{
    dim3    gridSize;
    dim3    blockSize;

    float milliseconds = 0;
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    int n_col = 8192;
    int n_row = 8192;

    half2 *gpu_data;
    cudaMalloc( (void**)&gpu_data, n_col*n_row*sizeof(half2) );
    cudaDeviceSynchronize();
    half2 *host_data = (half2*)malloc(n_col*n_row*sizeof(half2));

    gridSize    = dim3(n_row/BLOCK_DIM,n_col/BLOCK_DIM,1);
    blockSize   = dim3(BLOCK_DIM,BLOCK_DIM,1);
    data_set<<<gridSize,blockSize>>>(gpu_data,1,n_col,n_row);

    cufftHandle row_plan;
    cufftCreate(&row_plan);

    long long int row_n[] = {n_row};
    long long int row_inembed[1] = {n_row};
    long long int row_onembed[1] = {n_row};
    long long int row_istride = 1;
    long long int row_idist = n_row;
    long long int row_ostride = 1;
    long long int row_odist = n_row;
    long long int row_batch = n_col;
    size_t row_worksize[] = { n_row * sizeof(half2) };

    cufftXtMakePlanMany(row_plan, 1, row_n, row_inembed, row_istride, row_idist, CUDA_C_16F, row_onembed, row_ostride, row_odist, CUDA_C_16F, row_batch, row_worksize,CUDA_C_16F);

    cufftHandle col_plan;
    cufftCreate(&col_plan);

    long long int col_n[] = {n_col};
    long long int col_inembed[1] = {n_col};
    long long int col_onembed[1] = {n_col};
    long long int col_istride = n_row;
    long long int col_idist = 1;
    long long int col_ostride = n_row;
    long long int col_odist = 1;
    long long int col_batch = n_row;
    size_t col_worksize[] = { n_col * sizeof(half2) };

    cufftXtMakePlanMany(col_plan, 1, col_n, col_inembed, col_istride, col_idist, CUDA_C_16F, col_onembed, col_ostride, col_odist, CUDA_C_16F, col_batch, col_worksize,CUDA_C_16F);

    // row fft
    cudaEventRecord(start);

    cufftXtExec(row_plan, gpu_data, gpu_data, CUFFT_FORWARD);
    cudaDeviceSynchronize();

    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&milliseconds, start, stop);
    printf("row fft execotion time : %f  ms\n", milliseconds);


    gridSize    = dim3(n_row/BLOCK_DIM+1,n_col/BLOCK_DIM+1,1);
    blockSize   = dim3(BLOCK_DIM,BLOCK_DIM,1);
    Scale_f<<<gridSize,blockSize>>>(gpu_data,n_row,n_col ,n_row);
    cudaDeviceSynchronize();


    // column fft
    cudaEventRecord(start);


    cufftXtExec(col_plan, gpu_data, gpu_data, CUFFT_FORWARD);
    cudaDeviceSynchronize();

    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&milliseconds, start, stop);
    printf("column fft execotion time : %f  ms\n", milliseconds);

    cudaMemcpy(host_data, gpu_data, n_col*n_row*sizeof(half2), cudaMemcpyDeviceToHost);

    printf("host_data : %f\n",__half2float(host_data[0].x));

}
$ ./fft_test.o
row fft execotion time : 5.604704  ms
column fft execotion time : 87.617729  ms

The difference of 15 times cannot be explained by less efficient memory performance; one could expect performance two times lower, but not 15 times lower.

Someone can run this test on their Jetson. Maybe I have hardware problems, the code itself looks correct.

Hi,

Sorry for the late update.
We test the last sample and can reproduce a similar behavior in our environment:

$ ./fft_test.o
row fft execotion time : 3.556832  ms
column fft execotion time : 59.715359  ms
host_data : 8192.000000

We are checking this internally. Will update more info with you later.
Thanks.

Hi,

Below is the feedback from our internal team:

The performance difference is within expected values due to the strided access to the data, as well as the compute capabilities of the Orin board (datacenter GPUs show smaller differences between contiguous and strided data accesses).

We have taken note of your use case and it is possible that we are able to optimize it in a future release; in the meantime, you could consider transposing the data to enable continuous access, or using cuFFTDx to write a custom FFT kernel with loading / storing tailored to your problem.

Thanks.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.