Local memory conundrum

So, I’m implementing a grid-based computation with NxN threads (arranged in MxM blocks). Each grid cell reads data from it’s nearest neighbors, so we have an (N+2)x(N+2) grid of data cells to accomodate the boundary cases. Each block then is using a local grid region of (M+2)x(M+2).

The problem I’m having is that, for efficiency, I want to copy each block’s local region into shared memory before doing computation on it. However, the issue I’m having is how to efficiently copy an (M+2)x(M+2) grid into shared memory using MxM work items.

What I’m currently doing is having the MxM “body” of the region copied over in parallel, and then using thread (0, 0) within each block to copy over the boundaries.

Any thoughts?