Constant memory provides no improvement

Hi

I am using 2D convolution and applying filter (3 x 3) to an image (2048 x 2048). I wrote two versions: one uses global memory accesses and another uses constant memory for the filter. When I benchmark the code (on my RTX 3090), I see no improvement with the use of constant memory.

Kernel using global memory

__global__ void gpu_conv2d_kernel(float *d_N_ptr, float *d_F_ptr, float *d_P_ptr, int n_rows, int n_cols)
{
    // Which output element this thread works on
    int out_col = blockIdx.x*blockDim.x + threadIdx.x;
    int out_row = blockIdx.y*blockDim.y + threadIdx.y;
    
    // Check if output element is valid
    if (out_row < n_rows && out_col < n_cols) 
    {
        // Result (in thread register)
        float p_val = 0.0f;
        
        // Loop over elements of the filter array
        for (int f_row = 0; f_row < 2*FILTER_RADIUS+1; f_row++) 
        {
            for (int f_col = 0; f_col < 2*FILTER_RADIUS+1; f_col++) 
            {
                // Input element to filter element mapping
                int in_row = out_row + (f_row - FILTER_RADIUS);
                int in_col = out_col + (f_col - FILTER_RADIUS);
                        
                // Boundary check
                if (in_row >= 0 && in_row < n_rows && in_col >= 0 && in_col < n_cols) 
                    p_val += d_F_ptr[f_row*(2*FILTER_RADIUS+1) + f_col] * d_N_ptr[in_row*n_cols + in_col];
                }
        }
        d_P_ptr[out_row*n_cols + out_col] = p_val;
    }
}

Kernel using constant memory

#define FILTER_RADIUS 1
extern __constant__ float d_F[(2*FILTER_RADIUS+1)*(2*FILTER_RADIUS+1)];

__global__ void gpu_conv2d_constMem_kernel(float *d_N_ptr, float *d_P_ptr, int n_rows, int n_cols)
{
    // Which output element this thread works on
    int out_col = blockIdx.x*blockDim.x + threadIdx.x;
    int out_row = blockIdx.y*blockDim.y + threadIdx.y;
    
    // Check if output element is valid
    if (out_row < n_rows && out_col < n_cols) 
    {
        // Result (in thread register)
        float p_val = 0.0f;
        
        // Loop over elements of the filter array
        for (int f_row = 0; f_row < 2*FILTER_RADIUS+1; f_row++) 
        {
            for (int f_col = 0; f_col < 2*FILTER_RADIUS+1; f_col++) 
            {
                // Input element to filter element mapping
                int in_row = out_row + (f_row - FILTER_RADIUS);
                int in_col = out_col + (f_col - FILTER_RADIUS);
                
                // Boundary check
                if (in_row >= 0 && in_row < n_rows && in_col >= 0 && in_col < n_cols) 
                    p_val += d_F[f_row*(2*FILTER_RADIUS+1)+f_col] * d_N_ptr[in_row*n_cols + in_col];
            }
        }
        d_P_ptr[out_row*n_cols + out_col] = p_val;
    }
}

I’ve written a main function that performs benchmarking.

// All essential includes
// .
// .
// .

#define FILTER_RADIUS 1
__constant__ float d_F[(2*FILTER_RADIUS+1)*(2*FILTER_RADIUS+1)];

// CUDA Error Checking
#define cuda_check(err) { \
    if (err != cudaSuccess) { \
        std::cout << cudaGetErrorString(err) << " in " << __FILE__ << " at line " << __LINE__ << "\n"; \
        exit(EXIT_FAILURE); \
    } \
}

int main(int argc, char const *argv[])
{
    // Benchmarking variables
    float elapsed_time_mem_alloc, 
            elapsed_time_mem_t_in, elapsed_time_mem_t_f, elapsed_time_mem_t_out, 
            elapsed_time_kernel;
    cudaEvent_t beg, end;
    cudaEventCreate(&beg);
    cudaEventCreate(&end);
    // ---------------------------------------------------------- //
    // ------------------ Load image in memory ------------------ //
    // ---------------------------------------------------------- //
    //.
    //.
    //.
    // ---------------------------------------------------------- //
    // ----------------- GPU memory allocation ------------------ //
    // ---------------------------------------------------------- //
    cudaError_t err;
    
    std::cout << "Allocating GPU memory... \n";
    cudaEventRecord(beg);
    
    float* d_N;
    err = cudaMalloc((void**) &d_N, new_size*new_size*sizeof(float));
    cuda_check(err);

    float *d_P; 
    err = cudaMalloc((void**) &d_P, new_size*new_size*sizeof(float));
    cuda_check(err);

    cudaEventRecord(end);
    cudaEventSynchronize(beg);
    cudaEventSynchronize(end);
    cudaEventElapsedTime(&elapsed_time_mem_alloc, beg, end);
    elapsed_time_mem_alloc /= 1000.;

    std::cout << "Time for GPU memory allocation (seconds): " << elapsed_time_mem_alloc << "\n";
    std::cout << "\n";

    // ---------------------------------------------------------- //
    // ------------------- Move input to GPU -------------------- //
    // ---------------------------------------------------------- //
    std::cout << "Moving input to GPU memory... \n";
    cudaEventRecord(beg);
    
    err = cudaMemcpy(d_N, N, new_size*new_size*sizeof(float), cudaMemcpyHostToDevice);
    cuda_check(err);

    cudaEventRecord(end);
    cudaEventSynchronize(beg);
    cudaEventSynchronize(end);
    cudaEventElapsedTime(&elapsed_time_mem_t_in, beg, end);
    elapsed_time_mem_t_in /= 1000.;
    std::cout << "Time for input data transfer (seconds): " << elapsed_time_mem_t_in << "\n";
    std::cout << "\n";

    // ------------------------------------------------------------------------- //
    // ----------------------- Initialize filter ------------------------------- //
    // ------------------------------------------------------------------------- //
    std::string filter_type;
    float *F = new float[(2*FILTER_RADIUS+1)*(2*FILTER_RADIUS+1)];

    int iter = 0;
    while (true)
    {
        // ------------------------------------------------------------------------- //
        // Which filter; Options: Sharpen, High-pass, Low-pass, Gaussian, d_Gaussian //
        // ------------------------------------------------------------------------- //
        std::cout << "Filter options: Sharpen, High-pass, Low-pass, Gaussian, d_Gaussian \n";
        std::cout << "Enter filter (press 'q' to exit): ";
        std::cin >> filter_type;


        // ---------------------------------------------------------- //
        // ---------------- Defining filter matrix ------------------ //
        // ---------------------------------------------------------- //
        if (filter_type == "Sharpen")
        {
            float alpha = 0.8f;
            std::cout << "Enter alpha between 0 and 1 (default: 0.8): ";
            std::cin >> alpha;
            std::cout << "\n";

            F[0] = -alpha/(9-9*alpha);
            F[1] = -alpha/(9-9*alpha);
            F[2] = -alpha/(9-9*alpha);
            F[3] = -alpha/(9-9*alpha);
            F[4] = (9-alpha)/(9-9*alpha);
            F[5] = -alpha/(9-9*alpha);
            F[6] = -alpha/(9-9*alpha);
            F[7] = -alpha/(9-9*alpha);
            F[8] = -alpha/(9-9*alpha);
            
        }
        else if (filter_type == "High-pass")
        {
            std::cout << "\n";   
            F[0] = -1;
            F[1] = -1;
            F[2] = -1;
            F[3] = -1;
            F[4] = 8;
            F[5] = -1;
            F[6] = -1;
            F[7] = -1;
            F[8] = -1;
        }
        else if (filter_type == "Low-pass")
        {
            float alpha = 9.0f;
            std::cout << "Enter alpha (default: 9.0): ";
            std::cin >> alpha;
            std::cout << "\n";

            F[0] = 1/alpha;
            F[1] = 1/alpha;
            F[2] = 1/alpha;
            F[3] = 1/alpha;
            F[4] = 1/alpha;
            F[5] = 1/alpha;
            F[6] = 1/alpha;
            F[7] = 1/alpha;
            F[8] = 1/alpha;
        }
        else if (filter_type == "Gaussian")
        {
            float alpha = 16.0f;
            std::cout << "Enter alpha (default: 16.0): ";
            std::cin >> alpha;
            std::cout << "\n";

            F[0] = 1/alpha;
            F[1] = 2/alpha;
            F[2] = 1/alpha;
            F[3] = 2/alpha;
            F[4] = 3/alpha;
            F[5] = 4/alpha;
            F[6] = 1/alpha;
            F[7] = 2/alpha;
            F[8] = 1/alpha;
        }
        else if (filter_type == "d_Gaussian")
        {
            std::cout << "\n";
            F[0] = -2;
            F[1] = 1;
            F[2] = -2;
            F[3] = 1;
            F[4] = 4;
            F[5] = 1;
            F[6] = -2;
            F[7] = 1;
            F[8] = -2;
        }
        else if (filter_type == "q")
        {
            break;
        }
        else
        {
            std::cout << "Filter not supported!" << "\n";
            std::terminate();
        }

        
        // ---------------------------------------------------------- //
        // ------------------ Move filter to GPU -------------------- //
        // ---------------------------------------------------------- //
        std::cout << "Moving filter to GPU constant memory... \n";
        cudaEventRecord(beg);
        
        err = cudaMemcpyToSymbol(d_F, F, (2*FILTER_RADIUS+1)*(2*FILTER_RADIUS+1)*sizeof(float));
        cuda_check(err);
        cudaDeviceSynchronize();

        cudaEventRecord(end);
        cudaEventSynchronize(beg);
        cudaEventSynchronize(end);
        cudaEventElapsedTime(&elapsed_time_mem_t_f, beg, end);
        elapsed_time_mem_t_f /= 1000.;
        std::cout << "Time for filter data transfer (seconds): " << elapsed_time_mem_t_f << "\n";
        std::cout << "\n";

        // ---------------------------------------------------------- //
        // --------------------- 2D Convolution --------------------- //
        // ---------------------------------------------------------- //

        // Applying filters frame by frame
        std::cout << "Applying filter... \n"; 

        // Kernel execution
        cudaEventRecord(beg);

        dim3 dim_block(32, 32, 1);
        dim3 dim_grid(ceil(new_size/(float)(32)), ceil(new_size/(float)(32)), 1);
        gpu_conv2d_constMem_kernel<<<dim_grid, dim_block>>>(d_N, d_P, new_size, new_size);
        cudaDeviceSynchronize();
        
        cudaEventRecord(end);
        cudaEventSynchronize(beg);
        cudaEventSynchronize(end);
        cudaEventElapsedTime(&elapsed_time_kernel, beg, end);
        elapsed_time_kernel /= 1000.;
        std::cout << "Time for kernel execution (seconds): " << elapsed_time_kernel << "\n";
        std::cout << "\n";

        // ---------------------------------------------------------- //
        // ---------- Copying result back to host memory -------------//
        // ---------------------------------------------------------- //
        std::cout << "Moving result to CPU memory... \n";
        cudaEventRecord(beg);
        
        err = cudaMemcpy(P, d_P, new_size*new_size*sizeof(float), cudaMemcpyDeviceToHost);
        cuda_check(err);
        
        cudaEventRecord(end);
        cudaEventSynchronize(beg);
        cudaEventSynchronize(end);
        cudaEventElapsedTime(&elapsed_time_mem_t_out, beg, end);
        elapsed_time_mem_t_out /= 1000.;
        std::cout << "Time for output data transfer (seconds): " << elapsed_time_mem_t_out << "\n";
        std::cout << "\n";

        // ---------------------------------------------------------- //
        // --------------------- Benchmarking ------------------------//
        // ---------------------------------------------------------- //

        std::cout << "--------------------- \n";
        std::cout << "Benchmarking details: \n";
        std::cout << "--------------------- \n";
        if (iter == 0)
        {
            std::cout << "Time (total): " << elapsed_time_kernel + elapsed_time_mem_alloc + 
                                                elapsed_time_mem_t_in + elapsed_time_mem_t_f + elapsed_time_mem_t_out << "\n";
            std::cout << "FPS (total): " << 1 / (elapsed_time_kernel + elapsed_time_mem_alloc + 
                                                elapsed_time_mem_t_in + elapsed_time_mem_t_f + elapsed_time_mem_t_out) << "\n";
            std::cout << "\n";
        }
        else
        {
            std::cout << "Time (total): " << elapsed_time_kernel +  elapsed_time_mem_t_f + elapsed_time_mem_t_out << "\n";
            std::cout << "FPS (total): " << 1 / (elapsed_time_kernel +  elapsed_time_mem_t_f+ elapsed_time_mem_t_out) << "\n";
            std::cout << "\n";
        }

        std::cout << "Time (kernel): " << elapsed_time_kernel << "\n";
        std::cout << "FPS (kernel): " << 1 / (elapsed_time_kernel) << "\n";
        std::cout << "GFLOPS (kernel): " << 2*new_size*new_size*(2*FILTER_RADIUS+1)*(2*FILTER_RADIUS+1) * 1e-9 / elapsed_time_kernel << "\n";
        std::cout << "------------------------------------ \n";
        std::cout << "\n";

        // ----------------------------------------------------------------- //
        // -------------------- Saving output as jpg ----------------------- //
        // ----------------------------------------------------------------- //
        //.
        //.
        //.
        iter += 1;
    }

    delete[] N;
    delete[] F;
    delete[] P;

    cudaFree(d_N);
    cudaFree(d_P);

    return 0;
}

You can see the code repository here. I’m not sure if there is a problem with the implementation or maybe the problem is not complicated enough (although I tried large input image sizes).

What performance bottlenecks did the CUDA profiler identify when you profiled the code?

I am not able to profile using Nsight compute. It keeps returning errors and to be fair it’s way too complicated for me ( as I’m just starting out with CUDA).

Since you are just starting with CUDA and are also interested in performance issues (as opposed to just trying to achieve functional code), now is the perfect time to learn to use Nsight.

I was able to run the profiler for both kernels. Here are the details:

Global Memory

Constant Memory

I’m not sure why the SM and memory bandwidth utilization has gone down. I’m not accessing global memory for filters so theoretically, it should go up and I’m not changing anything in the algorithm that should affect the SM utilization.

Could you compare the Warp State Statistics within Nsight Compute? This statistics shows the reasons, why the scheduler has to wait instead of scheduling the next warp, and is a good hint for already rather well-performing kernels.

Is any performance difference observed between these two kernels as FILTER_RADIUS is increased?

What happens if a #pragma unroll is inserted directly before the outer loop?

I experimented with the block dims and found 16x16 to be optimum. However, I still don’t see any improvement from using constant memory.

Global Memory

Constant Memory

Warp state comparison

Global Memory

Constant Memory

Memory Workload Analysis

Global Memory

Constant Memory

I also tried #pragma unroll and it doesn’t change anything.

#pragma unroll on the outer loop should have caused the compiler to completely unrolled the loop nest, enhancing the mobility of load instructions for a potential increase in latency tolerance. This code is very load intensive and does very little computation, causing it to be partially limited by load latency.

The regular cache hierarchy should work well for this code as far as there is data reuse. The generated machine code is dominated by loads, of which the loads converted from “gmem” to “constant” comprise a fairly small portion, so I would not expect much speedup from switching to the “constant” variant of the kernel. However, I would have expected a small performance difference, e.g. a 5% speedup.

We are all missing: The version with constant cache is much faster indeed. We should not look at the largely displayed SOL charts, but at the shown running time in µs or cycles.

In the first example

global memory: 111,341 cycles
constant memory: 93,399 cycles

With optimized thread block size

global memory: 73,922 cycles
constant memory: 62,857 cycles

Theoretical optimal speed
The memory clock multiplier is 9.73 GHz/1.39 GHz (values shown in Nsight Compute) = 7. The RTX 3090 has a memory interface of 384 bits.

Memory accesses needed: 2048 x 2048 x 2 (input+output) x 32 bits = 268,435,456 bits
268,435,456 / 7 / 2 (double data rate memory) / 384 = 49,932 cycles

Reasons

The global memory requests to the L1 Cache were halved from 2.36M to 1.18M. With the reduced number of transactions the Stall LG Throttle Warp State (which usually shows too many requests for the local/global memory pipeline) nearly vanished.

It seems to me only one of these statements can be correct.

I don’t see any improvement when I run benchmarks.

Constant Memory

------------------------------------------ 
GPU (constant memory) Benchmarking details 
------------------------------------------ 
Time for GPU memory allocation (seconds): 0.00026624

Time for input data transfer (seconds): 0.00284182

Time for filter data transfer (seconds): 0.000221472

Time for kernel execution (seconds): 4.86093e-05

Time for output data transfer (seconds): 0.00571584

Time (total): 0.00909398
FPS (total): 109.963

Time (kernel): 4.86093e-05
FPS (kernel): 20572.2
GFLOPS (kernel): 1553.15
------------------------------------------

Global Memory

------------------------ 
GPU Benchmarking details 
------------------------ 
Time for GPU memory allocation (seconds): 0.000300032

Time for input data transfer (seconds): 0.00282214

Time for filter data transfer (seconds): 9.152e-06

Time for kernel execution (seconds): 5.08006e-05

Time for output data transfer (seconds): 0.0058751

Time (total): 0.00905723
FPS (total): 110.409

Time (kernel): 5.08006e-05
FPS (kernel): 19684.8
GFLOPS (kernel): 1486.15
------------------------

I hope I’m benchmarking correctly.

// Benchmarking variables
float elapsed_time_mem_alloc, 
          elapsed_time_mem_t_in, elapsed_time_mem_t_f, elapsed_time_mem_t_out, 
          elapsed_time_kernel;
cudaEvent_t beg, end;
cudaEventCreate(&beg);
cudaEventCreate(&end);

// ---------------------------------------------------------- //
// ----------------- GPU memory allocation ------------------ //
// ---------------------------------------------------------- //
cudaError_t err;
  
std::cout << "Allocating GPU memory... \n";
cudaEventRecord(beg);
  
float* d_N;
err = cudaMalloc((void**) &d_N, new_size*new_size*sizeof(float));
cuda_check(err);

float *d_P; 
err = cudaMalloc((void**) &d_P, new_size*new_size*sizeof(float));
cuda_check(err);

cudaEventRecord(end);
cudaEventSynchronize(beg);
cudaEventSynchronize(end);
cudaEventElapsedTime(&elapsed_time_mem_alloc, beg, end);
elapsed_time_mem_alloc /= 1000.;

// ---------------------------------------------------------- //
// ------------------- Move input to GPU -------------------- //
// ---------------------------------------------------------- //
std::cout << "Moving input to GPU memory... \n";
cudaEventRecord(beg);
  
err = cudaMemcpy(d_N, N, new_size*new_size*sizeof(float), cudaMemcpyHostToDevice);
cuda_check(err);

cudaEventRecord(end);
cudaEventSynchronize(beg);
cudaEventSynchronize(end);
cudaEventElapsedTime(&elapsed_time_mem_t_in, beg, end);
elapsed_time_mem_t_in /= 1000.;

// .
//.
//.

// ---------------------------------------------------------- //
// --------------------- 2D Convolution --------------------- //
// ---------------------------------------------------------- //

dim3 dim_block(16, 16, 1);
dim3 dim_grid(ceil(new_size/(float)(16)), ceil(new_size/(float)(16)), 1);
for (int i = 0; i < 100; i++)
    gpu_conv2d_constMem_kernel<<<dim_grid, dim_block>>>(d_N, d_P, new_size, new_size);

cudaEventRecord(end);
cudaEventSynchronize(beg);
cudaEventSynchronize(end);
cudaEventElapsedTime(&elapsed_time_kernel, beg, end);
elapsed_time_kernel /= (1000. * 100);

That is a performance improvement of 4.5% when __constant__ is used. This is outside noise level, and close to the 5% I expected.

One additional change which may or may not make sense for your use case is to stick the filter array into a struct and pass this struct as a kernel argument. Kernel arguments reside in constant memory.

If the number of possible filter configurations is “small”, you might also consider using a templated kernel, instantiated for each filter configuration, which lets you convert the filter data into literal constants, avoiding loads for the filter data altogether.

You are using no warmup phase for benchmarking.

The Nsight Compute shows an even higher improvement than 5% for the kernel run alone with global memory being slower and constant memory being faster than your measured numbers. Not sure, why. Have you used the same 16x16 block configuration?

Apart from benchmarking for production code (as you mention the FPS figure) the GPU memory allocation and filter data transfer should be done once outside performance critical loops. The input and output data transfer should be overlapping with kernel execution and should be stored in pinned memory on the host side, if possible.

The filter data could also be compiled into the kernel, e.g. with kernel templates.

The 3x3 filter kernel is computationally a not very difficult problem, the PCIe transfers will take away any performance gain the GPUs would give you. If the data is not on the GPU before or after (for other processing) and you do not want to use larger filters, an AVX approach would probably be faster.

Why is the output transfer half as fast as the input transfer?

This was one off. Here are the benchmarks from another run (with 10 warm up runs).

Global Memory

------------------------ 
GPU Benchmarking details 
------------------------ 
Time for GPU memory allocation (seconds): 0.000331488

Time for input data transfer (seconds): 0.00283232

Time for filter data transfer (seconds): 8.8e-06

Time for kernel execution (seconds): 5.10054e-05

Time for output data transfer (seconds): 0.00601456

Time (total): 0.00923817
FPS (total): 108.246

Time (kernel): 5.10054e-05
FPS (kernel): 19605.8
GFLOPS (kernel): 1480.18
------------------------

Constant Memory

------------------------------------------ 
GPU (constant memory) Benchmarking details 
------------------------------------------ 
Time for GPU memory allocation (seconds): 0.000233472

Time for input data transfer (seconds): 0.00281638

Time for filter data transfer (seconds): 0.000149728

Time for kernel execution (seconds): 5.48352e-05

Time for output data transfer (seconds): 0.00625805

Time (total): 0.00951247
FPS (total): 105.125

Time (kernel): 5.48352e-05
FPS (kernel): 18236.5
GFLOPS (kernel): 1376.81
------------------------------------------

You have a quite huge percentage difference in execution time between Compute Nsight (45µs) and your benchmark (55µs) for the constant memory case. Normally Compute Nsight with default parameters clears the L2 cache and sets the clock to the base clock to be better reproducible, so your benchmark should be faster than Compute Nsight.