Speeding up triple matrix multiply

Hello all!

I am developing a ray caster where I have to calculate a triple matrix multiply for every frame, my equation is of the form:

R = M T N

M typically has a size of (768 x d), T is (d x d) and N is (d x 768). d is the degree of the surface, so it is usually below 10. (To be more precise T is a tensor of size (d x d x d), so T is stored as a tiled float4 texture where I pull 4 components at a time into shared memory.)

My current implementation (at the bottom of this message) is based on the example from the 0.8 programming guide (chapter 7). I first form the necessary sub block(s) of T N, then I pull in the necessary block(s) of M and calculate R. As far as I understand, the programming guide is written to be educational, and not a high-performance implementation. So I was wondering if anyone could shed some light on possible optimizations for doing several matrix multiplies in one go one the GPU?

// These textures are matrices whith size m*t, t*t and t*n respectively

texture<float, 2, cudaReadModeElementType> M;

texture<float, 2, cudaReadModeElementType> N;

texture<float4, 2, cudaReadModeElementType> T;

__device__

void TriMatMul( float *res, const int inner_blocks,

                const int t, const int tile ) {

    // Block index

    const int bx = blockIdx.x;

    const int by = blockIdx.y;

   // Thread index

    const int tx = threadIdx.x;

    const int ty = threadIdx.y;

	

    float4 Csub = make_float4( 0.0, 0.0, 0.0, 0.0 );

    for( int j = 0; j < inner_blocks; ++j ) {

        // We allocate shared memory for one float block and one float4 block.

        // First we use them to read in values from T and N

        // and calculate one element of W. When we have all the elements,

        // read from M into Ms and store the values we calculated in Ws.

        // This reduces the amount of shared memory we need to allocate.

        __shared__ float4 Ws[BLOCK_SIZE][BLOCK_SIZE];

        __shared__ float Ms[BLOCK_SIZE][BLOCK_SIZE];

       // First we build the W = TN^T matrix

        float4 Wsub = make_float4( 0.0, 0.0, 0.0, 0.0 );

        for( int i = 0; i < inner_blocks; ++i ) {

            // Load the matrices from texture memory to shared memory.

            // Each thread loads one element of each matrix Notice

            // that we offset the access to Ts by tile.  T is

            // orignally filled with zeros, so reading extra elements

            // should not be a problem.

            Ws[ty][tx] = texfetch( T, tx + i*blockDim.x, ty + j*blockDim.y + 

       tile*t);

            Ms[ty][tx] = texfetch( N, tx + bx*blockDim.x, ty + i*blockDim.y);

            __syncthreads();

           // The blocksize will seldom match the dimensions of the

            // matrices, so we have to avoid summing the invalid

            // elements.

            const int kend = min( blockDim.x, t - i*blockDim.x );

  	

            for ( int k = 0; k < kend; ++k ) {

                Wsub.x += Ws[ty][k].x * Ms[k][tx];

                Wsub.y += Ws[ty][k].y * Ms[k][tx];

                Wsub.z += Ws[ty][k].z * Ms[k][tx];

                Wsub.w += Ws[ty][k].w * Ms[k][tx];

            }

        }

        __syncthreads();

        // Write the element into Ws

        Ws[ty][tx] = Wsub;

       // Then each thread loads one element from the M matrix.

        Ms[ty][tx] = texfetch( M, tx + j*blockDim.x, ty + by*blockDim.y );

        // This syncthreads ensures both that Ws is fully calculated,

        // and that Ms has been filled from global memory.

        __syncthreads();

       const uint kend = min( blockDim.x, t - j*blockDim.x );

        for ( int k = 0; k < kend; ++k ) {

            Csub.x += Ms[ty][k] * Ws[k][tx].x;

            Csub.y += Ms[ty][k] * Ws[k][tx].y;

            Csub.z += Ms[ty][k] * Ws[k][tx].z;

            Csub.w += Ms[ty][k] * Ws[k][tx].w;

        }

        __syncthreads(); // Sync before we read more into Ms and Ws

    }

    // Write out the result. Res is of size maxsize, so it should not

    // be a problem to write out of bounds here.

    res[tile*4 + 0] = Csub.x;

    res[tile*4 + 1] = Csub.y;

    res[tile*4 + 2] = Csub.z;

    res[tile*4 + 3] = Csub.w;

}