I thought of a solution to obtain more performance for cublassgemm, when transA == ‘t’ and transB == ‘t’, for matrix dimensions multiple of 16 and 64.
Thanks to V.Volkov and P.Leventis for large tiling and prefetching.
On a 8800 GTS 512 (G92), for large matrixes, it reaches (including all memory transactions) 191 Gflops (instead of 130 of cublas), on windows xp.
It computes A^T * B^T doing C^T = B * A, but transposing C^T in the write back.
To write it back without doing uncoalesced accesses, I used shared memory, but 64164 Bytes were too much, so I serialized the write back using only 16164 a time (actually 16174, to avoid conflicts).
compiled with --maxrregcount 32
__device__ void scalarAlphaXPlusY16( float alpha, float *X, float *Y )
{
Y[0] += alpha*X[0];
Y[1] += alpha*X[1];
Y[2] += alpha*X[2];
Y[3] += alpha*X[3];
Y[4] += alpha*X[4];
Y[5] += alpha*X[5];
Y[6] += alpha*X[6];
Y[7] += alpha*X[7];
Y[8] += alpha*X[8];
Y[9] += alpha*X[9];
Y[10] += alpha*X[10];
Y[11] += alpha*X[11];
Y[12] += alpha*X[12];
Y[13] += alpha*X[13];
Y[14] += alpha*X[14];
Y[15] += alpha*X[15];
}
__global__ void fml_sgemmTT_16_64_16( float* A, int lda, float* B, int ldb, float* C, int ldc, int nLoops, float alpha, float beta )
{
const int ibx = blockIdx.x * 64;
const int iby = blockIdx.y * 16;
A += threadx + (iby + thready)*lda;
B += ibx + threadx + thready*16;
C += (ibx + thready*16 )*ldc + iby + threadx;
float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
while (nLoops)
{
__shared__ float AShared[16][17];
AShared[threadx][thready] = A[0];
AShared[threadx][thready+4] = A[4*lda];
AShared[threadx][thready+8] = A[8*lda];
AShared[threadx][thready+12] = A[12*lda];
float b[4];
b[0] = B[0];
b[1] = B[ldb];
b[2] = B[2*ldb];
b[3] = B[3*ldb];
nLoops--;
B += 4*ldb;
__syncthreads();
scalarAlphaXPlusY16( b[0], AShared[0], c ); b[0] = B[0];
scalarAlphaXPlusY16( b[1], AShared[1], c ); b[1] = B[ldb];
scalarAlphaXPlusY16( b[2], AShared[2], c ); b[2] = B[2*ldb];
scalarAlphaXPlusY16( b[3], AShared[3], c ); b[3] = B[3*ldb];
B += 4*ldb;
scalarAlphaXPlusY16( b[0], AShared[4], c ); b[0] = B[0];
scalarAlphaXPlusY16( b[1], AShared[5], c ); b[1] = B[ldb];
scalarAlphaXPlusY16( b[2], AShared[6], c ); b[2] = B[2*ldb];
scalarAlphaXPlusY16( b[3], AShared[7], c ); b[3] = B[3*ldb];
B += 4*ldb;
scalarAlphaXPlusY16( b[0], AShared[8], c ); b[0] = B[0];
scalarAlphaXPlusY16( b[1], AShared[9], c ); b[1] = B[ldb];
scalarAlphaXPlusY16( b[2], AShared[10], c ); b[2] = B[2*ldb];
scalarAlphaXPlusY16( b[3], AShared[11], c ); b[3] = B[3*ldb];
B += 4*ldb;
scalarAlphaXPlusY16( b[0], AShared[12], c );
scalarAlphaXPlusY16( b[1], AShared[13], c );
scalarAlphaXPlusY16( b[2], AShared[14], c );
scalarAlphaXPlusY16( b[3], AShared[15], c );
A += 16;
__syncthreads();
}
__shared__ float CShared[16][17];
#pragma unroll
for (int j=0; j<4; j++)
{
if (thready == j) {
#pragma unroll
for( int i = 0; i < 16; i++ )
CShared[i][threadx] = alpha*c[i];
#pragma unroll
for( int i = 0; i < 16; i++, C += ldc )
CShared[threadx][i] += beta*C[0];
C -= ldc;
}
__syncthreads();
if (thready == j) {
#pragma unroll
for( int i = 15; i >= 0; i--, C -= ldc )
C[0] = CShared[threadx][i];
}
__syncthreads();
}
}
void fml_sgemmTT( int m, int n, int k, float alpha, float* A, int lda, float* B, int ldb, float beta, float* C, int ldc )
{
dim3 grid( n/64,m/16 ), threads( 16, 4 );
int nloops = k/16;
fml_sgemmTT_16_64_16<<<grid, threads>>>( A, lda, B, ldb, C, ldc, nloops, alpha, beta);
}