How is elementwise operators fusion done by compiler?

As I know, for several continuous elementwise operator kernels, we can fuse them into one kernel. It obviously can reduce overhead.

However, my question is, how and at which stage, the memory I/O can also be reduced?

Here is the link I read: Making Deep Learning go Brrrr From First Principles

The author says we can avoid repeatedly moving data between SRAM and DRAM under this situation. But how to implement this idea? Or the compiler help us do this?

image
image

Thank you for your reply. Actually I know the idea behind it. What I want to know is how to implement it by cuda c++, because from my point of view, I cannot control the behavior of register on the level of c++. So will the nvcc compiler help us do this automatically?

The compiler doesn’t help with that. Kernel fusion requires refactoring, and during the refactoring process, redundant loads and stores (that would not be redundant in separate kernels) are removed by you, the programmer, during the refactoring process. That’s not a compiler function. And its the most immediate and obvious benefit of kernel fusion.

If the question is “Can the CUDA compiler merge two __global__ functions automagically?” the answer is “No” (to the best of my knowledge).

Here’s how you do kernel fusion. You discover in your code you have this:

__global__ void k1(int *d, int *r, int N){
    int idx = threadIdx.x+blockDim.x*blockIdx.x;
    if (idx < N) r[idx] = d[idx]*2;
}

__global__ void k2(int *d, int *r, int N){
    int idx = threadIdx.x+blockDim.x*blockIdx.x;
    if (idx < N) r[idx] = d[idx]+2;
}

int main(){

  ...
  k1<<<grid, block>>>(d1, r1, N);
  k2<<<grid, block>>>(r1, r1, N);
  ...
}

And so what you do as a programmer is rewrite it like this:

__global__ void k12(int *d, int *r, int N){
    int idx = threadIdx.x+blockDim.x*blockIdx.x;
    if (idx < N) r[idx] = d[idx]*2+2;
}

int main(){

  ...
  k12<<<grid,  block>>>(d1, r1, N);
  ...
}

And you save one store operation and one load operation (per element). You do that by refactoring the code. The compiler doesn’t do it “for you”, doesn’t “help” and there is no way to coax this kind of transformation out of the nvcc compiler toolchain, currently. You do it. And it doesn’t require any programmer’s knowledge or control of register level behavior of the C++ compiler.

1 Like

For this part, how to guarantee r[idx] = d[idx]*2+2 is atomic on the register? (The compiler help us do this?) What if I write code like this:

__global__ void k12(int *d, int *r, int N){
    int idx = threadIdx.x+blockDim.x*blockIdx.x;
    if (idx < N) {
        r[idx] = d[idx]*2;
        r[idx] = r[idx]+2;
    }
}

int main(){

  ...
  k12<<<grid,  block>>>(d1, r1, N);
  ...
}

And what if my expression is too long to write in one line, how to do this fusion on register?

Source code formatting is irrelevant to the generated code. As long as compiler optimizations are enabled (that is the default for the CUDA compiler), the compiler will eliminate redundant memory operations. It may help to declare pointers to arrays which are function arguments as __restrict__ (see CUDA documentation).

In other words, during code optimization, the compiler transforms code like this

    if (idx < N) {
        r[idx] = d[idx]*2;
        r[idx] = r[idx]+2;
    }

into something like this

    if (idx < N) {
        int temp = d[idx];
        temp = temp * 2 + 2;  // this may map to an IMAD, or ISCADD, or LEA instruction
        r[idx] = temp;
    }

where temp is placed in a register. Compared to the unmerged kernels, bandwidth requirements have been cut in half. You can always inspect the generated machine code (SASS) with cuda-objdump --dump-sass to check on the number of load and store instructions being produced.

Thank you for your detailed explanation! Sorry to bother but still another question. In this code:

 if (idx < N) {
    int temp = d[idx];
    temp = temp * 2 + 2;  // this may map to an IMAD, or ISCADD, or LEA instruction
    r[idx] = temp;
}

Why temp will be placed directly on register? I have thought that the variables declared in global function was placed on DRAM. The variables declared with __shared__ decorator will be placed on SRAM. Am I wrong?

Thread-local variables like temp are placed into local memory (a particular re-mapping of a chunk of global memory) by default. As an optimization the compiler places them into registers instead, subject to various restrictions. By default, the CUDA compiler compiles device code with full optimizations, comparable to what host compilers may do at -O3 or /Ox.

Note that debug builds turn off all optimizations.

By definition, __shared_ data is not thread-local data. See the CUDA documentation for details on the handling of __shared__ data.

I get it. Thank you so much!