Bank conflicts confusion for tiled matrix multiplication

Hi, I am running a simple tiled matrix multiplication code on a V100:

#define TILE_WIDTH 32
__global__ void mm_2(float* A, float* B, float* C, int N){
    // rows and column that the thread compute in the output C matrix
    int cRow = threadIdx.y + blockDim.y * blockIdx.y;
    int cCol = threadIdx.x + blockDim.x * blockIdx.x;
    // rows and columns in shared memory that the thread writes to;
    int sRow = threadIdx.y;
    int sCol = threadIdx.x;
    // rows and columns that the thread loads from A and B into shared memory
    int gRow_A = cRow;
    int gCol_A;
    int gRow_B;
    int gCol_B = cCol;
    // at each tile we load a 32x32 submatrix
    __shared__ float sA[TILE_WIDTH*TILE_WIDTH];
    __shared__ float sB[TILE_WIDTH*TILE_WIDTH];

    // load into shared memory
    float sum = 0;
    for (int kTile=0; kTile < N/TILE_WIDTH; kTile++){        
        gCol_A = kTile*TILE_WIDTH + threadIdx.x;
        gRow_B = kTile*TILE_WIDTH + threadIdx.y;
        sA[sRow * TILE_WIDTH + sCol] = A[gRow_A * N + gCol_A];
        sB[sRow * TILE_WIDTH + sCol] = B[gRow_B * N + gCol_B];
        __syncthreads();
        // write from shared memory
        for (int i=0; i<TILE_WIDTH; i++){
            sum += sA[sRow*TILE_WIDTH + i] * sB[i*TILE_WIDTH + sCol];
        }
        __syncthreads();
    }

    C[cRow*N + cCol] = sum;
}

int main(){
    int N = 1024;

    dim3 dimGrid(32, 32);
    dim3 dimBlock(32, 32);
    //
    // ...code to initialize N x N matrices A, B, so C is also N x N
    //
    mm_2<<<dimGrid, dimBlock>>>(thrust::raw_pointer_cast(dA.data()), thrust::raw_pointer_cast(dB.data()),
                                    thrust::raw_pointer_cast(dC.data()), N);
    hC = dC;

    return 0;
}

We have a 1024x1024 matrix and each thread computes a single entry in C. At each tile we load a 32x32 submatrix of A, B into shared memory.

Shared store:

Relevant code:

// all threads in each warp will have the same sRow since we have dimBlock(32, 32)
int sRow = threadIdx.y;
int sCol = threadIdx.x;
sA[sRow * TILE_WIDTH + sCol] = A[gRow_A * N + gCol_A];
sB[sRow * TILE_WIDTH + sCol] = B[gRow_B * N + gCol_B];

From the above code, since sRow is the same for all threads in a warp, each warp loads 32 floats into 32 contiguous 4B memory addresses in shared memory (logically, each warp loads a single row of the matrix tile since TILE_WIDTH=32) → bank conflict-free

Shared load:

Relevant code:

sum += sA[sRow*TILE_WIDTH + i] * sB[i*TILE_WIDTH + sCol];

For sA, all threads in each warp has the same sRow → warp broadcast → bank conflict-free
For sB, by the same argument used in shared store we also see that its bank conflict-free

However, when I profile the kernel there are lots of bank conflicts

I am wondering where I am wrong in my explanation? Thanks!

the metric in question is not necessarily reliable in all cases. There are various other questions on these forums that indicate this, here is one. If you have difficulty assessing bank conflicts using the profiler when you have a case that should not generate bank conflicts, my suggestion is to ask that question on the nsight compute forum.

1 Like