Hi all,
I recently changed code like this:
for (int h = 1; h <= 4; h++) {
int idx_y = max_idx(h, threadIdx.y); // will be -1 for threads with y > n(h)
for (int w = 1 + (h == 1); w <= 4; w++) {
int idx_x = max_idx(w, threadIdx.x); // will be -1 for threads with x > n(w)
if (w == 1) {
// (h \geq 2) ==> (idx_y >= 12)
if (idx_y != -1) {
IndexInstruction instr = idx_instr[idx_y - 12];
max_v[idx_y][idx_x] =
max(max_v[instr.idx1][idx_x],
max_v[instr.idx2][idx_x]);
scan_inner(w, h, max_v[idx_y][idx_x]);
}
} else if (idx_x != -1 && idx_y != -1) {
// w \geq 2
IndexInstruction instr = idx_instr[idx_x - 12];
max_v[idx_y][idx_x] =
max(max_v[idx_y][instr.idx1],
max_v[idx_y][instr.idx2]);
scan_inner(w, h, max_v[idx_y][idx_x]);
}
if (w < 3) {
__syncthreads();
}
}
}
to this
int idx_x, idx_y, read_idx1, read_idx2;
INDEXINSTRUCTION_SET_LOCALS(instrs[tid + 0]);
if (tid < 72) {
max_v[idx_y][idx_x] = max(max_v[idx_y][read_idx1], max_v[idx_y][read_idx2]);
scan_inner(2, 1, max_v[idx_y][idx_x]);
}
__syncthreads();
INDEXINSTRUCTION_SET_LOCALS(instrs[tid + 72]);
if (tid < 48) {
max_v[idx_y][idx_x] = max(max_v[idx_y][read_idx1], max_v[idx_y][read_idx2]);
scan_inner(3, 1, max_v[idx_y][idx_x]);
}
else if (tid < 84) {
max_v[idx_y][idx_x] = max(max_v[idx_y][read_idx1], max_v[idx_y][read_idx2]);
scan_inner(4, 1, max_v[idx_y][idx_x]);
}
else if (tid < 156) {
max_v[idx_y][idx_x] = max(max_v[read_idx1][idx_x], max_v[read_idx2][idx_x]);
scan_inner(1, 2, max_v[idx_y][idx_x]);
}
__syncthreads();
INDEXINSTRUCTION_SET_LOCALS(instrs[tid + 228]);
if (tid < 36) {
max_v[idx_y][idx_x] = max(max_v[idx_y][read_idx1], max_v[idx_y][read_idx2]);
scan_inner(2, 2, max_v[idx_y][idx_x]);
}
where most of the above code is generated (jinja2 + some python backend code)
I was hoping it would keep warps together better since I had a 12x12 block size (now 192 to support the “tid < 156”), but it doesn’t seem to be doing so. Does anyone have ideas?
Thanks,
Nicholas
edit: sorry for the sloppy original post (I changed it a few seconds after; accidentally hit the wrong button). I’m looking at the divergent_branch counter in the profiler. INDEXINSTRUCTION_SET_LOCALS is pretty simple:
struct IndexInstruction {
uint32_t v_0;
__device__ inline void get(int *idx_x, int *idx_y, int *read_idx1, int *read_idx2) {
uint32_t tmp_0 = v_0;
*read_idx2 = (tmp_0 & 0x1f);
tmp_0 >>= 5;
*read_idx1 = (tmp_0 & 0x1f);
tmp_0 >>= 5;
*idx_y = (tmp_0 & 0x1f);
tmp_0 >>= 5;
*idx_x = (tmp_0 & 0x1f);
}
#define INDEXINSTRUCTION_SET_LOCALS(name) \
(name).get(&idx_x, &idx_y, &read_idx1, &read_idx2)
};