After taking on board a few pointers from vvolkov (thanks!), starting from scratch and adding complexity where necessary, I’ve so far reduced execution time by about 30%.

I’m struggling to narrow down what the bottleneck is after this point. If I comment out the atomic write, by replacing it with a coalesced dense write, it’s only about 10% faster, which means I’m missing something crucial regarding the memory transactions prior. Will I be facing a large performance hit by the big read-stride from each warp? (Each warp will read in column (a), then column (a+4096), then column (a+8192) )

```
template<typename T>
__launch_bounds__(256, 8)
__global__ void ELLgemmT(T* result, const unsigned int* __restrict__ index,
const T* __restrict__ vals, const T* __restrict__ dense, unsigned int length)
{
int lane = threadIdx.x;
int warp = threadIdx.y;
int off1 = blockIdx.x * 256 + warp * 32;
int off2 = gridDim.x * 256;
//advance pointers
index += off1;
vals += off1;
dense += off1;
int read = off1/32;
T calc[32]={0.0f};
int index_reg = index[lane];
T vals_reg = vals[lane];
T dense_reg = dense[lane];
int index_reg_buf = 0;
T vals_reg_buf = 0.0f;
T dense_reg_buf = 0.0f;
while (read<length)
{
if (read+1<length)
{
index_reg_buf = index[lane + off2];
vals_reg_buf = vals[lane + off2];
dense_reg_buf = dense[lane + off2];
}
else
{
index_reg_buf = -1;
}
for (int i=0; i<32; i++)
{
T use_vals = __shfl(vals_reg,i);
int use_index = __shfl(index_reg,i);
int use_index_buf = __shfl(index_reg_buf,i);
calc[i]+=use_vals * dense_reg;
if (use_index!=use_index_buf)
{
result[lane+off1]=calc[i]; //this isn't any faster
//atomicAdd(result + lane + 32 * use_index,calc[i]);
calc[i]=0;
}
}
index_reg = index_reg_buf;
vals_reg = vals_reg_buf;
dense_reg = dense_reg_buf;
index += off2;
dense += off2;
vals += off2;
read+= off2/32;
}
}
void ELLgemm(float *result, unsigned int *index, float *vals, float *dense, unsigned int length, bool trans)
{
dim3 threadsPerBlock(32, 8);
int numblocks = BLOCKS_PER_SM * 32;
if (trans)
{
ELLgemmT<float><<<numblocks, threadsPerBlock>>>(result, index, vals, dense, length);
}
else
{
ELLgemmT<float><<<numblocks, threadsPerBlock>>>(result, index, vals, dense, length);
}
}
```