Load data for tensor core

I want to load 128B (64 half) * 8 data from shared memory to register. Each thread have 16 half register. Here is a simple way to do that (just for example, have bank conflict, assume only have one warp):

__global__ f(int* ptr, ...) {
  __shared__ half src[64][64];
  half dst[16];
  int lane_id      = threadIdx.x;
  int group_id     = lane_id / 4;
  int tid_in_group = lane_id % 4;
  for (int i = 0; i < 8; i++) {
      dst[i * 2]     = src[ptr[group_id]][i * 8 + tid_in_group * 2];
      dst[i * 2 + 1] = src[ptr[group_id]][i * 8 + tid_in_group * 2 + 1];
  }
}

The code will cause bank conflict. To avoid that, I try to load data in cycle:

__global__ f(int* ptr, ...) {
  __shared__ half src[64][64];
  half dst[16];
  int lane_id      = threadIdx.x;
  int group_id     = lane_id / 4;
  int tid_in_group = lane_id % 4;
  for (int i = 0; i < 8; i++) {
    cycle_i = (i + group_id) % 8;
    dst[cycle_i  * 2]     = src[ptr[group_id]][cycle_i  * 8 + tid_in_group * 2];
    dst[cycle_i  * 2 + 1] = src[ptr[group_id]][cycle_i  * 8 + tid_in_group * 2 + 1];
  }
}

but it will cause the use of local memory because the register used can’t be determined at compile time. Is there any way to both avoid bank conflict and spill?

I think use branch may solve this problem, but it will make the code so messy. The overhead is also neglectable.

Is there any way to deal with this problem? Like use template or other things. I don’t want to use padding or alter the elements order of src’s first dimension (as ptr[i] is completely transferred from outside and I can only ensure they are different no-negetive integer)

To avoid local memory, insert #pragma unroll in front of the for loop, if it is not done automatically.
You have to make sure that the dst array is never accessed with dynamic indices. Unrolled loop variables are okay.

Let me think about how to combine it with avoiding bank conflicts.

You have no control how the data is arranged in shared memory?

How do you access the dst array later on? Can each thread have a different (but defined) order of the elements of dst?

Is it important that certain threads read specific data or can we reorder the threads? If you use matrix multiplications afterwards, we can perhaps also reorder the corresponding other matrix and the result.

How many bank conflicts do you get? What factor? To determine, whether it would be better to reorder with load/store/load over a different shared memory array or with shuffling instructions. Probably 8x? Because tid_in_group can have 4 values instead of 32.

Your second approach is actually quite good.
(For simpler processing and ensuring 32 bit accesses I changed to half2 for now).

I have no full solution, but just different ways to load without bank conflicts with the data landing in the wrong threads.

__global__ f(int* ptr, ...) {
  __shared__ half2 src[64][32];
  half2 dst[8];
  half2 dst1[8];
  int lane_id      = threadIdx.x;
  int group_id     = lane_id / 4;
  int tid_in_group = lane_id % 4;
  for (int i = 0; i < 8; i++) {
    cycle_i = (i + group_id) % 8;
    dst1[i]     = src[ptr[group_id]][cycle_i * 4 + tid_in_group]; // wrongly ordered dst1
  }
}

You could start to load without bank conflicts like above and then try to reorder the values locally.

Or you load into the wrong threads, same tid_in_group, but wrong group_id and then use shuffle or change how dst is used:

__global__ f(int* ptr, ...) {
  __shared__ half src[64][64];
  half dst[16];
  half2 dst1[8];
  int lane_id      = threadIdx.x;
  int group_id     = lane_id / 4;
  int tid_in_group = lane_id % 4;
  for (int i = 0; i < 8; i++) {
    dst1[i]     = src[ptr[group_id]][lane_id];
  }
}

Perhaps better. Have you tried a ldmatrix solution? It loads consecutive half data into threads 0, 4, 8, 12, 16, 20, 24, 28. You want to load into 0, 1, 2, 3. But perhaps you can rearrange the further processing.

The main problem is that threadIdx.x can’t be determined, so I think it is partly a kind of dynamic indices.

I know that have dynamic indices is difficult to avoid bank conflict and spill at the same time, but I just think this code clip is so simple and threadIdx.x only have very limited conditions, so I think there might be some solution.

I’ve think about an idea, I can load from smem to reg, then use shfl to change them. but this will lead to maybe 8 shfl instructions. I wonder if there are some better ideas

For src[64][64],I think there will be no use to reorder the first dimension, as the value of ptr[i] is hard to predict (it is also why I haven’t try ldmatrix yet). But for second dimension, I think it may be possible to reorder.
Now, dst in shared memory is just as the same layout as these data in the global memory. dst data will never change once they are loaded.

I will use mma.m16n8k16 after load, dst is the B matrix in multiplication D = A * B + C. It will run for 4 turns. For turn i, it will use dst[4 * i + 0]-dst[4 * i + 3]. I don’t understand how to reorder the threads.

Yes, the will be called in order for mma.
the main problem in my code is this clip:

dst[cycle_i  * 2] = ...
dst[cycle_i * 2 + 1] = ...

I have to use cycle_i instead of simple i to select the register to store the data, so that in the later use of them, I can use dst[0] - dst[3] for first turn. dst[4] = dst[7] for second and so on. It is all equal to all thread.
If I just use i instead of cycle_i, for each turn, I need different register for mma instruction for different thread. This will cause dynamic indice at mma stage.

See my updated post above.
shfl takes the same time as one shared memory access (load or store) and uses the same resources (so using many shfl slows down or takes away bandwidth from shared memory).


But we could also use shared memory for resorting.

You could do my first example / your second example:

Store each dst1 back into shared memory with every lane having its own bank, so it can access any other of its own data for storing or loading without bank conflict.
And then load back in the correct order.

  __shared__ half2 src[64][32];
  __shared__ half2 reorder[8][32];
  for (int i = 0; i < 8; i++) {
    cycle_i = (i + group_id) % 8;
    dst1[i]     = src[ptr[group_id]][cycle_i * 4 + tid_in_group]; // wrongly ordered dst1
  }
  for (int i = 0; i < 8; i++) {
    reorder[i][lane_id] = dst1[i];
  }
  for (int i = 0; i < 8; i++) {
    cycle2_i = (i + 8 - group_id) % 8;
    dst[i] = reorder[cycle2_i][lane_id];
  }

Hope I guessed cycle2_i correctly or please correct/change it. As there are definitely no bank conflicts with any cycle2_i.

This algorithm takes 3x as many shared memory accesses instead of 8x as many.

As shfl takes the same time as access to shared memory, if I use load + shfl, it may take 8 shfl. Will it be a better choice than to use shared memory again?

It is the number of accesses (normalized to one word of 32 bits):

  • read from src with 8x bank conflicts

=> 8 accesses

  • read from src
  • write to reorder
  • read from reorder

=> 3 accesses

  • read from src
  • 8x shuffle

=> 9 accesses

When storing src, you probably do not know the group_id at that time, of the thread which will in the future load the data?

Yes, I can’t know.

You have to mathematically think, what happens, if you exchange data, which leads to exchanged columns or rows. With matrix multiplications, often just some columns or rows in the result are exchanged. But could be different in this example.

If just the result is rearranged in some way, you could care for it, when saving the result of the matrix multiplication instead.

If use first code clip

__global__ f(int* ptr, ...) {
  __shared__ half2 src[64][32];
  half2 dst[8];
  half2 dst1[8];
  int lane_id      = threadIdx.x;
  int group_id     = lane_id / 4;
  int tid_in_group = lane_id % 4;
  for (int i = 0; i < 8; i++) {
    cycle_i = (i + group_id) % 8;
    dst1[i]     = src[ptr[group_id]][cycle_i * 4 + tid_in_group]; // wrongly ordered dst1
  }
}

and then shfl, it will be 8 + 8 access. can it be lower? like 8 + 4

I have no good idea for shuffle yet. How can you do it with 8+8 (= 2x)?

Would use 8+8+8 (=3x)

Whereas the original code used
8*8 (=8x)

I’ve made a mistake, I will check again. Thanks!