Most efficient way to work with 3 dimensional data?

I am solving a three dimensional transient diffusion problem. For this purpose, i will need to update the right hand sides for each time steps. Each element in the 3d array need to communicate with adjacent element.

What would be the most efficient way to deal with the three dimensional data in this case? I am thinking of mapping the 3 dimensional array to a 1D array, and use only threadIdx.x.

Would it be a good idea to declare 5 1d arrays in shared memory, one for (:,y,z), one for (:,y+1,z), one for (:,y-1,z), one for (:,y,z+1) and another one for (:,y,z-1) ? (: stands for every elements in a row/column)