Using memcpy_async in matrix transpose

Hello everyone, I’m currently exploring the new asynchronous memory copy feature on an RTX 3050 laptop running Windows 11 with Microsoft Visual Studio version 19.29.30152. Specifically, I’m attempting to implement memcpy_async in a matrix transpose kernel that involves shared memory. Unfortunately, I’m encountering issues as the results differ from the synchronous version.

I’m uncertain about whether I’m using the API correctly, especially given the limited code samples I’ve come across. Below are the kernels I’m working with, and I would greatly appreciate guidance on the proper usage of asynchronous copy APIs.

kernel with sync copy api:

template <int BM = 32, int BN = 32>
__global__ void transpose_v3_kernel(const float* input, float* output, int M, int N) {
  __shared__ float shmem[BM][BN + 3];

  int bidx = blockIdx.x;
  int bidy = blockIdx.y;
  int tidx = threadIdx.x;
  int tidy = threadIdx.y;

  int row = bidy * blockDim.y + tidy;
  int col = bidx * blockDim.x + tidx;

  shmem[tidy][tidx] = input[row * N + col];
  __syncthreads();

  int tid  = threadIdx.y * blockDim.x + threadIdx.x;
  int srow = tid / blockDim.y;
  int scol = tid % blockDim.y;
  row      = bidx * blockDim.x + srow;
  col      = bidy * blockDim.y + scol;

  output[row * M + col] = shmem[scol][srow];
}

kernel with async copy api:

template <int BM = 32, int BN = 32>
__global__ void transpose_v4_kernel(const float* input, float* output, int M, int N) {
  __shared__ float shmem[BM][BN + 3];
  namespace cg = cooperative_groups;

  int bidx = blockIdx.x;
  int bidy = blockIdx.y;
  int tidx = threadIdx.x;
  int tidy = threadIdx.y;

  auto block = cg::this_thread_block();
  __shared__ cuda::barrier<cuda::thread_scope::thread_scope_block> barrier;
  if (block.thread_rank() == 0) {
    init(&barrier, block.size());  // Friend function initializes barrier
  }
  block.sync();

  int row = bidy * blockDim.y + tidy;
  int col = bidx * blockDim.x + tidx;

  cg::memcpy_async(block, &shmem[tidy][tidx], &input[row * N + col], sizeof(float));

  int tid  = threadIdx.y * blockDim.x + threadIdx.x;
  int srow = tid / blockDim.y;
  int scol = tid % blockDim.y;
  row      = bidx * blockDim.x + srow;
  col      = bidy * blockDim.y + scol;
  barrier.arrive_and_wait();
  block.sync();
  output[row * M + col] = shmem[scol][srow];
}

There seem to be several issues with your memcpy_async kernel.

  1. With respect to use of cuda::barrier (which is distinct from cooperative groups), when I compile your code on CUDA 12.2, I get a warning about dynamic initialization of a static shared object. If I switch to the suggested pattern indicated in the example here (i.e. thread_scope_system), the warning goes away. However as we’ll see in a moment, we are going to dispense with that.

  2. Perhaps more importantly, you are mixing synchronization methods between cuda::barrier and cg::memcpy_async. There are two distinct methods to structure such an operation: A. using cuda::barrier, cuda::memcpy_async, and explicit designation of the barrier in the memcpy_async call or B. using cg::memcpy_async, cg::wait and an implicit barrier that is handled “under the hood” by the cooperative groups system. You have mixed both. You have attempted to use an explicit cuda::barrier, with the implicit cg::memcpy_async method, which has no way to specify an explicit barrier, and you have not used the cg::wait to properly synchronize when using cg::memcpy_async

  3. Your general understanding of the usage of cg::memcpy_async is incorrect. It cannot be used as a 1:1 replacement for the shared memory copy you are doing in the non-cg kernel. There are several reasons for this, but they revolve around the idea that the cg::memcpy_async is a collective operation, and the specification you give it (i.e. the arguments you pass) are for the collective operation, not the individual threads. If you study the example this should eventually become clear to you. The arguments passed to memcpy_async do not/should not vary across threads in the specified group. (See here: " For cooperative variants, if the parameters are not the same across all threads in group , the behavior is undefined.") Notice that the shared pointer passed to that operation in that example is just the base pointer to the shared memory allocation. It does not show any per-thread offsetting. Also notice that the specified size for the operation (which happens to be the entire size of shared memory in that case) does not correspond to the per-thread amount to copy, as your attempt does, but instead is the total size of the copy performed by the collective operation.

So we have to do some reworking of your example. An additional factor that we will run into is that memcpy_async does a linear copy. It cannot/does not do any kind of strided copy, which is what’s needed here, for the group that you pass to the operation. We can address this by breaking your thread block group into (warp) tiles, where each tile copies a single row of shared memory, and therefore can be specified as a linear copy. Also note that although it appears that below there is only one group (the tile), there are in fact 32 tile groups, each warp belongs to a separate tile.

Here is a modification of your example, with these ideas in mind, that seems to work correctly for me:

# cat t135.cu
#include <iostream>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda/barrier>
template <int BM = 32, int BN = 32>
__global__ void transpose_v3_kernel(const float* input, float* output, int M, int N) {
  __shared__ float shmem[BM][BN + 1];

  int bidx = blockIdx.x;
  int bidy = blockIdx.y;
  int tidx = threadIdx.x;
  int tidy = threadIdx.y;

  int row = bidy * blockDim.y + tidy;
  int col = bidx * blockDim.x + tidx;

  shmem[tidy][tidx] = input[row * N + col];
  __syncthreads();
  row = bidx * blockDim.y+tidy;
  col = bidy * blockDim.x+tidx;
  output[row * M + col] = shmem[tidx][tidy];
}

template <int BM = 32, int BN = 32>
__global__ void transpose_v4_kernel(const float* input, float* output, int M, int N) {
  __shared__ float shmem[BM][BN+1];
  namespace cg = cooperative_groups;

  int bidx = blockIdx.x;
  int bidy = blockIdx.y;
  int tidx = threadIdx.x;
  int tidy = threadIdx.y;

  auto block = cg::this_thread_block();
  auto tile = cg::tiled_partition<32>(block);
  int bsrow = bidy * blockDim.y;  // block start row
  int bscol = bidx * blockDim.x;  // block start column

  cg::memcpy_async(tile, shmem[tile.meta_group_rank()], &input[(bsrow+tile.meta_group_rank())*M + bscol], sizeof(float)*tile.size());
  int row = bidx * blockDim.y+tidy;
  int col = bidy * blockDim.x+tidx;
  cg::wait(block);

  output[row * M + col] = shmem[tidx][tidy];
}

int main(){
  const int mult = 32;
  const int M = mult*32;
  const int N = mult*32;
  float *input, *output;
  cudaMalloc(&input, M*N*sizeof(*input));
  cudaMalloc(&output, M*N*sizeof(*output));
  cudaMemset(output, 0, M*N*sizeof(*output));
  float *hi = new float[M*N];
  for (int i = 0; i < M*N; i++) hi[i] = i%5;
  cudaMemcpy(input, hi, M*N*sizeof(*input), cudaMemcpyHostToDevice);
  float *r1 = new float[M*N];
  transpose_v3_kernel<<<dim3(mult,mult), dim3(32,32)>>>(input, output, M, N);
  cudaMemcpy(r1, output, M*N*sizeof(*output), cudaMemcpyDeviceToHost);
  cudaMemset(output, 0, M*N*sizeof(*output));
  float *r2 = new float[M*N];
  transpose_v4_kernel<<<dim3(mult,mult), dim3(32,32)>>>(input, output, M, N);
  cudaMemcpy(r2, output, M*N*sizeof(*output), cudaMemcpyDeviceToHost);
  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {std::cout << "Error: " << cudaGetErrorString(err) << std::endl; return 0;}
  for (int i = 0; i < M*N; i++) if (r1[i] != r2[i]) {std::cout << "Mismatch at: " << i << " was: " << r2[i]  << " should be: " << r1[i] << std::endl; return 0;}
  return 0;
}
# nvcc -o t135 t135.cu -arch=sm_89
# ./t135
#
2 Likes

Thanks very very much for your comprehensive explanation and code examples, I will delve deeper into studying them and modify my code to observe the result. BTW, today (Feb. 10th) is the Spring Festival, a traditional celebration in China and Southeast Aisa, best wishes to you and happy luna new year!!

Happy New Year!

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.