Feedback on forward only randomly initialized convolution

Hello, I’m writing a super wide forward only convolution with cifar10 (32x32x3), I’m taking advantage of the fact that I’m going to only use randomly initialized weights so i can deterministically rerun with the same rng seeds instead of loading from memory. I have each warp storing all 3x3x3 pixels from the source image in each thread so each warp can write coalesced into the 32x32xN data.

One issue i’m running into is that each warp owns a center pixel in the convolution, but my card is an rtx 3090 with 82 sms, 82/2 = 41, which is a large prime. I’m curious how to go about balancing blocks so I don’t have a large tail effect. Right now I have it running 512 threads per block = 16 warps, so it takes 2 blocks to run a row of 32 pixels. If I naively overlap the blocks, then I have roughly half of the pixels with 3 warps each and the other half with 2 warps each.

Here’s my code, I welcome brutal feedback:

#define MAX_THREADS_PER_BLOCK 512
#define MIN_BLOCKS_PER_MP     2

const float recip255 = 1.0 / 255.0;
const size_t WIDTH = 1 << 22;
const float recipWidth = 1 / sqrt(WIDTH);

__forceinline__
__device__ void multAndAdd(float& sum0, float& sum1, float& sum2, float& sum3, curandStatePhilox4_32_10_t& local_rng, const float val) {
    float4 rng = curand_normal4(&local_rng);
    sum0 += rng.x*val;
    sum1 += rng.y*val;
    sum2 += rng.z*val;
    sum3 += rng.w*val;
}

__global__ void setup_kernel(curandStatePhilox4_32_10_t *state, unsigned long long seed) {
    const int id = threadIdx.x + blockIdx.x * blockDim.x;
    curand_init(seed, id, 0, &state[id]);
}

__forceinline__
__device__ float getByteFromColumn(const int col, const unsigned int val) {
    const int byte_idx = col / 4;
    const int byte_num = col % 4;
    unsigned int tmp = __shfl_sync(0xffffffff, val, byte_idx);
    return (0xff & tmp >> 8*byte_num) * recip255;
}

__device__ void byteImageToFloatConv(const uint8_t* const __restrict__ d_in, float* const __restrict__ d_out, curandStatePhilox4_32_10_t& rng) {
    const unsigned int* const d_in32 = reinterpret_cast<const unsigned int*>(d_in);
    cg::thread_block cta = cg::this_thread_block();
    cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);

    const int global_warp_id = tile32.meta_group_rank() + blockIdx.x*tile32.meta_group_size();
    const int row = global_warp_id / 32;
    const int col = global_warp_id % 32;

    unsigned int r_data = d_in32[0 + 8*row + tile32.thread_rank()];
    unsigned int g_data = d_in32[256 + 8*row + tile32.thread_rank()];
    unsigned int b_data = d_in32[512 + 8*row + tile32.thread_rank()];

    float r_ul = 0, r_um = 0, r_ur = 0, r_ml = 0, r_mm = 0, r_mr = 0, r_ll = 0, r_lm = 0, r_lr = 0;
    float g_ul = 0, g_um = 0, g_ur = 0, g_ml = 0, g_mm = 0, g_mr = 0, g_ll = 0, g_lm = 0, g_lr = 0;
    float b_ul = 0, b_um = 0, b_ur = 0, b_ml = 0, b_mm = 0, b_mr = 0, b_ll = 0, b_lm = 0, b_lr = 0;

    // BEGIN RED
    // populate top row
    if (row > 0) {
        if (col > 0) {
            r_ul = getByteFromColumn(col-1, r_data);
        }

        r_um = getByteFromColumn(col, r_data);

        if (col < 31) {
            r_ur = getByteFromColumn(col+1, r_data);
        }
    }

    // populate middle row
    if (col > 0) {
        r_ml = getByteFromColumn(col-1 + 32, r_data);
    }

    r_mm = getByteFromColumn(col + 32, r_data);

    if (col < 31) {
        r_mr = getByteFromColumn(col+1 + 32, r_data);
    }

    // populate bottom row
    if (row < 31) {
        if (col > 0) {
            r_ll = getByteFromColumn(col-1 + 64, r_data);
        }

        r_lm = getByteFromColumn(col + 64, r_data);

        if (col < 31) {
            r_lr = getByteFromColumn(col+1 + 64, r_data);
        }
    }
    // END RED

    // BEGIN GREEN
    // populate top row
    if (row > 0) {
        if (col > 0) {
            g_ul = getByteFromColumn(col-1, g_data);
        }

        g_um = getByteFromColumn(col, g_data);

        if (col < 31) {
            g_ur = getByteFromColumn(col+1, g_data);
        }
    }

    // populate middle row
    if (col > 0) {
        g_ml = getByteFromColumn(col-1 + 32, g_data);
    }

    g_mm = getByteFromColumn(col + 32, g_data);

    if (col < 31) {
        g_mr = getByteFromColumn(col+1 + 32, g_data);
    }

    // populate bottom row
    if (row < 31) {
        if (col > 0) {
            g_ll = getByteFromColumn(col-1 + 64, g_data);
        }

        g_lm = getByteFromColumn(col + 64, g_data);

        if (col < 31) {
            g_lr = getByteFromColumn(col+1 + 64, g_data);
        }
    }
    // END GREEN

    // BEGIN BLUE
    // populate top row
    if (row > 0) {
        if (col > 0) {
            b_ul = getByteFromColumn(col-1, b_data);
        }

        b_um = getByteFromColumn(col, b_data);

        if (col < 31) {
            b_ur = getByteFromColumn(col+1, b_data);
        }
    }

    // populate middle row
    if (col > 0) {
        b_ml = getByteFromColumn(col-1 + 32, b_data);
    }

    b_mm = getByteFromColumn(col + 32, b_data);

    if (col < 31) {
        b_mr = getByteFromColumn(col+1 + 32, b_data);
    }

    // populate bottom row
    if (row < 31) {
        if (col > 0) {
            b_ll = getByteFromColumn(col-1 + 64, b_data);
        }

        b_lm = getByteFromColumn(col + 64, b_data);

        if (col < 31) {
            b_lr = getByteFromColumn(col+1 + 64, b_data);
        }
    }
    // END BLUE

    for (int i = tile32.thread_rank(); i < WIDTH; i+=4*32) {
        float sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
        
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_ul);
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_um);
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_ur);

        multAndAdd(sum0, sum1, sum2, sum3, rng, r_ml);
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_mm);
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_mr);

        multAndAdd(sum0, sum1, sum2, sum3, rng, r_ll);
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_lm);
        multAndAdd(sum0, sum1, sum2, sum3, rng, r_lr);

        multAndAdd(sum0, sum1, sum2, sum3, rng, g_ul);
        multAndAdd(sum0, sum1, sum2, sum3, rng, g_um);
        multAndAdd(sum0, sum1, sum2, sum3, rng, g_ur);

        multAndAdd(sum0, sum1, sum2, sum3, rng, g_ml);
        multAndAdd(sum0, sum1, sum2, sum3, rng, g_mm);
        multAndAdd(sum0, sum1, sum2, sum3, rng, g_mr);

        multAndAdd(sum0, sum1, sum2, sum3, rng, g_ll);
        multAndAdd(sum0, sum1, sum2, sum3, rng, g_lm);
        multAndAdd(sum0, sum1, sum2, sum3, rng, g_lr);

        multAndAdd(sum0, sum1, sum2, sum3, rng, b_ul);
        multAndAdd(sum0, sum1, sum2, sum3, rng, b_um);
        multAndAdd(sum0, sum1, sum2, sum3, rng, b_ur);

        multAndAdd(sum0, sum1, sum2, sum3, rng, b_ml);
        multAndAdd(sum0, sum1, sum2, sum3, rng, b_mm);
        multAndAdd(sum0, sum1, sum2, sum3, rng, b_mr);

        multAndAdd(sum0, sum1, sum2, sum3, rng, b_ll);
        multAndAdd(sum0, sum1, sum2, sum3, rng, b_lm);
        multAndAdd(sum0, sum1, sum2, sum3, rng, b_lr);

        size_t idx = WIDTH * (size_t)global_warp_id;
        d_out[idx + i + 0*32] = sum0;
        d_out[idx + i + 1*32] = sum1;
        d_out[idx + i + 2*32] = sum2;
        d_out[idx + i + 3*32] = sum3;
    }
}

__global__ void
__launch_bounds__(MAX_THREADS_PER_BLOCK,MIN_BLOCKS_PER_MP) 
kerneltest(
    const uint8_t* const __restrict__ d_in,
    float* const __restrict__ d_out, 
    curandStatePhilox4_32_10_t* const __restrict__ d_rng
    ) {
        const int idx = threadIdx.x + blockIdx.x * blockDim.x;
        curandStatePhilox4_32_10_t rng = d_rng[idx];
        byteImageToFloatConv(d_in, d_out, rng);
}

Thanks!