Does every access to constant memory require an LDC?

I compiled the same CUDA kernel targeting different architectures and found a difference in constant memory access patterns:

SM_52

SASS code uses constant memory operands directly in arithmetic instructions:

sass

FFMA R2, R2, c[0x3][0x1010], R3 ; // Direct constant memory operand

SM_90

Constant memory must first be loaded using LDC instructions:

sass

LDC R16, c[0x0][0x210] ;     // Load constant first
ADD R1, R2, R16               // Then use the loaded register

I’m not sure if all constant memory access requires a load first, so I tried many cases and found the same result.

I am a software developer who has studied the compiler of the open source project MESA(Graphic driver). In my impression, constant memory is usually configured before execution and can be used directly because it is on-chip.
Why need a load here?
I apologize if this is a beginner’s question.

code here:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#define SCALE    (32)
#define FOO_SIZE ((int)(65536/sizeof(float)))
#define BAR_SIZE ((int)(65536/sizeof(float)/SCALE))

// Multiple constant memory arrays for more C-bank operations (reduced sizes)
__constant__ float foo[1024];            // Reduced from FOO_SIZE
__constant__ float coeffs[16];           // Coefficients for calculations
__constant__ int lookup_table[32];       // Integer lookup table (reduced)
__constant__ float matrix[4][4];         // 4x4 transformation matrix
__constant__ float thresholds[8];        // Threshold values
__constant__ int config_params[16];      // Configuration parameters (reduced)
__constant__ float sin_table[64];        // Sine lookup table (reduced)
__constant__ float weights[32];          // Weight values (reduced)

float foo_shadow[1024];                    // Reduced size
float coeffs_shadow[16] = {1.0f, 2.0f, 0.5f, 1.5f, 3.0f, 0.25f, 4.0f, 0.125f,
                          5.0f, 0.1f, 6.0f, 0.0625f, 7.0f, 0.05f, 8.0f, 0.03125f};
int lookup_shadow[32];                  // Reduced size
float matrix_shadow[4][4];
float thresholds_shadow[8] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
int config_shadow[16];                  // Reduced size
float sin_shadow[64];                   // Reduced size
float weights_shadow[32];               // Reduced size

struct soa {
    float arr[BAR_SIZE];
} bar_shadow;

__global__ void kernel(int i, struct soa bar)
{
    int tid = threadIdx.x + blockIdx.x * blockDim.x;

    // Multiple constant memory accesses - should generate direct C-bank operations
    float base_val = foo[i % 1024];             // c[0x0][offset] - use modulo for safety

    // Arithmetic operations with multiple constant arrays
    float result = base_val * coeffs[0];        // MUL with c[0x0][offset]
    result += coeffs[1];                        // ADD with c[0x0][offset]
    result -= coeffs[2] * base_val;             // SUB/MUL with c[0x0][offset]

    // Matrix operations using constant memory
    float transformed = 0.0f;
    for (int j = 0; j < 4; j++) {
        transformed += matrix[0][j] * result;    // FMAD with c[0x0][offset]
        transformed *= coeffs[j + 4];           // MUL with c[0x0][offset]
    }

    // Conditional operations with constant thresholds
    if (transformed > thresholds[0]) {          // SETP with c[0x0][offset]
        result = result * coeffs[8] + thresholds[1];  // FMAD with c[0x0][offset]
    } else if (transformed < thresholds[2]) {   // SETP with c[0x0][offset]
        result = result - coeffs[9] * thresholds[3]; // FSUB/FMUL with c[0x0][offset]
    }

    // Lookup operations
    int index = tid % 32;                       // Reduced size
    int lookup_val = lookup_table[index];       // Load from c[0x0][offset]
    result += weights[lookup_val % 32];         // ADD with c[0x0][offset] - reduced size

    // Complex arithmetic with multiple constant operands
    float complex_calc = 0.0f;
    for (int k = 0; k < 8; k++) {
        complex_calc += sin_table[k * 8] * weights[k * 4]; // FMAD with c[0x0][offset] - adjusted indices
        complex_calc *= coeffs[k];              // MUL with c[0x0][offset]

        // Nested conditionals with constants
        if (complex_calc > thresholds[k]) {     // SETP with c[0x0][offset]
            complex_calc = complex_calc / coeffs[k + 8]; // DIV with c[0x0][offset]
        }
    }

    // Min/Max operations with constants
    result = fmaxf(result, thresholds[4]);      // MAX with c[0x0][offset]
    result = fminf(result, thresholds[5]);      // MIN with c[0x0][offset]

    // Bit operations with constant integers
    int int_result = __float_as_int(result);
    int_result ^= config_params[0];             // XOR with c[0x0][offset]
    int_result &= config_params[1];             // AND with c[0x0][offset]
    int_result |= config_params[2];             // OR with c[0x0][offset]

    // Final computation mixing all constant arrays
    float final_result = __int_as_float(int_result);
    final_result = final_result * coeffs[15] +
                   complex_calc * thresholds[6] -
                   weights[31] +
                   sin_table[63] * matrix[3][3];

    float t = bar.arr[i / SCALE];
    printf("i=%d base=%15.8e final=%15.8e t=%15.8e prod=%15.8e\n",
           i, base_val, final_result, t, final_result * t);
}

int main(void)
{
    int pos = FOO_SIZE - 1;

    // Initialize all shadow arrays
    for (int i = 0; i < 1024; i++) {           // Reduced size
        foo_shadow[i] = sqrtf((float)i);
    }

    for (int i = 0; i < 32; i++) {             // Reduced size
        lookup_shadow[i] = i * 2 + 1;
    }

    for (int i = 0; i < 4; i++) {
        for (int j = 0; j < 4; j++) {
            matrix_shadow[i][j] = (i == j) ? 1.0f : 0.1f * (i + j);
        }
    }

    for (int i = 0; i < 16; i++) {             // Reduced size
        config_shadow[i] = 0xFF00FF00 >> (i % 8);
    }

    for (int i = 0; i < 64; i++) {             // Reduced size
        sin_shadow[i] = sinf(i * M_PI / 32.0f);
    }

    for (int i = 0; i < 32; i++) {             // Reduced size
        weights_shadow[i] = 1.0f / (i + 1);
    }

    for (int i = 0; i < BAR_SIZE; i++) {
        bar_shadow.arr[i] = (float)(9999 - i);
    }

    // Copy all constant memory
    cudaMemcpyToSymbol(foo, foo_shadow, sizeof(foo));
    cudaMemcpyToSymbol(coeffs, coeffs_shadow, sizeof(coeffs));
    cudaMemcpyToSymbol(lookup_table, lookup_shadow, sizeof(lookup_table));
    cudaMemcpyToSymbol(matrix, matrix_shadow, sizeof(matrix));
    cudaMemcpyToSymbol(thresholds, thresholds_shadow, sizeof(thresholds));
    cudaMemcpyToSymbol(config_params, config_shadow, sizeof(config_params));
    cudaMemcpyToSymbol(sin_table, sin_shadow, sizeof(sin_table));
    cudaMemcpyToSymbol(weights, weights_shadow, sizeof(weights));

    // Launch kernel
    kernel<<<1, 1>>>(pos, bar_shadow);
    cudaDeviceSynchronize();

    return EXIT_SUCCESS;
}

Whether to support a load-execute instruction type is a design choice made by processor architects. From work on x86 processors I know that load-execute instructions can improve efficiency via a reduction in dynamic instruction count and in execution latency, at the expense of making the control path more complicated. The alternative approach, typically chosen by “RISC” processors, is to employ a pure load-store architecture.

It is possible that NVIDIA processor architects changed their approach from older GPUs and that load-execute instructions are no longer supported in the latest GPU, presumably to simplify op steering / scheduling. I have not checked into that.

Before we reach such a conclusion, you might want to check that the code idiom you are observing is not simply a consequence of multiple uses of the same constant-memory data. In the case of multiple uses of the same constant-memory data object, moving the data to a register first may have advantages in the compiler’s view. This could be driven by any kind of heuristic considering instruction scheduling, resource contention, energy efficiency, etc.

While that is true, an access to constant memory may still be more expensive than an access to the register file. In older GPUs, the assumption was that the costs are “close enough” and that a constant-memory access occurs with “near-register speed”, making constant memory access an attractive alternative especially on register-starved early GPU architectures. The balance may have shifted more strongly in favor of register access for best performance in the recent past, but that is just (reasoned) speculation. I don’t think NVIDIA has published a detailed discussion of these tradeoffs for their recent GPU architectures.

[Later:]

After poking around a bit, I see load-execute instructions with constant-bank references being generated up to and including sm_89, but not for sm_90 and later architectures. So it does seem like load-execute instructions no longer exist in the latest GPU architectures.

2 Likes

constant memory is not “on-chip” any more than global memory is on chip.

The constant memory is populated (i.e. physically backed) in GPU DRAM (which is off-chip). When it is accessed, it is loaded through the so-called constant cache, which is on chip. An analogous scenario exists for typical accesses through the logical global space.

Imagining that a load occurs or is needed in one case but not the other isn’t how the GPU works.

LDC loads explicitly from the logical constant space.

an instruction that has a constant operand still requires a load from the logical constant space, it’s just that the load occurs implicitly (and presumably does not create register pressure).

1 Like

Thank you all. The MESA projects and AMD GPUs I used about are from 20 years ago. Thanks for sharing new things with me — I learned a lot.

One reason that the load is stored into a register first is that it kind of prefetches the data. So if there is any delay (e.g. conflict, if the threads of a warp read different addresses; or as Robert has indirectly mentioned the data is in the DRAM instead of cached), then a load instruction can just continue, after putting the load into the pipeline. Whereas the add has to wait for the availability of all source operands, before it can be executed.

The compiler may apply different heuristics, whether higher register pressure or more instructions typically give an advantage for different architectures.

It can also be that a new compiler heuristic or optimization was only activated for newer architectures as not to pessimize the compilation of existing kernels and save extensive simulation and testing work by Nvidia from also having to be done on and for very old architectures.

All current SM architectures are Volta based (SM_70) with just small tweaks and some extensions (especially Tensor Cores, asynchronous engine, shorter data formats) over the generations.

1 Like