Why does this implementation of argmax with Numba CUDA return the wrong result 0.01% of the time?

I’ve made a simple kernel implementing argmax using Numba (code below). It returns the correct result about 99.99% of the time, and a wrong one about 0.01% of the time.
I’ve tried to debug and the closest I got is that sometimes what should be the max value gets overwritten by a thread that definitely should not be writing anything at that moment. Which would be explainable if the syncthreads didn’t work, but I think my usage of it is correct - I syncthreads before and after any reads and writes to the shared memory, and none of them are in conditionals that could resolve differently for different threads. What am I missing?

from numba import cuda
import math

def argmax2(arr, out):
    n_actual_threads = cuda.blockDim.x
    max_virtual_threads = len(arr)
    argmax_value_scratchpad = cuda.shared.array(0, dtype='float32')[:max_virtual_threads]
    argmax_index_scratchpad = cuda.shared.array(0, dtype='int32')[max_virtual_threads:]
    n_virtual_threads = max_virtual_threads
    for thread_batch in range(math.ceil(n_virtual_threads / n_actual_threads)):
        thread_idx = thread_batch * n_actual_threads + cuda.threadIdx.x
        if thread_idx < n_virtual_threads:
            argmax_value_scratchpad[thread_idx] = arr[thread_idx]
            argmax_index_scratchpad[thread_idx] = thread_idx
    scratchpad_length = len(arr)
    while scratchpad_length > 1:
        new_length = math.ceil(scratchpad_length / 2)
        n_virtual_threads = new_length
        for thread_batch in range(math.ceil(n_virtual_threads / n_actual_threads)):
            thread_idx = thread_batch * n_actual_threads + cuda.threadIdx.x
            if thread_idx < n_virtual_threads:
                a = argmax_value_scratchpad[thread_idx*2]
                idx_a = argmax_index_scratchpad[thread_idx*2]
                b = argmax_value_scratchpad[thread_idx*2+1]
                idx_b = argmax_index_scratchpad[thread_idx*2+1]
                if scratchpad_length % 2 == 1 and thread_idx == new_length - 1:
                    # special case if the length is odd. ceil already added one more thread for it
                    # a is already set correctly for it too
                    # but b is missing so we gotta set it to the first value in the array
                    b = argmax_value_scratchpad[0]
                    idx_b = argmax_index_scratchpad[0]
                if a > b:
                    winning_value = a
                    winning_index = idx_a
                    winning_value = b
                    winning_index = idx_b
            if thread_idx < new_length:
                argmax_value_scratchpad[thread_idx] = winning_value
                argmax_index_scratchpad[thread_idx] = winning_index
        scratchpad_length = new_length
    out[0] = argmax_index_scratchpad[0]

I guess your intent is to run this code with only a single threadblock? It would probably help if you gave a complete example showing the kernel launch, the data set size you are passing, and indicate how to witness a failure example.

Yeah actually you saying that pointed me in the right direction. The code is working 100% of the time, it’s my testing method that’s failing 0.01% of the time, because, well, my sample size was big enough that sometimes a there were two or more random floats at the exact same max value, so two argmaxes gave two, both correct, but different, answers. Sorry to have wasted space on this forum for something so obvious. You can lock/delete this thread.