I’ve made a simple kernel implementing argmax using Numba (code below). It returns the correct result about 99.99% of the time, and a wrong one about 0.01% of the time.
I’ve tried to debug and the closest I got is that sometimes what should be the max value gets overwritten by a thread that definitely should not be writing anything at that moment. Which would be explainable if the syncthreads didn’t work, but I think my usage of it is correct - I syncthreads before and after any reads and writes to the shared memory, and none of them are in conditionals that could resolve differently for different threads. What am I missing?
from numba import cuda
import math
def argmax2(arr, out):
n_actual_threads = cuda.blockDim.x
max_virtual_threads = len(arr)
argmax_value_scratchpad = cuda.shared.array(0, dtype='float32')[:max_virtual_threads]
argmax_index_scratchpad = cuda.shared.array(0, dtype='int32')[max_virtual_threads:]
n_virtual_threads = max_virtual_threads
for thread_batch in range(math.ceil(n_virtual_threads / n_actual_threads)):
thread_idx = thread_batch * n_actual_threads + cuda.threadIdx.x
if thread_idx < n_virtual_threads:
argmax_value_scratchpad[thread_idx] = arr[thread_idx]
argmax_index_scratchpad[thread_idx] = thread_idx
scratchpad_length = len(arr)
while scratchpad_length > 1:
new_length = math.ceil(scratchpad_length / 2)
n_virtual_threads = new_length
for thread_batch in range(math.ceil(n_virtual_threads / n_actual_threads)):
thread_idx = thread_batch * n_actual_threads + cuda.threadIdx.x
if thread_idx < n_virtual_threads:
a = argmax_value_scratchpad[thread_idx*2]
idx_a = argmax_index_scratchpad[thread_idx*2]
b = argmax_value_scratchpad[thread_idx*2+1]
idx_b = argmax_index_scratchpad[thread_idx*2+1]
if scratchpad_length % 2 == 1 and thread_idx == new_length - 1:
# special case if the length is odd. ceil already added one more thread for it
# a is already set correctly for it too
# but b is missing so we gotta set it to the first value in the array
b = argmax_value_scratchpad[0]
idx_b = argmax_index_scratchpad[0]
if a > b:
winning_value = a
winning_index = idx_a
winning_value = b
winning_index = idx_b
if thread_idx < new_length:
argmax_value_scratchpad[thread_idx] = winning_value
argmax_index_scratchpad[thread_idx] = winning_index
scratchpad_length = new_length
out[0] = argmax_index_scratchpad[0]