Porting a complex polynomial root solver to CUDA - optimizing kernel performance

Hello!

I’m completely new to CUDA programming. I’ve ported an algorithm for finding all (complex) roots of polynomials which was originally written in C, to CUDA. I’m interested in finding ways of improving the performance of the kernel.

Here is the kernel function and the function which calls the kernel:

__global__ void ehrlich_aberth_kernel(std::int64_t size, std::int64_t deg,
                                      const thrust::complex<double> *poly,
                                      thrust::complex<double> *roots, double *alpha, bool *conv,
                                      point *points, point *hull) {
  const std::int64_t itmax = 50;

  // Compute roots
  std::int64_t i;
  // This is a "grid-stride loop" see http://alexminnaar.com/2019/08/02/grid-stride-loops.html
  for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += blockDim.x * gridDim.x) {
    i = idx * (deg + 1);
    ehrlich_aberth(poly + i, roots + i - idx, deg, itmax, alpha + i, conv + i - idx, points + i,
                   hull + i);
  }
}

// Function which calls the CUDA kernel
inline void apply_ehrlich_aberth(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len) {
  const EhrlichAberthDescriptor &d =
      *UnpackDescriptor<EhrlichAberthDescriptor>(opaque, opaque_len);
  const std::int64_t size = d.size;
  const std::int64_t deg = d.deg;

  const thrust::complex<double> *poly =
      reinterpret_cast<const thrust::complex<double> *>(buffers[0]);
  thrust::complex<double> *roots = reinterpret_cast<thrust::complex<double> *>(buffers[1]);

  const int block_dim = 256;
  const int grid_dim = std::min<int>(1024, (size + block_dim - 1) / block_dim);

  // Preallocate memory for temporary arrays used within the kernel
  // allocating these arrays within the kernel with `new` results in a an illegal memory access
  // error for some reason I don't understand
  double *alpha;
  bool *conv;
  point *points;
  point *hull;

  cudaMalloc(&alpha, size * (deg + 1) * sizeof(double));
  cudaMalloc(&conv, size * deg * sizeof(bool));
  cudaMalloc(&points, size * (deg + 1) * sizeof(point));
  cudaMalloc(&hull, size * (deg + 1) * sizeof(point));

  ehrlich_aberth_kernel<<<grid_dim, block_dim, 0, stream>>>(size, deg, poly, roots, alpha, conv,
                                                            points, hull);

  // free memory
  cudaFree(alpha);
  cudaFree(conv);
  cudaFree(points);
  cudaFree(hull);

  cudaError_t cudaerr = cudaDeviceSynchronize();
  if (cudaerr != cudaSuccess)
    printf("kernel launch failed with error \"%s\".\n", cudaGetErrorString(cudaerr));
}

}  // namespace

The rest of the code is here.

The ehrlich_aberth function takes a vector of polynomial coefficients of shape n_polynomials * (deg + 1) where deg is the degree of the polynomial and it returns the roots of shape n_polynomials*deg. There’s a test function test_cuda.cc.cu and a version for the CPU test_cpu.cc.

For 1M polynomials, the kernel call takes about ~60ms using the RTX 3090 card. This is at least an order of magnitude slower than I expected because a paper implementing a similar complex polynomial solver claims a performance of ~10ms per 1M polynomials, and that’s with a much much older gpu ( Tesla C2075). The same code compiled for the CPU with 1M polynomials takes ~7 seconds.

Here is some output from NSight compute:

I’m wondering if there’s any low hanging fruit for optimizing this code that’s obvious to CUDA experts?

Thanks!

EDIT: I was just looking at the specs for Tesla C2075 and its FP64 theoretical performance is 514 GFLOPS while for RTX 3090 it’s 556 GFLOPS. I was naively expecting that RTX 3090 would be significantly faster in double precision as well as float precision.

You don’t want to have adjacent threads processing groups of adjacent elements. For example, you have offset your alpha pointer by a constant times the thread index:

i = idx * (deg + 1);  // this breaks performance

Then in your ehrlich_aberth routine you have each thread doing for example this:

  for (int i = 0; i < deg; i++) {
    alpha[i] = thrust::abs(poly[i]);
    conv[i] = false;
  }

That is a bad design paradigm for CUDA. You want adjacent threads to load and store adjacent elements. This will involve code refactoring as well as storage refactoring to address, most likely. The (bad) pattern exists extensively in the code from what I have looked at.

This is a really basic CUDA programming principle, so when I teach CUDA, I suggest to people that they have a few basics in the area of “optimization” under their belt, before they undertake serious projects. You can get a basic, orderly introduction to CUDA here and the concepts I am referring to are covered in unit 4 there.

1 Like

As Robert Crovella says, it is more likely than not that optimizing data access patterns (and minimizing memory access when possible) is more important than re-arranging computation when undertaking performance optimizations.

A few simple points:

(1) Make sure you are using release builds when analyzing performance.
(2) Acquaint yourself with CUDA profiler and use it to help you identify the most important bottlenecks in the code.
(3) Don’t use 64-bit integer types like size_t and int64_t unless strictly needed. GPUs are 32-bit platforms with minimal hardware extensions for 64-bit addressing.
(4) Don’t square floating-point numbers by calls to pow().
(5) When comparing performance with other codes, take into account quality aspects. For example, I see a compensated sum (priest_sum) in the code, which presumably is motivated by the desire to achieve highly accurate solutions. Other codes may emphasize speed over accuracy.

1 Like

Thank you for sharing those slides!

Unfortunately I don’t really have the resources or the knowledge to rewrite this root solver algorithm in a more parallel way. What I want is for each thread to work on one vector of polynomial coefficients at a time. I also forgot to mention that deg is quite small (<10) so these loops inside ehrlich_aberth aren’t that large.

For example, you have offset your alpha pointer by a constant times the thread index:

The only way I can’t think of to avoid doing this is by passing pointers to pointers such as **poly to the kernel function and then evaluating poly[idx]. I saw multiple threads on stack overflow saying that this is strongly discouraged though and that it’s best to pass flattened arrays to the kernel function. Also, when I tried to implement this approach I ran into segfault errors when using cudaMalloc to allocate memory for 2D arrays. Is there anything else I could try?

I have also tried fixing the degree of the polynomials deg at compile time but this gave me only a slight speedup.

In principle, it should be sufficient to transpose your coefficent matrix, and update the address calculations accordingly. But the benefits really depend on the maximum degree and coefficient datatype.

I wrote some toy example (see below) which just computes the sum of coefficients for each polynomial.
Approach 1 uses your data layout. Approach 2 uses a transposed data layout. Approach 3 uses your data layout, but uses 8 threads per polynomial. Approach 4 uses a transposed data layout, but uses 8 threads per polynomial.

3 and 4 could give you an idea how one could use multiple threads per polynomial, but this requires quite a refactoring in general (and may not even prove beneficial)

Kernel timings on a rtx 3090 reported by nsight compute (approach 1,2,3,4):

float coefficient, degree 10
Duration usecond 113.18
Duration usecond 112.51
Duration usecond 73.41
Duration usecond 96.86

float coefficient, degree 32
Duration usecond 477.12
Duration usecond 352
Duration usecond 157.89
Duration usecond 285.54

thrust::complex<double> coefficient, degree 10
Duration usecond 328.45
Duration usecond 256.80
Duration usecond 679.90
Duration usecond 679.30

thrust::complex<double> coefficient, degree 32
Duration msecond 1.25
Duration usecond 788.13
Duration msecond 1.35
Duration msecond 1.35

The smaller the element size and the access stride, the closer the performance of non-transposed vs transposed.
Approach 3 works nice on floats, since 8-thread float sum reduction is hardware accelerated on 3090. The same applies to approach 4, but its memory access is worse (because using multiple threads counters the effect of data transposition)

For your use case with degree 10 and double-complex type, transposition gives ~25% speedup on this example.

//nvcc -O3 -arch=sm_86 main.cu -o main

#include <vector>
#include <algorithm>
#include <iostream>
#include <cassert>
#include <thrust/complex.h>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

namespace cg = cooperative_groups;

using CoeffT = thrust::complex<double>;

//compute sum of coefficients per polynomial
__global__
void coeffSumKernel(CoeffT* __restrict__ output, const CoeffT* __restrict__ coeffs, int N, int maxDegree){
    const int tid = threadIdx.x + blockIdx.x * blockDim.x;
    if(tid < N){
        const CoeffT* myCoeffs = coeffs + tid * maxDegree;
        CoeffT result = 0;
        for(int i = 0; i < maxDegree; i++){
            result += myCoeffs[i];
        }
        output[tid] = result;
    }
}


//compute sum of coefficients per polynomial, coeffs are transposed
__global__
void coeffSumKernel2(CoeffT* __restrict__ output, const CoeffT* __restrict__ coeffs, int N, int maxDegree){
    const int tid = threadIdx.x + blockIdx.x * blockDim.x;
    if(tid < N){
        const CoeffT* myCoeffs = coeffs + tid;
        CoeffT result = 0;
        for(int i = 0; i < maxDegree; i++){
            result += myCoeffs[i * N];
        }
        output[tid] = result;
    }
}


//compute sum of coefficients per polynomial, use groupsize threads (1,2,4,8,16,or 32) per polynomial
template<int groupsize>
__global__
void coeffSumKernel3(CoeffT* __restrict__ output, const CoeffT* __restrict__ coeffs, int N, int maxDegree){
    auto group = cg::tiled_partition<groupsize>(cg::this_thread_block());
    const int groupId = (threadIdx.x + blockIdx.x * blockDim.x) / groupsize;

    if(groupId < N){
        const CoeffT* myCoeffs = coeffs + groupId * maxDegree;
        CoeffT result = 0;
        const int numIters = (maxDegree + groupsize - 1) / groupsize;
        for(int iter = 0; iter < numIters; iter++){
            const int index = iter * groupsize + group.thread_rank();
            CoeffT val = index < maxDegree ? myCoeffs[index] : 0;
            result += cg::reduce(group, val, cg::plus<CoeffT>());
        }
        if(group.thread_rank() == 0){
            output[groupId] = result;
        }
    }
}

//compute sum of coefficients per polynomial, coeffs are transposed, use groupsize threads (1,2,4,8,16,or 32) per polynomial
template<int groupsize>
__global__
void coeffSumKernel4(CoeffT* __restrict__ output, const CoeffT* __restrict__ coeffs, int N, int maxDegree){
    auto group = cg::tiled_partition<groupsize>(cg::this_thread_block());
    const int groupId = (threadIdx.x + blockIdx.x * blockDim.x) / groupsize;

    if(groupId < N){
        const CoeffT* myCoeffs = coeffs + groupId;
        CoeffT result = 0;
        const int numIters = (maxDegree + groupsize - 1) / groupsize;
        for(int iter = 0; iter < numIters; iter++){
            const int index = iter * groupsize + group.thread_rank();
            CoeffT val = index < maxDegree ? myCoeffs[index * N] : 0;
            result += cg::reduce(group, val, cg::plus<CoeffT>());
        }
        if(group.thread_rank() == 0){
            output[groupId] = result;
        }
    }
}



int main(){
    const int N = 1000000;
    const int maxDegree = 10;
    const int maxCoeffs = N * maxDegree;

    //Approach 1, this is your current data layout
    std::vector<CoeffT> coeffs1(maxCoeffs);

    for(int i = 0; i < N; i++){
        for(int j = 0; j < maxDegree; j++){
            coeffs1[i * maxDegree + j] = i;
        }
    }

    CoeffT* d_coeffs1; 
    CoeffT* d_result1;
    cudaMalloc(&d_coeffs1, sizeof(CoeffT) * maxCoeffs);
    cudaMalloc(&d_result1, sizeof(CoeffT) * N);
    cudaMemcpy(d_coeffs1, coeffs1.data(), sizeof(CoeffT) * maxCoeffs, cudaMemcpyHostToDevice);

    coeffSumKernel<<<(maxCoeffs + 127)/128, 128>>>(d_result1, d_coeffs1, N, maxDegree);

    std::vector<CoeffT> result1(N);
    cudaMemcpy(result1.data(), d_result1, sizeof(CoeffT) * N, cudaMemcpyDeviceToHost);
    cudaDeviceSynchronize();

    // for(int i = 0; i < N; i++){
    //     if(result1[i] != i * maxDegree){
    //         printf("%f %d\n",result1[i], i * maxDegree);
    //     }
    //     assert(result1[i] == i * maxDegree);
    // }
    cudaFree(d_coeffs1);
    cudaFree(d_result1);


    //Approach 2, transpose the coefficient matrix. the coefficients of polynomial i are stored in the i-th column
    std::vector<CoeffT> coeffs2(maxCoeffs);

    for(int i = 0; i < N; i++){
        for(int j = 0; j < maxDegree; j++){
            coeffs2[i + j * N] = i;
        }
    }

    CoeffT* d_coeffs2; 
    CoeffT* d_result2;
    cudaMalloc(&d_coeffs2, sizeof(CoeffT) * maxCoeffs);
    cudaMalloc(&d_result2, sizeof(CoeffT) * N);
    cudaMemcpy(d_coeffs2, coeffs2.data(), sizeof(CoeffT) * maxCoeffs, cudaMemcpyHostToDevice);

    coeffSumKernel2<<<(maxCoeffs + 127)/128, 128>>>(d_result2, d_coeffs2, N, maxDegree);

    std::vector<CoeffT> result2(N);
    cudaMemcpy(result2.data(), d_result2, sizeof(CoeffT) * N, cudaMemcpyDeviceToHost);
    cudaDeviceSynchronize();

    //assert(result2 == result1);
    cudaFree(d_coeffs2);
    cudaFree(d_result2);


    //Approach 3, your data layout, use multiple threads per polynomial
    constexpr int numThreadsPerPoly3 = 8;
    std::vector<CoeffT> coeffs3(maxCoeffs);

    for(int i = 0; i < N; i++){
        for(int j = 0; j < maxDegree; j++){
            coeffs3[i * maxDegree + j] = i;
        }
    }

    CoeffT* d_coeffs3; 
    CoeffT* d_result3;
    cudaMalloc(&d_coeffs3, sizeof(CoeffT) * maxCoeffs);
    cudaMalloc(&d_result3, sizeof(CoeffT) * N);
    cudaMemcpy(d_coeffs3, coeffs3.data(), sizeof(CoeffT) * maxCoeffs, cudaMemcpyHostToDevice);    

    coeffSumKernel3<numThreadsPerPoly3><<<((N * numThreadsPerPoly3) + 127)/128, 128>>>(d_result3, d_coeffs3, N, maxDegree);

    std::vector<CoeffT> result3(N);
    cudaMemcpy(result3.data(), d_result3, sizeof(CoeffT) * N, cudaMemcpyDeviceToHost);
    cudaDeviceSynchronize();

    //assert(result3 == result1);

    cudaFree(d_coeffs3);
    cudaFree(d_result3);


    //Approach 4, transpose the coefficient matrix. use multiple threads per polynomial
    constexpr int numThreadsPerPoly4 = 8;
    std::vector<CoeffT> coeffs4(maxCoeffs);

    for(int i = 0; i < N; i++){
        for(int j = 0; j < maxDegree; j++){
            coeffs4[i + j * N] = i;
        }
    }

    CoeffT* d_coeffs4; 
    CoeffT* d_result4;
    cudaMalloc(&d_coeffs4, sizeof(CoeffT) * maxCoeffs);
    cudaMalloc(&d_result4, sizeof(CoeffT) * N);
    cudaMemcpy(d_coeffs4, coeffs4.data(), sizeof(CoeffT) * maxCoeffs, cudaMemcpyHostToDevice);    

    coeffSumKernel4<numThreadsPerPoly4><<<((N * numThreadsPerPoly4) + 127)/128, 128>>>(d_result4, d_coeffs4, N, maxDegree);

    std::vector<CoeffT> result4(N);
    cudaMemcpy(result4.data(), d_result4, sizeof(CoeffT) * N, cudaMemcpyDeviceToHost);
    cudaDeviceSynchronize();

    //assert(result4 == result1);

    cudaFree(d_coeffs4);
    cudaFree(d_result4);
}
1 Like

Thank you for the very detailed answer, that clarified a few things for me. I don’t think I can apply approach 2 easily because the ehrlich_aberth function need a pointer to a contiguous bit of memory so I’d have to rewrite it in a different way so I think I’ll stick with approach 1 and maybe use a GPU better optimized for double precision.

One more question. I need to compile the code with the --expt-relaxed-constexpr, otherwise the compiler complains about fma() being a host function. Do you know if this significantly affects performance?