Hi,
I need suggestions on how to optimize my code or new implementation ideas to perform tensor contractions.
I need to perform a lot of contractions for a rank-4 tensor T, with each dimension D (D^4 elements). D is usually small ~4-10. and the final answer is another a rank-4 tensor TO:
TO[a,b,c,d]=T[a,l,c,i]*T[a,j,i,d]*T[j,b,k,d]*T[l,b,c,k]
Repeated indices are summed over.
Here is my code for this operation, and would like to ask for suggestions to optimize it or new ways to implement it. Currently I reshape T to a D^2xD^2 matrix. The block_size is set to D, so there will be DxD threads in a 2d block and the grid will be DxD as well. Each thread now computes an element of TO.
For D=4, I can achieve 33% occupancy, and about 1.5x speedup. When I increase D to D=8, the code seems to hang. The CPU benchmark takes less then 1s to run for D=8. I also put T into a 2D texture, and the code is actually slower for D=4.
Any suggestion will be appreciated.
Thank you.
__global__ void tensorCon( float* T, float* TO)
{
// Block index
int bx = blockIdx.x;
int by = blockIdx.y;
// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;
int alpha = bx / D;
int beta = bx % D;
int gamma = by / D;
int delta = by % D;
// Declaration of the shared memory arrays
__shared__ float As[D][D];
__shared__ float Bs[D][D];
__shared__ float Cs[D][D];
__shared__ float Ds[D][D];
__shared__ float sdata[D];
As[ty][tx] = T[ tx + gamma * D + ty * D*D + alpha * D*D*D ];
Bs[ty][tx] = T[ delta + ty * D + tx * D*D + alpha * D*D*D ];
Cs[ty][tx] = T[ delta + tx * D + beta * D*D + ty * D*D*D ];
Ds[ty][tx] = T[ ty + gamma * D + beta * D*D+ tx * D*D*D ];
__syncthreads();
float Csub = 0.;
float Dsub = 0.;
for (int k = 0; k < D; ++k){
Csub += As[ty][k] * Bs[k][tx];
Dsub += Cs[ty][k] * Ds[k][tx];
}
__syncthreads();
As[ty][tx] = Csub;
Bs[ty][tx] = Dsub;
__syncthreads();
if (ty==0) {
float res=0.;
for (int k =0; k < D; ++k )
res += As[tx][k]*Bs[k][tx];
sdata[tx]=res;
__syncthreads();
for ( unsigned int s=D/2; s> 0; s>>=1 ){
if ( tx < s )
sdata[tx] +=sdata[tx+s];
__syncthreads();
}
}
TO[delta+gamma*D+beta*D*D*D+alpha*D*D*D]=sdata[0];
}