Advice on depthwise convolution kernel and local memory strategies

I am trying to write a depthwise convolution kernel in openCL. I am going to port to CUDA soon.
shapes:

  1. input Hi x Wi x C
  2. kernel R x S x C
  3. output Ho x Wo x C

R is kernel height, S is kernel width, C is input channel, Hi, is input height, Wi is input width, and so on.

Here’s my strategy:
per work group shared mem tiling: threads in a workgroup cooperatively load all the input data they need into local memory (shared memory in cuda).

per thread register tiling: each thread loads all values it needs from the input tile (int local memory) into registers.

vectorization: I divide channels into groups of 4, to use half4 types.

micro-kernel choice: I chose to perform direct convolution. I don’t know if im2col, or winograd methods are warranted here because the arithmetic intensity is low.

Unfortunately, this strategy is beat by my other kernel that does just register tiling out of global memory with no use of local memory.

Can someone please guide me to what I might be doing wrong.

My goal is to get very close to peak memory bandwidth and I am only halfway there.

Here’s my problem sizes and settings.

#define Hi 114
#define Wi 114
#define R 3
#define S 3

#define Ho (Hi - R + 1)
#define Wo (Wi - S + 1)
#define C 32

#define CD4 (C / 4)

// work group sizes
#define WG_SIZE_H 2
#define WG_SIZE_W 2
#define WG_SIZE_C 8


// each thread computes THREAD_H x THREAD_W x THREAD_C output elements
#define THREAD_H 2
#define THREAD_W 2
#define THREAD_C 1

std::stringstream ss;
    ss
     << "-cl-std=CL2.0"
    << " -cl-fast-relaxed-math -O3"
    << DEF(N) << DEF(Hi) << DEF(Wi) << DEF(C) << DEF(R) << DEF(S) << DEF(Ho) << DEF(Wo) << DEF(CD4)
    << DEF(THREAD_H) << DEF(THREAD_W) << DEF(THREAD_C)
    << DEF(THREAD_H2) << DEF(THREAD_W2) << DEF(THREAD_C2)
    << DEF(WG_SIZE_H) << DEF(WG_SIZE_W) << DEF(WG_SIZE_C)
    << " -D INPUT_TILE_H=" << (THREAD_H + R - 1) << " -D INPUT_TILE_W=" << (THREAD_W + S - 1)
    << " -D INPUT_TILE_H2=" << (THREAD_H * WG_SIZE_H + R - 1) << " -D INPUT_TILE_W2=" << (THREAD_W * WG_SIZE_W + S - 1) << " -D INPUT_TILE_C2=" << (THREAD_C * WG_SIZE_C)
    << " -D LOAD_WORK_H=" << (CEIL_DIV(THREAD_H * WG_SIZE_H + R - 1, WG_SIZE_H)) << " -D LOAD_WORK_W=" << (CEIL_DIV(THREAD_W * WG_SIZE_W + S - 1, WG_SIZE_W))

host code for enqueing kernel

global_work_size = {Ho / THREAD_H, Wo / THREAD_W, CD4 / THREAD_C};
        local_work_size = {WG_SIZE_H, WG_SIZE_W, WG_SIZE_C};

        // warmup run
        ret = clEnqueueNDRangeKernel(queue, depthwiseconv2d3, 3, NULL, global_work_size.data(), local_work_size.data(), 0, NULL, &kernel_event);
        clWaitForEvents(1, &kernel_event);
        print_event2(kernel_event, name1);
        clReleaseEvent(kernel_event);

Kernel #1 uses local memory, but still loses to below kernel

kernel void depthwiseconv2d5(read_only image2d_t input_row_pack, const constant half4 * restrict filter __attribute__( (max_constant_size(R * S * C * 2))), global half4 * restrict output) {
    sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;

    local half4 smemT[INPUT_TILE_H2][INPUT_TILE_W2][INPUT_TILE_C2];

    if (((int)get_local_id(0) * LOAD_WORK_H < INPUT_TILE_H2) && ((int)get_local_id(1) * LOAD_WORK_W < INPUT_TILE_W2)) {
        // #pragma unroll LOAD_WORK_H
        // REP(i, LOAD_WORK_H) {
            #pragma unroll LOAD_WORK_W
            REP(j, LOAD_WORK_W) {
                #pragma unroll LOAD_WORK_H
                REP(i, LOAD_WORK_H) {
                smemT[(int)get_local_id(0) * LOAD_WORK_H + i][(int)get_local_id(1) * LOAD_WORK_W + j][(int)get_local_id(2)] = read_imageh(
                    input_row_pack,
                    sampler,
                    (int2)(
                        (((int)get_global_id(1) * LOAD_WORK_W) + j) * CD4 + ((int)get_local_id(2)), // local not global
                        ((int)get_global_id(0) * LOAD_WORK_H) + i
                    )
                );
            }
        }
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    half4 threadResults[THREAD_H][THREAD_W];
    
    // half4 filterT[R][S];
    // #pragma unroll S
    // REP(j, S) {
    //     #pragma unroll R
    //     REP(i, R) {
    //         filterT[i][j] = filter[(i * S + j) * CD4 + ((int)get_local_id(2))]; // local not global
    //     }
    // }

    half4 regT[INPUT_TILE_H][INPUT_TILE_W];
    // #pragma unroll INPUT_TILE_H
    // REP(i, INPUT_TILE_H) {
    #pragma unroll INPUT_TILE_W
    REP(j, INPUT_TILE_W) {
        #pragma unroll INPUT_TILE_H
        REP(i, INPUT_TILE_H) {
            // regT[i][j] = read_imageh(
            //     input_row_pack,
            //     sampler,
            //     (int2)(
            //         (((int)get_global_id(1) * THREAD_W) + j) * CD4 + ((int)get_local_id(2)), // local not global
            //         ((int)get_global_id(0) * THREAD_H) + i
            //     )
            // );

            regT[i][j] = smemT[(int)get_local_id(0) * THREAD_H + i][(int)get_local_id(1) * THREAD_W + j][(int)get_local_id(2)];
        }
    }

    // #pragma unroll THREAD_H
    // REP(threadH, THREAD_H) {
    //     #pragma unroll THREAD_W
    //     REP(threadW, THREAD_W) {
            #pragma unroll S
            REP(j, S) {
                #pragma unroll R
                REP(i, R) {
                    half4 filterReg = filter[(i * S + j) * CD4 + ((int)get_local_id(2))];
                    #pragma unroll THREAD_H
                    REP(threadH, THREAD_H) {
                        #pragma unroll THREAD_W
                        REP(threadW, THREAD_W) {

                    threadResults[threadH][threadW]
                    +=
                    regT[threadH + i][threadW + j]
                    // smemT[(int)get_local_id(0) * THREAD_H + i][(int)get_local_id(1) * THREAD_W + j][(int)get_local_id(2)]
                    *
                    filterReg;
                    // filter[(i * S + j) * CD4 + ((int)get_local_id(2))]; // local not global
                    // filterT[i][j];
                }
            }
        }
    }

    #pragma unroll THREAD_H
    REP(threadH, THREAD_H) {
        #pragma unroll THREAD_W
        REP(threadW, THREAD_W) {
            output[(((((int)get_global_id(0)) * THREAD_H) + threadH) * Wo + ((((int)get_global_id(1)) * THREAD_W) + threadW)) * CD4 + ((int)get_local_id(2))] = threadResults[threadH][threadW]; // local not global
        }
    }
}

Kernel #2 just register tiling and vectorization using direct convolution method

kernel void depthwiseconv2d4(read_only image2d_t input_row_pack, const constant half4 * restrict filter __attribute__( (max_constant_size(R * S * C * 2))), global half4 * restrict output) {
    sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
    // sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_LINEAR;

    half4 threadResults[THREAD_H][THREAD_W];
    
    // half4 filterT[R][S];
    // #pragma unroll S
    // REP(j, S) {
    //     #pragma unroll R
    //     REP(i, R) {
    //         filterT[i][j] = filter[(i * S + j) * CD4 + ((int)get_local_id(2))]; // local not global
    //     }
    // }

    half4 regT[INPUT_TILE_H][INPUT_TILE_W];
    #pragma unroll INPUT_TILE_H
    REP(i, INPUT_TILE_H) {
    #pragma unroll INPUT_TILE_W
    REP(j, INPUT_TILE_W) {
        // #pragma unroll INPUT_TILE_H
        // REP(i, INPUT_TILE_H) {
            regT[i][j] = read_imageh(
                input_row_pack,
                sampler,
                (int2)(
                    (((int)get_global_id(1) * THREAD_W) + j) * CD4 + ((int)get_local_id(2)), // local not global
                    ((int)get_global_id(0) * THREAD_H) + i
                )
            );

            // regT[i][j] = read_imageh(
            //     input_row_pack,
            //     sampler,
            //     (float2)(
            //         (((float)get_global_id(1) * THREAD_W) + j) * CD4 + ((int)get_local_id(2)) + 1.0f, // local not global
            //         ((float)get_global_id(0) * THREAD_H) + i + 1.0f
            //     )
            // );
        }
    }

    // #pragma unroll THREAD_H
    // REP(threadH, THREAD_H) {
    //     #pragma unroll THREAD_W
    //     REP(threadW, THREAD_W) {
    //         #pragma unroll S
    //         REP(j, S) {
    //             #pragma unroll R
    //             REP(i, R) {
    //                 threadResults[threadH][threadW]
    //                 +=
    //                 regT[threadH + i][threadW + j]
    //                 *
    //                 filter[(i * S + j) * CD4 + ((int)get_local_id(2))]; // local not global
    //                 // filterT[i][j];
    //             }
    //         }
    //     }
    // }

    // #pragma unroll THREAD_H
    // REP(threadH, THREAD_H) {
    //     #pragma unroll THREAD_W
    //     REP(threadW, THREAD_W) {
    #pragma unroll R
    REP(i, R) {
            #pragma unroll S
            REP(j, S) {
                // #pragma unroll R
                // REP(i, R) {
                    half4 regFilter = filter[(i * S + j) * CD4 + ((int)get_local_id(2))]; // local not global
                    #pragma unroll THREAD_H
                    REP(threadH, THREAD_H) {
                        #pragma unroll THREAD_W
                        REP(threadW, THREAD_W) {

                    threadResults[threadH][threadW]
                    +=
                    regT[threadH + i][threadW + j]
                    *
                    regFilter;
                    // filter[(i * S + j) * CD4 + ((int)get_local_id(2))]; // local not global
                    // filterT[i][j];
                }
            }
        }
    }

    #pragma unroll THREAD_H
    REP(threadH, THREAD_H) {
        #pragma unroll THREAD_W
        REP(threadW, THREAD_W) {
            output[(((((int)get_global_id(0)) * THREAD_H) + threadH) * Wo + ((((int)get_global_id(1)) * THREAD_W) + threadW)) * CD4 + ((int)get_local_id(2))] = threadResults[threadH][threadW]; // local not global
        }
    }
}