Advices needed tensor contractions

Hi,

I need suggestions on how to optimize my code or new implementation ideas to perform tensor contractions.

I need to perform a lot of contractions for a rank-4 tensor T, with each dimension D (D^4 elements). D is usually small ~4-10. and the final answer is another a rank-4 tensor TO:

TO[a,b,c,d]=T[a,l,c,i]*T[a,j,i,d]*T[j,b,k,d]*T[l,b,c,k]

Repeated indices are summed over.

Here is my code for this operation, and would like to ask for suggestions to optimize it or new ways to implement it. Currently I reshape T to a D^2xD^2 matrix. The block_size is set to D, so there will be DxD threads in a 2d block and the grid will be DxD as well. Each thread now computes an element of TO.

For D=4, I can achieve 33% occupancy, and about 1.5x speedup. When I increase D to D=8, the code seems to hang. The CPU benchmark takes less then 1s to run for D=8. I also put T into a 2D texture, and the code is actually slower for D=4.

Any suggestion will be appreciated.

Thank you.

__global__ void tensorCon( float* T, float* TO)

{

// Block index

int bx = blockIdx.x;

int by = blockIdx.y;

// Thread index

int tx = threadIdx.x;

int ty = threadIdx.y;

int alpha = bx / D;

int beta = bx % D;

int gamma = by / D;

int delta = by % D;

// Declaration of the shared memory arrays

__shared__ float As[D][D];

__shared__ float Bs[D][D];

__shared__ float Cs[D][D];

__shared__ float Ds[D][D];

__shared__ float sdata[D];

As[ty][tx] = T[ tx + gamma * D + ty * D*D + alpha * D*D*D ];

Bs[ty][tx] = T[ delta + ty * D + tx * D*D + alpha * D*D*D ];

Cs[ty][tx] = T[ delta + tx * D + beta * D*D + ty * D*D*D ];

Ds[ty][tx] = T[ ty + gamma * D + beta * D*D+ tx * D*D*D ];

__syncthreads();

float Csub = 0.;

float Dsub = 0.;

for (int k = 0; k < D; ++k){

Csub += As[ty][k] * Bs[k][tx];

Dsub += Cs[ty][k] * Ds[k][tx];

}

__syncthreads();

As[ty][tx] = Csub;

Bs[ty][tx] = Dsub;

__syncthreads();

if (ty==0) {

float res=0.;

for (int k =0; k < D; ++k )

res += As[tx][k]*Bs[k][tx];

sdata[tx]=res;

__syncthreads();

for ( unsigned int s=D/2; s> 0; s>>=1 ){

if ( tx < s )

sdata[tx] +=sdata[tx+s];

__syncthreads();

}

}

TO[delta+gamma*D+beta*D*D*D+alpha*D*D*D]=sdata[0];

}