Question about the final exercise in "Multidimensional Grids and Shared Memory for CUDA Python with Numba"

So the final exercise requires that i do matrix multiplication with cuda. I have ran this code and got
“Your code did not produce the correct output. +0 pts”
Im not sure why, but it looks like it run correctly. Ive already consulted

https://forums.developer.nvidia.com/t/need-help-in-implementing-matrix-multiplication-using-shared-memory-in-numba/111461

to no avail. I’m not sure how to approach this next, i thought i did everything according to robert.

Here is my the code

import numpy as np
from numba import cuda, types

Leave the values in this cell alone

M = 128
N = 32

Input vectors of MxN and NxM dimensions

a = np.arange(MN).reshape(M,N).astype(np.int32)
b = np.arange(M
N).reshape(N,M).astype(np.int32)
c = np.zeros((M, M)).astype(np.int32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c)

NxN threads per block, in 2 dimensions

block_size = (N,N)

MxM/NxN blocks per grid, in 2 dimensions

grid_size = (int(M/N),int(M/N))

@cuda.jit
def mm_shared(A, B, C):
tpb = N
sA = cuda.shared.array(shape=(tpb, tpb), dtype=types.float32)
sB = cuda.shared.array(shape=(tpb, tpb), dtype=types.float32)

tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x
by = cuda.blockIdx.y
bw = cuda.blockDim.x
bh = cuda.blockDim.y

bpg = cuda.gridDim.x

o = bpg * tpb
x = tx + bx * bw
y = ty + by * bh

acc = 0.
for i in range(bpg):
    if x < o and y < o:
        sA[ty, tx] = A[y, tx + i * tpb]
        sB[ty, tx] = B[ty + i * tpb, x]

    cuda.syncthreads()

    if x < o and y < o:
        for j in range(tpb):
            acc += sA[ty, j] * sB[j, tx]

    cuda.syncthreads()

if x < o and y < o:
    C[y, x] = acc

There’s no need to update this kernel launch

mm_shared[grid_size, block_size](d_a, d_b, d_c)

Do not modify the contents in this cell

from numpy import testing
solution = a@b
output = d_c.copy_to_host()

This assertion will fail until you correctly update the kernel above.

testing.assert_array_equal(output, solution)

(I’m sorry I don’t know how to do the formatting correctly)