These days I have tried a lot of ways to get down the register pressure but still no progress. Besides, I find the register allocation of the nvcc compiler is totally incomprehensible in that changing two single lines of my code results in significantly more register consumption. So I decide to put all details here, and I really need help to fight against register pressure for my case.
I provide a minimum reproducible code below. The structure of the code is to some extent similar to matrix multiplication (e.g., A * B + C), where the program first reads C, then continuously reads a piece of inputs (A and B) of size 16 and stores it into shared memory, and then calculates the output and adds to the final answer. In the code I create a template kernel which is then specialized into two versions. The only difference between these two versions lies in the ways of calculating the address offset of the input pointer. All the other codes are the same except for two lines: line 72 and line 131. However, these two lines are not inside the main loop and is executed only at initialization and in the end. So we should expect that the performance of the two kernels are similar. However, after compilation, kernel 1 has 64 registers while kernel 2 has 96! I can’t understand it at all why the difference of these two lines results in such a huge gap of register allocation. In my opinion, the register bottleneck should be inside the main loop where huge computation is needed.
Consequently, kernel 1 is 50% faster than kernel 2, despite the computation of the two kernels are exactly the same, and the only difference is in fetching the input. I have tried a lot in order to get down the register pressure for kernel 2, but still no progress. Using --max-reg-count or launch bound does not help, even if they can reduce registers. After many days’ struggling I finally have to come here for help. Really appreciate if someone can help me make the kernel faster. Thanks a lot!
#include <type_traits>
const int BLOCK_SIZE = 16;
#define CONST_PTR const float* __restrict__
#define PTR float* __restrict__
struct FCInput {
__device__ int get_offset(int G, int C, int HW, int b, int g, int c, int hw, bool& success) {
success = true;
return ((b * G + g) * C + c) * HW + hw;
}
};
template <int K, int P, bool near_padding>
struct ConvInput {
int H, W, S, WO;
__device__ int get_offset(int G, int C, int HW, int b, int g, int c, int hw, bool& success) {
int offset_bc = ((b * G + g) * C + c) / (K * K);
int h = hw / WO, w = hw - h * WO;
int offset_h = h * S + c / K % K - P, offset_w = w * S + c % K - P;
if (P > 0 && near_padding) {
offset_w = max(min(offset_w, W - 1), 0);
offset_h = max(min(offset_h, H - 1), 0);
}
success = near_padding || (offset_w >= 0 && offset_w < W && offset_h >= 0 && offset_h < H);
return (offset_bc * H + offset_h) * W + offset_w;
}
};
template <typename Input>
__device__ __forceinline__ bool is_conv(Input i_conf) { return !std::is_empty<Input>::value; }
__device__ __forceinline__
float update_backward(float x, float w, float b, float o, float g) {
return g * exp2f(w * x + b - o);
}
template <int GROUP_CI, int GROUP_B, bool has_hw, bool check_co, typename Input> __global__
void logsumexp_backward_input_kernel(CONST_PTR grad_output, CONST_PTR input, CONST_PTR weight, CONST_PTR bias,
CONST_PTR output, int B, int CO_div_G, int CI_div_G, int HW, int G, PTR grad_input,
Input i_conf) {
if (!has_hw) HW = 1;
int b_hw = blockIdx.x * (BLOCK_SIZE * GROUP_B) + (has_hw ? threadIdx.x : threadIdx.y);
int b[GROUP_B], hw[GROUP_B];
#pragma unroll
for (int i = 0; i < GROUP_B; i++) {
int b_hw_i = b_hw + i * BLOCK_SIZE;
if (has_hw) { b[i] = b_hw_i / HW; hw[i] = b_hw_i % HW; }
else { b[i] = b_hw_i; hw[i] = 0; }
}
__shared__ float blockO[GROUP_B][BLOCK_SIZE][BLOCK_SIZE]; // CO * B if has_hw else B * CO
__shared__ float blockG[GROUP_B][BLOCK_SIZE][BLOCK_SIZE]; // CO * B if has_hw else B * CO
__shared__ float blockW[GROUP_CI][BLOCK_SIZE][BLOCK_SIZE]; // CO * CI
__shared__ float blockB[GROUP_CI][BLOCK_SIZE][BLOCK_SIZE]; // CO * CI
float res[GROUP_B][GROUP_CI], x[GROUP_B][GROUP_CI];
#pragma unroll
for (int i = 0; i < GROUP_B; i++) {
#pragma unroll
for (int j = 0; j < GROUP_CI; j++)
res[i][j] = 0;
}
#pragma unroll
for (int i = 0; i < GROUP_B; i++) {
int write_ci = blockIdx.y * (BLOCK_SIZE * GROUP_CI) + (has_hw ? threadIdx.y : threadIdx.x);
#pragma unroll
for (int j = 0; j < GROUP_CI; j++) {
int channel = write_ci + j * BLOCK_SIZE;
bool success;
int offset = i_conf.get_offset(G, CI_div_G, HW, b[i], blockIdx.z, channel, hw[i], success);
x[i][j] = success && b[i] < B && channel < CI_div_G ? input[offset] : 0;
}
}
for (int k = 0; k < CO_div_G; k += BLOCK_SIZE) {
#pragma unroll
for (int i = 0; i < GROUP_B; i++) {
int channel = k + (has_hw ? threadIdx.y : threadIdx.x);
if (b[i] < B) {
int output_offset = ((b[i] * G + blockIdx.z) * CO_div_G + channel) * HW + hw[i];
float value_o = check_co && channel >= CO_div_G ? 1e10f : output[output_offset];
float value_g = check_co && channel >= CO_div_G ? 0 : grad_output[output_offset];
blockO[i][threadIdx.y][threadIdx.x] = value_o;
blockG[i][threadIdx.y][threadIdx.x] = value_g;
}
}
int read_w_ci = blockIdx.y * (BLOCK_SIZE * GROUP_CI) + threadIdx.x;
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
int in_channel = read_w_ci + i * BLOCK_SIZE;
int out_channel = k + threadIdx.y;
if (in_channel < CI_div_G) {
int w_offset = (blockIdx.z * CO_div_G + out_channel) * CI_div_G + in_channel;
if (check_co) {
blockW[i][threadIdx.y][threadIdx.x] = out_channel < CO_div_G ? weight[w_offset] : 0;
blockB[i][threadIdx.y][threadIdx.x] = out_channel < CO_div_G ? bias[w_offset] : 0;
}
else {
blockW[i][threadIdx.y][threadIdx.x] = weight[w_offset];
blockB[i][threadIdx.y][threadIdx.x] = bias[w_offset];
}
}
}
__syncthreads();
#pragma unroll
for (int t = 0; t < BLOCK_SIZE; t++) {
#pragma unroll
for (int i = 0; i < GROUP_B; i++) {
#pragma unroll
for (int j = 0; j < GROUP_CI; j++) {
float g = has_hw ? blockG[i][t][threadIdx.x] : blockG[i][threadIdx.y][t];
float w = blockW[j][t][has_hw ? threadIdx.y : threadIdx.x];
float bias = blockB[j][t][has_hw ? threadIdx.y : threadIdx.x];
float o = has_hw ? blockO[i][t][threadIdx.x] : blockO[i][threadIdx.y][t];
res[i][j] += update_backward(x[i][j], w, bias, o, g) * w;
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < GROUP_B; i++) {
if (b[i] < B) {
int write_ci = blockIdx.y * (BLOCK_SIZE * GROUP_CI) + (has_hw ? threadIdx.y : threadIdx.x);
#pragma unroll
for (int j = 0; j < GROUP_CI; j++) {
int channel = write_ci + j * BLOCK_SIZE;
bool success;
int offset = i_conf.get_offset(G, CI_div_G, HW, b[i], blockIdx.z, channel, hw[i], success);
if (success && channel < CI_div_G) {
if (!is_conv(i_conf)) grad_input[offset] = res[i][j]; // Fully connected
else atomicAdd(&grad_input[offset], res[i][j]);
}
}
}
}
}
// A simple testing program, which tests the two instantiations.
int main() {
int B = 512, CI = 128, CO = 128, H = 16, W = 16;
float *grad_output, *input, *input_unfold, *weight, *bias, *output;
cudaMallocManaged(&grad_output, B * CO * H * W * sizeof(float));
cudaMallocManaged(&output, B * CO * H * W * sizeof(float));
cudaMallocManaged(&input, B * CI * H * W * sizeof(float));
cudaMallocManaged(&input_unfold, B * CI * 3 * 3 * H * W * sizeof(float));
cudaMallocManaged(&weight, CO * CI * 3 * 3 * sizeof(float));
cudaMallocManaged(&bias, CO * CI * 3 * 3 * sizeof(float));
dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
dim3 dimGrid((B * H * W - 1) / (BLOCK_SIZE * 4) + 1, (CI * 3 * 3 - 1) / (BLOCK_SIZE * 4) + 1, 1);
logsumexp_backward_input_kernel<4, 4, true, false><<<dimGrid, dimBlock>>>(
grad_output, input_unfold, weight, bias, output, B, CO, CI * 3 * 3, H * W, 1, input_unfold, FCInput{});
logsumexp_backward_input_kernel<4, 4, true, false><<<dimGrid, dimBlock>>>(
grad_output, input, weight, bias, output, B, CO, CI * 3 * 3, H * W, 1, input, ConvInput<3, 1, false>{H, W, 1, W});
}