I wrote a cuda kernel, compiled it using nvcc, and tested the performance using Nsight Compute. At the beginning my cuda version is 11.1, and I thought I may upgrade my cuda to the latest version 11.6 to further improve the performance. However, after upgrading my cuda and re-compile the kernel (with the same code without change), Nsight Compute shows that the performance drops severly (about 25% more time cost). I can’t understand it at all why the newer compiler results in degraded performance. Below are the profile results for different compilers. My GPU is RTX 3090 (Compute Capability 8.6). The driver version is 510.47.03.
CUDA 11.1
Section: GPU Speed Of Light Throughput
---------------------------------------------------------------------- --------------- ------------------------------
DRAM Frequency cycle/nsecond 9.48
SM Frequency cycle/nsecond 1.39
Elapsed Cycles cycle 41,671,978
Memory [%] % 56.99
DRAM Throughput % 5.03
Duration msecond 29.90
L1/TEX Cache Throughput % 59.47
L2 Cache Throughput % 14.40
SM Active Cycles cycle 39,924,119.98
Compute (SM) [%] % 74.59
---------------------------------------------------------------------- --------------- ------------------------------
WRN Compute is more heavily utilized than Memory: Look at the Compute Workload Analysis section to see what the
compute pipelines are spending their time doing. Also, consider whether any computation is redundant and
could be reduced or moved to look-up tables.
INF The ratio of peak float (fp32) to double (fp64) performance on this device is 64:1. The kernel achieved 32%
of this device's fp32 peak performance and 0% of its fp64 peak performance. See the Kernel Profiling Guide
(https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#roofline) for mode details on roofline
analysis.
Section: Compute Workload Analysis
---------------------------------------------------------------------- --------------- ------------------------------
Executed Ipc Active inst/cycle 3.11
Executed Ipc Elapsed inst/cycle 2.98
Issue Slots Busy % 77.83
Issued Ipc Active inst/cycle 3.11
SM Busy % 77.83
---------------------------------------------------------------------- --------------- ------------------------------
No compute pipeline is over-utilized.
Section: Memory Workload Analysis
---------------------------------------------------------------------- --------------- ------------------------------
Memory Throughput Gbyte/second 45.77
Mem Busy % 29.13
Max Bandwidth % 56.99
L1/TEX Hit Rate % 0.01
L2 Compression Success Rate % 0
L2 Compression Ratio 0
L2 Hit Rate % 88.32
Mem Pipes Busy % 56.99
---------------------------------------------------------------------- --------------- ------------------------------
Section: Scheduler Statistics
---------------------------------------------------------------------- --------------- ------------------------------
One or More Eligible % 77.83
Issued Warp Per Scheduler 0.78
No Eligible % 22.17
Active Warps Per Scheduler warp 3.84
Eligible Warps Per Scheduler warp 2.13
---------------------------------------------------------------------- --------------- ------------------------------
Section: Warp State Statistics
---------------------------------------------------------------------- --------------- ------------------------------
Warp Cycles Per Issued Instruction cycle 4.94
Warp Cycles Per Executed Instruction cycle 4.94
Avg. Active Threads Per Warp 32
Avg. Not Predicated Off Threads Per Warp 26.59
---------------------------------------------------------------------- --------------- ------------------------------
WRN On average, each warp of this kernel spends 1.7 cycles being stalled due to not being selected by the
scheduler. This represents about 35.1% of the total average of 4.9 cycles between issuing two instructions.
Not selected warps are eligible warps that were not picked by the scheduler to issue that cycle as another
warp was selected. A high number of not selected warps typically means you have sufficient warps to cover
warp latencies and you may consider reducing the number of active warps to possibly increase cache coherence
and data locality.
----- --------------------------------------------------------------------------------------------------------------
INF Check the Source Counters section for the top stall locations in your source based on sampling data. The
Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#sampling) provides
more details on each stall reason.
Section: Instruction Statistics
---------------------------------------------------------------------- --------------- ------------------------------
Avg. Executed Instructions Per Scheduler inst 31,074,586.93
Executed Instructions inst 10,192,464,512
Avg. Issued Instructions Per Scheduler inst 31,074,644.35
Issued Instructions inst 10,192,483,346
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel executes 3715891200 fused and 2893749760 non-fused FP32 instructions. By converting pairs of
non-fused instructions to their fused
(https://docs.nvidia.com/cuda/floating-point/#cuda-and-floating-point), higher-throughput equivalent, the
achieved FP32 performance could be increased by up to 22% (relative to its current performance). Check the
Source page to identify where this kernel executes FP32 instructions.
Section: Launch Statistics
---------------------------------------------------------------------- --------------- ------------------------------
Block Size 256
Function Cache Configuration cudaFuncCachePreferNone
Grid Size 864
Registers Per Thread register/thread 127
Shared Memory Configuration Size Kbyte 102.40
Driver Shared Memory Per Block Kbyte/block 1.02
Dynamic Shared Memory Per Block byte/block 0
Static Shared Memory Per Block Kbyte/block 39.68
Threads thread 221,184
Waves Per SM 5.27
---------------------------------------------------------------------- --------------- ------------------------------
Section: Occupancy
---------------------------------------------------------------------- --------------- ------------------------------
Block Limit SM block 16
Block Limit Registers block 2
Block Limit Shared Mem block 2
Block Limit Warps block 6
Theoretical Active Warps per SM warp 16
Theoretical Occupancy % 33.33
Achieved Occupancy % 32.02
Achieved Active Warps Per SM warp 15.37
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel's theoretical occupancy (33.3%) is limited by the number of required registers This kernel's
theoretical occupancy (33.3%) is limited by the required amount of shared memory See the CUDA Best Practices
Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on
optimizing occupancy.
Section: Source Counters
---------------------------------------------------------------------- --------------- ------------------------------
Branch Instructions Ratio % 0.00
Branch Instructions inst 32,140,672
Branch Efficiency % 100
Avg. Divergent Branches 0
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel has uncoalesced shared accesses resulting in a total of 2322432 excessive wavefronts (0% of the
total 692586496 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source
locations. The CUDA Best Practices Guide
(https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#shared-memory-in-matrix-multiplication-c
-aa) has an example on optimizing shared memory accesses.
CUDA 11.6
Section: GPU Speed Of Light Throughput
---------------------------------------------------------------------- --------------- ------------------------------
DRAM Frequency cycle/nsecond 9.48
SM Frequency cycle/nsecond 1.39
Elapsed Cycles cycle 52,519,743
Memory [%] % 45.42
DRAM Throughput % 3.78
Duration msecond 37.70
L1/TEX Cache Throughput % 47.73
L2 Cache Throughput % 11.42
SM Active Cycles cycle 49,963,566.56
Compute (SM) [%] % 58.01
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel exhibits low compute throughput and memory bandwidth utilization relative to the peak performance
of this device. Achieved compute throughput and/or memory bandwidth below 60.0% of peak typically indicate
latency issues. Look at Scheduler Statistics and Warp State Statistics for potential reasons.
INF The ratio of peak float (fp32) to double (fp64) performance on this device is 64:1. The kernel achieved 25%
of this device's fp32 peak performance and 0% of its fp64 peak performance. See the Kernel Profiling Guide
(https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#roofline) for mode details on roofline
analysis.
Section: Compute Workload Analysis
---------------------------------------------------------------------- --------------- ------------------------------
Executed Ipc Active inst/cycle 2.44
Executed Ipc Elapsed inst/cycle 2.32
Issue Slots Busy % 60.97
Issued Ipc Active inst/cycle 2.44
SM Busy % 60.97
---------------------------------------------------------------------- --------------- ------------------------------
No compute pipeline is over-utilized.
Section: Memory Workload Analysis
---------------------------------------------------------------------- --------------- ------------------------------
Memory Throughput Gbyte/second 34.43
Mem Busy % 23.22
Max Bandwidth % 45.42
L1/TEX Hit Rate % 0.05
L2 Compression Success Rate % 0
L2 Compression Ratio 0
L2 Hit Rate % 88.90
Mem Pipes Busy % 45.42
---------------------------------------------------------------------- --------------- ------------------------------
Section: Scheduler Statistics
---------------------------------------------------------------------- --------------- ------------------------------
One or More Eligible % 60.95
Issued Warp Per Scheduler 0.61
No Eligible % 39.05
Active Warps Per Scheduler warp 3.84
Eligible Warps Per Scheduler warp 1.18
---------------------------------------------------------------------- --------------- ------------------------------
Section: Warp State Statistics
---------------------------------------------------------------------- --------------- ------------------------------
Warp Cycles Per Issued Instruction cycle 6.30
Warp Cycles Per Executed Instruction cycle 6.30
Avg. Active Threads Per Warp 32
Avg. Not Predicated Off Threads Per Warp 26.47
---------------------------------------------------------------------- --------------- ------------------------------
Section: Instruction Statistics
---------------------------------------------------------------------- --------------- ------------------------------
Avg. Executed Instructions Per Scheduler inst 30,458,590.83
Executed Instructions inst 9,990,417,792
Avg. Issued Instructions Per Scheduler inst 30,462,031.41
Issued Instructions inst 9,991,546,301
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel executes 3715891200 fused and 2893749760 non-fused FP32 instructions. By converting pairs of
non-fused instructions to their fused
(https://docs.nvidia.com/cuda/floating-point/#cuda-and-floating-point), higher-throughput equivalent, the
achieved FP32 performance could be increased by up to 22% (relative to its current performance). Check the
Source page to identify where this kernel executes FP32 instructions.
Section: Launch Statistics
---------------------------------------------------------------------- --------------- ------------------------------
Block Size 256
Function Cache Configuration cudaFuncCachePreferNone
Grid Size 864
Registers Per Thread register/thread 128
Shared Memory Configuration Size Kbyte 102.40
Driver Shared Memory Per Block Kbyte/block 1.02
Dynamic Shared Memory Per Block byte/block 0
Static Shared Memory Per Block Kbyte/block 39.68
Threads thread 221,184
Waves Per SM 5.27
---------------------------------------------------------------------- --------------- ------------------------------
Section: Occupancy
---------------------------------------------------------------------- --------------- ------------------------------
Block Limit SM block 16
Block Limit Registers block 2
Block Limit Shared Mem block 2
Block Limit Warps block 6
Theoretical Active Warps per SM warp 16
Theoretical Occupancy % 33.33
Achieved Occupancy % 32.02
Achieved Active Warps Per SM warp 15.37
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel's theoretical occupancy (33.3%) is limited by the required amount of shared memory This kernel's
theoretical occupancy (33.3%) is limited by the number of required registers See the CUDA Best Practices
Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on
optimizing occupancy.
Section: Source Counters
---------------------------------------------------------------------- --------------- ------------------------------
Branch Instructions Ratio % 0.00
Branch Instructions inst 36,557,440
Branch Efficiency % 100
Avg. Divergent Branches 0
---------------------------------------------------------------------- --------------- ------------------------------
WRN This kernel has uncoalesced shared accesses resulting in a total of 2322432 excessive wavefronts (0% of the
total 692586496 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source
locations. The CUDA Best Practices Guide
(https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#shared-memory-in-matrix-multiplication-c
-aa) has an example on optimizing shared memory accesses.
I can also provide a minimal reproducible code (although the code is a bit complicated). The compiler command is as follows: /usr/local/cuda-11.1/bin/nvcc -arch compute_86 -code sm_86 tmp.cu
.
const int WARP_SIZE = 32;
#define CONST_PTR const float* __restrict__
#define PTR float* __restrict__
#define ASSERTION_FAILS "static assertion fails"
#define INF 1e10f
__device__ __forceinline__ float fast_exp2(float x) { return exp2f(x); }
__device__ __forceinline__
float2 update_forward(float mean, float diff, float w, float wp, float b) {
float t = mean * w + b;
return { fast_exp2(-wp * diff + t), fast_exp2(wp * diff + t) };
}
__device__ __forceinline__
float2 update_forward(float mean, float diff, float w, float wp, float b, float maxL, float maxU) {
float t = mean * w + b;
return { fast_exp2(t - (wp * diff + maxL)), fast_exp2(wp * diff + t - maxU) };
}
template <typename... T> __device__ __forceinline__
float2 update_backward(float mean, float diff, float w, float wp, float b, float gL, float gU, T... o) {
float2 t = update_forward(mean, diff, w, wp, b, o...);
float2 t2 = { t.x * gL, t.y * gU };
return { t2.x + t2.y, t2.y - t2.x };
}
// threadIdx.x: WARP_SIZE / GROUP_BCO; threadIdx.y: BLOCK_CI_SIZE
template <int GROUP_CI, int GROUP_BCO, int GROUP_BCO_SUB, int BLOCK_CI_SIZE, bool has_hw, bool stable> __global__
__launch_bounds__(BLOCK_CI_SIZE * WARP_SIZE / GROUP_BCO)
void logsumexp_bound_backward_input_weight_kernel(CONST_PTR grad_outputL, CONST_PTR grad_outputU, CONST_PTR inputL,
CONST_PTR inputU, CONST_PTR weight, CONST_PTR bias, CONST_PTR outputL,
CONST_PTR outputU, int B, int CO_div_G, int CI_div_G, int HW, int G,
PTR grad_inputL, PTR grad_inputU, PTR grad_weight, PTR grad_bias) {
const int THREAD_X = WARP_SIZE / GROUP_BCO;
int threadIdx_low, threadIdx_high;
if (THREAD_X == BLOCK_CI_SIZE) {
threadIdx_low = threadIdx.x;
threadIdx_high = threadIdx.y;
}
else {
int id = threadIdx.y * THREAD_X + threadIdx.x;
threadIdx_low = id & (BLOCK_CI_SIZE - 1);
threadIdx_high = id / BLOCK_CI_SIZE;
}
if (!has_hw) HW = 1;
int b_hw = blockIdx.x * WARP_SIZE + (has_hw ? threadIdx.x : threadIdx_high);
int b[GROUP_BCO], hw[GROUP_BCO];
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
int b_hw_j = b_hw + j * THREAD_X;
if (has_hw) { b[j] = b_hw_j / HW; hw[j] = b_hw_j % HW; }
else { b[j] = b_hw_j; hw[j] = 0; }
}
int write_ci = blockIdx.y * (BLOCK_CI_SIZE * GROUP_CI) + (has_hw ? threadIdx.y : threadIdx_low);
int read_w_ci = blockIdx.y * (BLOCK_CI_SIZE * GROUP_CI) + threadIdx_low;
static_assert(!(GROUP_BCO & (GROUP_BCO - 1)) && !(GROUP_BCO_SUB & (GROUP_BCO_SUB - 1)), ASSERTION_FAILS);
__shared__ float blockOL[WARP_SIZE][WARP_SIZE]; // has_hw ? CO * B : B * CO
__shared__ float blockOU[WARP_SIZE][WARP_SIZE]; // has_hw ? CO * B : B * CO
__shared__ float blockGL[WARP_SIZE][WARP_SIZE]; // has_hw ? CO * B : B * CO
__shared__ float blockGU[WARP_SIZE][WARP_SIZE]; // has_hw ? CO * B : B * CO
__shared__ float blockW[GROUP_CI][THREAD_X][BLOCK_CI_SIZE + WARP_SIZE / THREAD_X]; // CO * CI
__shared__ float blockB[GROUP_CI][THREAD_X][BLOCK_CI_SIZE + WARP_SIZE / THREAD_X]; // CO * CI
__shared__ float blockIM[GROUP_CI][GROUP_BCO][BLOCK_CI_SIZE][THREAD_X]; // CI * B
__shared__ float blockID[GROUP_CI][GROUP_BCO][BLOCK_CI_SIZE][THREAD_X]; // CI * B
// To reduce shared memory cost, does not avoid bank conflict for blockIM and blockID.
// Warning: bank conflict for blockW, blockB iff BLOCK_CI_SIZE < 32
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
int channel = write_ci + i * BLOCK_CI_SIZE;
int offset = ((b[j] * G + blockIdx.z) * CI_div_G + channel) * HW + hw[j];
float xL = b[j] < B && channel < CI_div_G ? inputL[offset] : 0;
float xU = b[j] < B && channel < CI_div_G ? inputU[offset] : 0;
float tmp1 = (xL + xU) * 0.5, tmp2 = (xU - xL) * 0.5;
if (has_hw) {
blockIM[i][j][threadIdx.y][threadIdx.x] = tmp1;
blockID[i][j][threadIdx.y][threadIdx.x] = tmp2;
}
else {
blockIM[i][j][threadIdx_low][threadIdx_high] = tmp1;
blockID[i][j][threadIdx_low][threadIdx_high] = tmp2;
}
}
}
float grad_x1[GROUP_CI][GROUP_BCO][GROUP_BCO_SUB], grad_x2[GROUP_CI][GROUP_BCO][GROUP_BCO_SUB];
#pragma unroll
for (int i = 0; i < GROUP_CI; i++)
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
#pragma unroll
for (int s = 0; s < GROUP_BCO_SUB; s++) {
grad_x1[i][j][s] = grad_x2[i][j][s] = 0;
}
}
for (int k = 0; k < CO_div_G; k += WARP_SIZE) {
#pragma unroll
for (int i = 0; i < WARP_SIZE; i += BLOCK_CI_SIZE * THREAD_X / WARP_SIZE) {
int id = threadIdx.y * THREAD_X + threadIdx.x;
int threadIdx32x = id & (WARP_SIZE - 1);
int threadIdx32y = id / WARP_SIZE;
int read_output_co = k + (has_hw ? threadIdx32y + i: threadIdx32x);
int read_output_b_hw = blockIdx.x * WARP_SIZE + (has_hw ? threadIdx32x : threadIdx32y + i);
int read_b = has_hw ? read_output_b_hw / HW : read_output_b_hw;
int read_hw = has_hw ? read_output_b_hw % HW : 0;
int offset = ((read_b * G + blockIdx.z) * CO_div_G + read_output_co) * HW + read_hw;
float value_oL = read_b < B && read_output_co < CO_div_G ? outputL[offset] : INF;
float value_oU = read_b < B && read_output_co < CO_div_G ? outputU[offset] : INF;
float value_gL = read_b < B && read_output_co < CO_div_G ? grad_outputL[offset] : 0;
float value_gU = read_b < B && read_output_co < CO_div_G ? grad_outputU[offset] : 0;
if (stable) {
blockOL[threadIdx32y + i][threadIdx32x] = value_oL;
blockOU[threadIdx32y + i][threadIdx32x] = value_oU;
blockGL[threadIdx32y + i][threadIdx32x] = value_gL;
blockGU[threadIdx32y + i][threadIdx32x] = value_gU;
}
else {
blockGL[threadIdx32y + i][threadIdx32x] = value_gL * fast_exp2(-value_oL);
blockGU[threadIdx32y + i][threadIdx32x] = value_gU * fast_exp2(-value_oU);
}
}
for (int j_co = 0; j_co < GROUP_BCO; j_co++) {
float grad_w[GROUP_CI], grad_b[GROUP_CI];
#pragma unroll
for (int i = 0; i < GROUP_CI; i++)
grad_w[i] = grad_b[i] = 0;
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
int in_channel = read_w_ci + i * BLOCK_CI_SIZE;
int out_channel = k + threadIdx_high + j_co * THREAD_X;
int w_offset = (blockIdx.z * CO_div_G + out_channel) * CI_div_G + in_channel;
bool valid = in_channel < CI_div_G && out_channel < CO_div_G;
blockW[i][threadIdx_high][threadIdx_low] = valid ? weight[w_offset] : 0;
blockB[i][threadIdx_high][threadIdx_low] = valid ? bias[w_offset] : 0;
}
__syncthreads();
#pragma unroll
for (int t = 0; t < THREAD_X / GROUP_BCO_SUB; t++) {
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
float sum_res_w1, sum_res_w2, sum_res_b;
int co_ = threadIdx.x ^ t, co = co_ ^ (j_co * THREAD_X);
float w = blockW[i][co_][threadIdx.y], wp = abs(w);
float bias = blockB[i][co_][threadIdx.y];
#pragma unroll
for (int j_b = 0; j_b < GROUP_BCO; j_b++) {
#pragma unroll
for (int s = 0; s < GROUP_BCO_SUB; s++) {
int b_ = threadIdx.x ^ (THREAD_X / GROUP_BCO_SUB * s), b = b_ ^ (j_b * THREAD_X);
float x_mean = blockIM[i][j_b][threadIdx.y][b_];
float x_diff = blockID[i][j_b][threadIdx.y][b_];
float gL = has_hw ? blockGL[co][b] : blockGL[b][co];
float gU = has_hw ? blockGU[co][b] : blockGU[b][co];
float2 res_pair;
if (stable) {
float oL = has_hw ? blockOL[co][b] : blockOL[b][co];
float oU = has_hw ? blockOU[co][b] : blockOU[b][co];
res_pair = update_backward(x_mean, x_diff, w, wp, bias, gL, gU, oL, oU);
}
else res_pair = update_backward(x_mean, x_diff, w, wp, bias, gL, gU);
if (j_b == 0 && s == 0) {
sum_res_w1 = res_pair.x * x_mean;
sum_res_w2 = res_pair.y * x_diff;
sum_res_b = res_pair.x;
}
else {
sum_res_w1 += res_pair.x * x_mean;
sum_res_w2 += res_pair.y * x_diff;
sum_res_b += res_pair.x;
}
grad_x1[i][j_b][s] += res_pair.x * w;
grad_x2[i][j_b][s] += res_pair.y * wp;
}
}
float sgn = w > 0 ? 1.0f : -1.0f;
float sum_res_w = __shfl_xor_sync(0xffffffff, sum_res_w1 + sgn * sum_res_w2, t); // grad at co=threadIdx.x
sum_res_b = __shfl_xor_sync(0xffffffff, sum_res_b, t); // grad at co=threadIdx.x
grad_w[i] += sum_res_w;
grad_b[i] += sum_res_b;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
blockW[i][threadIdx.x][threadIdx.y] = grad_w[i];
blockB[i][threadIdx.x][threadIdx.y] = grad_b[i];
}
__syncthreads();
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
int in_channel = read_w_ci + i * BLOCK_CI_SIZE;
int out_channel = k + threadIdx_high + j_co * THREAD_X;
if (in_channel < CI_div_G && out_channel < CO_div_G) {
int w_offset = (blockIdx.z * CO_div_G + out_channel) * CI_div_G + in_channel;
atomicAdd(&grad_weight[w_offset], blockW[i][threadIdx_high][threadIdx_low]);
atomicAdd(&grad_bias[w_offset], blockB[i][threadIdx_high][threadIdx_low]);
}
}
__syncthreads();
}
}
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
#pragma unroll
for (int s = 1; s < GROUP_BCO_SUB; s++) {
grad_x1[i][j][0] += __shfl_xor_sync(0xffffffff, grad_x1[i][j][s], THREAD_X / GROUP_BCO_SUB * s);
grad_x2[i][j][0] += __shfl_xor_sync(0xffffffff, grad_x2[i][j][s], THREAD_X / GROUP_BCO_SUB * s);
}
}
}
if (!has_hw) {
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
blockIM[i][j][threadIdx.y][threadIdx.x] = grad_x1[i][j][0];
blockID[i][j][threadIdx.y][threadIdx.x] = grad_x2[i][j][0];
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
grad_x1[i][j][0] = blockIM[i][j][threadIdx_low][threadIdx_high];
grad_x2[i][j][0] = blockID[i][j][threadIdx_low][threadIdx_high];
}
}
}
#pragma unroll
for (int j = 0; j < GROUP_BCO; j++) {
if (b[j] < B) {
#pragma unroll
for (int i = 0; i < GROUP_CI; i++) {
int channel = write_ci + i * BLOCK_CI_SIZE;
int offset = ((b[j] * G + blockIdx.z) * CI_div_G + channel) * HW + hw[j];
if (channel < CI_div_G) {
grad_inputL[offset] = (grad_x1[i][j][0] - grad_x2[i][j][0]) * 0.5f;
grad_inputU[offset] = (grad_x1[i][j][0] + grad_x2[i][j][0]) * 0.5f;
}
}
}
}
}
void backward_input_weight(const float* grad_outputL, const float* grad_outputU, const float* inputL,
const float* inputU, const float* weight, const float* bias,
const float* outputL, const float* outputU, int B, int CO, int CI, int G, int HW,
float* grad_inputL, float* grad_inputU, float* grad_weight, float* grad_bias) {
const int GROUP_CI = 3;
const int GROUP_BCO = 4;
const int GROUP_BCO_SUB = 1;
const int BLOCK_CI_SIZE = 32;
int CI_div_G = CI / G, CO_div_G = CO / G;
dim3 dimBlock(WARP_SIZE / GROUP_BCO, BLOCK_CI_SIZE);
dim3 dimGrid((B * HW - 1) / WARP_SIZE + 1, (CI_div_G - 1) / (BLOCK_CI_SIZE * GROUP_CI) + 1, G);
cudaMemset(grad_weight, 0, CO * CI_div_G * sizeof(float));
cudaMemset(grad_bias, 0, CO * CI_div_G * sizeof(float));
logsumexp_bound_backward_input_weight_kernel<GROUP_CI, GROUP_BCO, GROUP_BCO_SUB, BLOCK_CI_SIZE, false, false><<<dimGrid, dimBlock>>>(
grad_outputL, grad_outputU, inputL, inputU, weight, bias, outputL, outputU, B, CO_div_G, CI_div_G, HW, G,
grad_inputL, grad_inputU, grad_weight, grad_bias);
}
void test(int B, int CI, int CO) {
int HW = 1;
float *y_lower = new float[B * CO * HW];
float *y_upper = new float[B * CO * HW];
float *x_lower = new float[B * CI * HW];
float *x_upper = new float[B * CI * HW];
float *weight = new float[CO * CI];
float *bias = new float[CO * CI];
float *grad_y_lower = new float[B * CO * HW];
float *grad_y_upper = new float[B * CO * HW];
for (int i = 0; i < B * CI * HW; i++) {
float mean = ((float)rand() - 0.5) / RAND_MAX;
float diff = (float)rand() / RAND_MAX;
x_lower[i] = mean - diff;
x_upper[i] = mean + diff;
}
for (int i = 0; i < CO * CI; i++)
weight[i] = ((float)rand() - 0.5) / RAND_MAX;
for (int i = 0; i < CO * CI; i++)
bias[i] = ((float)rand() - 0.5) / RAND_MAX;
for (int i = 0; i < B * CO * HW; i++) {
y_lower[i] = ((float)rand() - 0.5) / RAND_MAX;
y_upper[i] = ((float)rand() - 0.5) / RAND_MAX;
grad_y_lower[i] = ((float)rand() - 0.5) / RAND_MAX;
grad_y_upper[i] = ((float)rand() - 0.5) / RAND_MAX;
}
float *x_lower_cuda, *x_upper_cuda, *weight_cuda, *bias_cuda, *y_lower_cuda, *y_upper_cuda,
*grad_y_lower_cuda, *grad_y_upper_cuda, *grad_x_lower_cuda, *grad_x_upper_cuda, *grad_weight_cuda, *grad_bias_cuda;
cudaMallocManaged(&y_lower_cuda, B * CO * HW * sizeof(float));
cudaMallocManaged(&y_upper_cuda, B * CO * HW * sizeof(float));
cudaMallocManaged(&x_lower_cuda, B * CI * HW * sizeof(float));
cudaMallocManaged(&x_upper_cuda, B * CI * HW * sizeof(float));
cudaMallocManaged(&weight_cuda, CO * CI * sizeof(float));
cudaMallocManaged(&bias_cuda, CO * CI * sizeof(float));
cudaMallocManaged(&grad_y_lower_cuda, B * CO * HW * sizeof(float));
cudaMallocManaged(&grad_y_upper_cuda, B * CO * HW * sizeof(float));
cudaMallocManaged(&grad_x_lower_cuda, B * CI * HW * sizeof(float));
cudaMallocManaged(&grad_x_upper_cuda, B * CI * HW * sizeof(float));
cudaMallocManaged(&grad_weight_cuda, CO * CI * sizeof(float));
cudaMallocManaged(&grad_bias_cuda, CO * CI * sizeof(float));
cudaMemcpy(x_lower_cuda, x_lower, B * CI * HW * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(x_upper_cuda, x_upper, B * CI * HW * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(y_lower_cuda, y_lower, B * CO * HW * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(y_upper_cuda, y_upper, B * CO * HW * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(weight_cuda, weight, CO * CI * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(bias_cuda, bias, CO * CI * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad_y_lower_cuda, grad_y_lower, B * CI * HW * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad_y_upper_cuda, grad_y_upper, B * CI * HW * sizeof(float), cudaMemcpyHostToDevice);
backward_input_weight(grad_y_lower_cuda, grad_y_upper_cuda, x_lower_cuda, x_upper_cuda,
weight_cuda, bias_cuda, y_lower_cuda, y_upper_cuda, B, CO, CI, 1, HW, grad_x_lower_cuda, grad_x_upper_cuda, grad_weight_cuda, grad_bias_cuda);
cudaDeviceSynchronize();
cudaFree(y_lower_cuda);
cudaFree(y_upper_cuda);
cudaFree(x_lower_cuda);
cudaFree(x_upper_cuda);
cudaFree(weight_cuda);
cudaFree(bias_cuda);
cudaFree(grad_x_lower_cuda);
cudaFree(grad_x_upper_cuda);
cudaFree(grad_weight_cuda);
cudaFree(grad_y_lower_cuda);
cudaFree(grad_y_upper_cuda);
cudaFree(grad_bias_cuda);
delete[] y_lower;
delete[] y_upper;
delete[] x_lower;
delete[] x_upper;
delete[] weight;
delete[] bias;
delete[] grad_y_lower;
delete[] grad_y_upper;
}
int main() {
test(512, 5120, 5120);
}