Hi,
I wrote this kernel to multiply a system of 3x3 matrices together. Basically its used to combine left and right eigenvectors with eigenvalues. Each thread executes one element of each matrix. I felt pretty clever combining this into one operation, but I am not quite getting the performance I had hoped for. I have written very similar kernels that run 2 to 3 times faster and can’t figure out what is different about this one.
This kernel is not able to fully occupy each SM, but the kernels I am comparing it to use the same occupancy so that is not the problem. I am suspicious that the way I am handling my shared memory is causing bank conflicts and slightly slowing the code down with every memory access. Unfortunately, I am not too familiar with what causes a bank conflict or how to get around it. Does anyone on how to improve the speed of this kernel?
[codebox]// ---------------------- Modified StackMatMul Kernel --------------------------
// Kernel performs matrix multiplication for a stack of 3x3 matrices
// StackMatMul kernel modified to multiply three matrices, one of which is a diagonal
// Each block calculates product for 32 matrix sets
// Block dimensions are (3,3,32) with 288 threads/block
global void StackMatMul(float *a, float *b, float *c, float *diag, int MatNum)
{
// Allocate space for one matrix set to shared memory per thread per block
shared float As[3][3][32];
shared float Bs[3][3][32];
shared float diag_s[3][32];
int ix_i = threadIdx.x;
int iy_i = threadIdx.y;
int iz_i = threadIdx.z;
int idd_i = blockIdx.xblockDim.xblockDim.z + ix_i + 3*iz_i;
int idx_i = blockIdx.xblockDim.xblockDim.yblockDim.z + (ix_i + 3iy_i + 9*iz_i);
// Load a and b from global to shared memory on block
// Each thread loads one element
As[ix_i][iy_i][iz_i] = a[idx_i];
Bs[ix_i][iy_i][iz_i] = b[idx_i];
diag_s[ix_i][iz_i] = diag[idd_i];
__syncthreads();
int ix_o = threadIdx.x;
int iy_o = threadIdx.y;
int iz_o = threadIdx.z;
int idx_o = blockIdx.xblockDim.xblockDim.yblockDim.z + (ix_o + 3iy_o + 9*iz_o);
// Calculate one matrix per thread and write to global memory space
c[idx_o] = As[ix_o][0][iz_o]*diag_s[0][iz_o]*Bs[0][iy_o][iz_o]
+ As[ix_o][1][iz_o]*diag_s[1][iz_o]*Bs[1][iy_o][iz_o]
+ As[ix_o][2][iz_o]*diag_s[2][iz_o]*Bs[2][iy_o][iz_o];
if(idx_o >= 9*MatNum) return; // Prevents overshooting allocated memory
} // End StackedMatMul
// -----------------------------------------------------------------------------[/codebox]