I can't run cuFFTDx with fft points more than 8192

I can’t run cufftdx with fft points more than 8192 even though the cufftdx document says that it can be possible up to 32768 using cc80. Would you help me run cufftdx with 32768 points?

Here are hardward and software versions that I am using.

A100 PCIe
Cuda compilation tools, release 12.1, V12.1.105
cufftdx 1.1.0

It seems to me that the register pressure is the main reason that I can’t run cufftdx with n>=8192. The compilation command that I have used is as follows.

/nvcc -std=c++17 -arch sm_80 --ptxas-options=-v -O3 -I/opt/nvidia/mathdx/22.11/include first_cufftdx.cu -o first_cufftdx

Here is the program that I have run. The program is basically from a sample program in the section of “First FFT using cuFFTDx” in the cuFFTDx document.


#include <iostream>
#include <fstream>
#include <cmath>

#include <cufftdx.hpp>

#ifndef CUDA_CHECK_AND_EXIT
#    define CUDA_CHECK_AND_EXIT(error)                                                                      \
        {                                                                                                   \
            auto status = static_cast<cudaError_t>(error);                                                  \
            if (status != cudaSuccess) {                                                                    \
                std::cout << cudaGetErrorString(status) << " " << __FILE__ << ":" << __LINE__ << std::endl; \
                std::exit(status);                                                                          \
            }                                                                                               \
        }
#endif // CUDA_CHECK_AND_EXIT

using namespace cufftdx;

template <class FFT>
void initializeDataCos(typename FFT::value_type *data, const size_t n, const size_t nffts) {
    size_t size = n * nffts;
    for (size_t i = 0; i < size; i++) {
        float angle = 2.0f * M_PI * i / n;
        data[i] = typename FFT::value_type{std::cos(angle), 0.0f};
    }
}

template <class FFT>
void saveData(const typename FFT::value_type *data, size_t size, const std::string &filename) {
    std::ofstream outfile(filename);
    if (outfile.is_open()) {
        for (int i = 0; i < size; i++) {
            outfile << i << " " << data[i].x << " " << data[i].y << "\n";
        }
        outfile.close();
        std::cout << "Data saved to " << filename << std::endl;
    } else {
        std::cout << "Unable to open the file " << filename << std::endl;
    }
}

template<class FFT>
__global__ void block_fft_kernel(typename FFT::value_type* data) {

  using complex_type = typename FFT::value_type;

   // Registers
  complex_type thread_data[FFT::storage_size];

  // Local batch id of this FFT in CUDA block, in range [0; FFT::ffts_per_block)
  const unsigned int local_fft_id = threadIdx.y;
  // Global batch id of this FFT in CUDA grid is equal to number of batches per CUDA block (ffts_per_block)
  // times CUDA block id, plus local batch id.
  const unsigned int global_fft_id = (blockIdx.x * FFT::ffts_per_block) + local_fft_id;
//  const unsigned int global_fft_id = (blockIdx.x * FFT::ffts_per_block) + threadIdx.y;

  // Load data from global memory to registers
  const unsigned int offset = size_of<FFT>::value * global_fft_id;
  const unsigned int stride = FFT::stride;
  unsigned int       index  = offset + threadIdx.x;
  for (unsigned int i = 0; i < FFT::elements_per_thread; i++) {
      // Make sure not to go out-of-bounds
      if ((i * stride + threadIdx.x) < size_of<FFT>::value) {
          thread_data[i] = data[index];
          index += stride;
      }
  }

  // FFT::shared_memory_size bytes of shared memory
  extern __shared__ complex_type shared_mem[];

  // Execute FFT
  FFT().execute(thread_data, shared_mem);

  // Save results
  index = offset + threadIdx.x;
  for (unsigned int i = 0; i < FFT::elements_per_thread; i++) {
      if ((i * stride + threadIdx.x) < size_of<FFT>::value) {
          data[index] = thread_data[i];
          index += stride;
      }
  }
}

int main() {

// FFT description:
// A single precision complex-to-complex forward FFT description
    using FFT = decltype(  Size<8192>()
                         + Precision<float>()
                         + Type<fft_type::c2c>()
                         + Direction<fft_direction::forward>()
                         + Block()
                         + FFTsPerBlock<1>()
//                         + ElementsPerThread<32>()
                         + SM<800>());

// The blockDim of the executioin configuration is pre-determined by the following fomula.
// FFT::block_dim.x = (size_of<FFT>::value / FFT::elements_per_thread)
// FFT::block_dim.y = (FFT::ffts_per_block / FFT::implicit_type_batching) 
// FFT::block_dim.z = 1
    std::cout << "size_of<FFT>=" << size_of<FFT>::value << std::endl;
    std::cout << "FFT::elements_per_thread=" << FFT::elements_per_thread << std::endl;
    std::cout << "FFT::fft_per_block=" << FFT::ffts_per_block << std::endl;
    std::cout << "FFT::implicit_type_batching=" << FFT::implicit_type_batching << std::endl;
    std::cout << "FFT::block_dim=" << "(" << FFT::block_dim.x << "," << FFT::block_dim.y << "," << FFT::block_dim.z << ")" << std::endl;
    std::cout << "FFT::storage_size=" << FFT::storage_size << std::endl;
    std::cout << "FFT::shared_memory_size=" << FFT::shared_memory_size << std::endl;
    std::cout << "FFT::stride=" << FFT::stride << std::endl;

    using complex_type = typename FFT::value_type;

    // Allocate managed memory for input/output
    complex_type *data_h, *data_d;
    auto          size       = FFT::ffts_per_block * size_of<FFT>::value;
    auto          size_bytes = size * sizeof(complex_type);

    cudaMallocHost(&data_h, size_bytes);
    cudaMalloc(&data_d, size_bytes);

     // Initialize input data
    initializeDataCos<FFT>(data_h, size_of<FFT>::value, FFT::ffts_per_block);
    saveData<FFT>(data_h, size, "input.dat");

    cudaMemcpy (data_d, data_h, size_bytes, cudaMemcpyHostToDevice);

    // Invokes kernel with FFT::block_dim threads in CUDA block
    block_fft_kernel<FFT><<<1, FFT::block_dim, FFT::shared_memory_size>>>(data_d);
    CUDA_CHECK_AND_EXIT(cudaPeekAtLastError());

    cudaMemcpy (data_h, data_d, size_bytes, cudaMemcpyDeviceToHost);

    saveData<FFT>(data_h, size, "output.dat");

    cudaFree(data_h);
    cudaFree(data_d);

    return 0;
}

You may need to increase dynamic memory limit. Execution Methods — cuFFTDx 1.1.0 documentation

Many thanks for your comment. I did more experiments and found the following behaviors.

When N=8192 (number of fft points), the required size of shared memory is 65536, which is larger than the default shared memory size 49152 of A100. When I use the following call

CUDA_CHECK_AND_EXIT(cudaFuncSetAttribute(
    block_fft_kernel<FFT>,
    cudaFuncAttributeMaxDynamicSharedMemorySize,
    FFT::shared_memory_size));

this program runs without any complaints.

Even with the above function call, cuFFTDx won’t run properly for the cases with N=16384 and 32768, where the register ran out. If you make workspace using the following call,

cudaError_t error_code = cudaSuccess;
auto workspace = make_workspace<FFT>(error_code);
CUDA_CHECK_AND_EXIT(error_code);

Then the program runs without complaints. In fact, for those two cases the program starts to use workspace in global memory. This will slow down the performance of cuFFTDx.

1 Like

I ran into simillar problem when upgrading from 0.3.1 to 1.1.0. Looks like the problem is that c2r gets forwarded to c2c. i.e. 8192 points runs in 0.3.1 with 32768 bytes of shared memory but requires double of that in 1.1.0. Could you please fix that? This means that we cannot run 16372 in shared memory.