Hi, I am running a simple tiled matrix multiplication code on a V100:
#define TILE_WIDTH 32
__global__ void mm_2(float* A, float* B, float* C, int N){
// rows and column that the thread compute in the output C matrix
int cRow = threadIdx.y + blockDim.y * blockIdx.y;
int cCol = threadIdx.x + blockDim.x * blockIdx.x;
// rows and columns in shared memory that the thread writes to;
int sRow = threadIdx.y;
int sCol = threadIdx.x;
// rows and columns that the thread loads from A and B into shared memory
int gRow_A = cRow;
int gCol_A;
int gRow_B;
int gCol_B = cCol;
// at each tile we load a 32x32 submatrix
__shared__ float sA[TILE_WIDTH*TILE_WIDTH];
__shared__ float sB[TILE_WIDTH*TILE_WIDTH];
// load into shared memory
float sum = 0;
for (int kTile=0; kTile < N/TILE_WIDTH; kTile++){
gCol_A = kTile*TILE_WIDTH + threadIdx.x;
gRow_B = kTile*TILE_WIDTH + threadIdx.y;
sA[sRow * TILE_WIDTH + sCol] = A[gRow_A * N + gCol_A];
sB[sRow * TILE_WIDTH + sCol] = B[gRow_B * N + gCol_B];
__syncthreads();
// write from shared memory
for (int i=0; i<TILE_WIDTH; i++){
sum += sA[sRow*TILE_WIDTH + i] * sB[i*TILE_WIDTH + sCol];
}
__syncthreads();
}
C[cRow*N + cCol] = sum;
}
int main(){
int N = 1024;
dim3 dimGrid(32, 32);
dim3 dimBlock(32, 32);
//
// ...code to initialize N x N matrices A, B, so C is also N x N
//
mm_2<<<dimGrid, dimBlock>>>(thrust::raw_pointer_cast(dA.data()), thrust::raw_pointer_cast(dB.data()),
thrust::raw_pointer_cast(dC.data()), N);
hC = dC;
return 0;
}
We have a 1024x1024 matrix and each thread computes a single entry in C. At each tile we load a 32x32 submatrix of A, B into shared memory.
Shared store:
Relevant code:
// all threads in each warp will have the same sRow since we have dimBlock(32, 32)
int sRow = threadIdx.y;
int sCol = threadIdx.x;
sA[sRow * TILE_WIDTH + sCol] = A[gRow_A * N + gCol_A];
sB[sRow * TILE_WIDTH + sCol] = B[gRow_B * N + gCol_B];
From the above code, since sRow is the same for all threads in a warp, each warp loads 32 floats into 32 contiguous 4B memory addresses in shared memory (logically, each warp loads a single row of the matrix tile since TILE_WIDTH=32) → bank conflict-free
Shared load:
Relevant code:
sum += sA[sRow*TILE_WIDTH + i] * sB[i*TILE_WIDTH + sCol];
For sA, all threads in each warp has the same sRow → warp broadcast → bank conflict-free
For sB, by the same argument used in shared store we also see that its bank conflict-free
However, when I profile the kernel there are lots of bank conflicts
I am wondering where I am wrong in my explanation? Thanks!