Cuda Kernel Slower when using Cuda Pipelines Despite avoiding Bank Conflicts

A GEMM kernel I wrote runs at about 65 TFLOPs on an A5000 GPU but incurs about 8 million shared memory bank conflicts (l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum). To avoid these, I included shared memory transfers by using a cuda::pipeline. I used the pipeline similar to its usage in the Cuda Samples globalToShmemAsyncCopy project. The async copies eliminate almost all bank conflicts, but the program runs slower, at about 55-60 TFLOPs.

A sample program can be seen below:

#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda/pipeline>
#include <mma.h>

#include <iostream>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <typeinfo>
#include <vector>

#define gpuErrchk(ans) \
    { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
                      bool abort = true) {
    if (code != cudaSuccess) {
        fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
                line);
        if (abort) exit(code);
    }
}

template <size_t BM_, size_t BN_, size_t BK_>
struct Block {
    constexpr static size_t kM = BM_;
    constexpr static size_t kN = BN_;
    constexpr static size_t kK = BK_;
};

constexpr size_t CUDA_VECTORIZED_BITS_LOAD = 128;

template <typename T>
class Matrix {
   public:
    Matrix(size_t rows, size_t cols)
        : rows_(rows), cols_(cols), host(new T[rows * cols * sizeof(T)]) {
        DeviceAlloc();
    };
    ~Matrix() {
        delete[] host;
        cudaFree(device);
    };
    // Delete copy constructor /assignment operator because the Matrix has
    // reference semantics and still a destructor that cleans memory
    Matrix(const Matrix &) = delete;
    Matrix &operator=(const Matrix &) = delete;
    Matrix(Matrix &&) noexcept = default;
    Matrix &operator=(Matrix &&) noexcept = default;
    void sync(cudaMemcpyKind kind) {
        cudaError_t status{};
        if (kind == cudaMemcpyHostToDevice) {
            status = cudaMemcpy(device, host, n_ele() * sizeof(T), kind);
        } else if (kind == cudaMemcpyDeviceToHost) {
            status = cudaMemcpy(host, device, n_ele() * sizeof(T), kind);
        } else {
            throw std::runtime_error{"Not supported"};
        };
        // cudaCheck(status);
    }
    T *device_ptr() { return device; };
    T *host_ptr() { return host; }
    const T *host_ptr() const { return host; }
    T &operator[](size_t i) { return *(host + i); };
    T operator[](size_t i) const { return *(host + i); };
    [[nodiscard]] size_t n_ele() const noexcept { return rows_ * cols_; };

   private:
    size_t rows_;
    size_t cols_;
    T *device;
    T *host;
    void DeviceAlloc() {
        cudaMalloc((void **)&device, rows_ * cols_ * sizeof(T));
    };
};

template <typename T>
void random_fill_and_sync(Matrix<T> &matrix) {
    std::default_random_engine engine{0};
    std::uniform_real_distribution<float> uniform(-1, 1);
    for (size_t i = 0; i < matrix.n_ele(); ++i) {
        matrix[i] = (T)(uniform(engine));
    }
    matrix.sync(cudaMemcpyHostToDevice);
}

template <typename T>
void fill_and_sync(Matrix<T> &matrix, T val) {
    for (size_t i = 0; i < matrix.n_ele(); ++i) {
        matrix[i] = val;
    }
    matrix.sync(cudaMemcpyHostToDevice);
};

template <size_t rows, size_t cols, size_t num_vectorized_loads>
struct Index {
    const size_t row{0};
    const size_t col{0};
    __host__ __device__ explicit Index(size_t laneId)
        : row(laneId / (cols / num_vectorized_loads)),
          col((laneId % (cols / num_vectorized_loads)) *
              num_vectorized_loads){};
};

template <typename T, size_t rows, size_t cols, typename ThreadOffset>
struct Loader {
    T (&shmem_)[rows][cols];
    T *global_ptr_;
    const ThreadOffset offset_;
    const size_t stride_;
    const size_t ld_;
    __host__ __device__ Loader(T *global_ptr, T (&shmem)[rows][cols],
                               size_t threadId, size_t blockOffset, size_t ld,
                               size_t stride)
        : global_ptr_(global_ptr + blockOffset),
          shmem_(shmem),
          offset_(threadId),
          ld_(ld),
          stride_(stride){};
    __host__ __device__ void load(
        cuda::pipeline<cuda::thread_scope_thread> &barrier) {
        const size_t global_idx = offset_.row * ld_ + offset_.col;
        constexpr size_t load_bytes = 16;
#pragma unroll
        for (size_t row = 0; row < rows; row += stride_) {
            const T *src = global_ptr_ + row * ld_ + global_idx;
            T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
            cuda::memcpy_async(dst, src, load_bytes, barrier);
        }
    }
    __host__ __device__ void load() {
        const size_t global_idx = offset_.row * ld_ + offset_.col;
#pragma unroll
        for (size_t row = 0; row < rows; row += stride_) {
            const T *src = global_ptr_ + row * ld_ + global_idx;
            T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
            int4 t = reinterpret_cast<const int4 *>(src)[0];
            reinterpret_cast<int4 *>(dst)[0] = t;
        }
    }

    __host__ __device__ void next(size_t stride) {
        global_ptr_ = global_ptr_ + stride;
    };
};

template <class Threadblock, class Warp, class WMMAblock>
__global__ void WarpCompute(half *__restrict__ A, half *__restrict__ B, half *C,
                            size_t M, size_t N, size_t K) {
    assert(M % Threadblock::kM == 0);
    assert(N % Threadblock::kN == 0);
    assert(K % Threadblock::kK == 0);
    const size_t threadblock_rows = M / Threadblock::kM;
    const size_t threadblock_cols = N / Threadblock::kN;
    const size_t num_blocks = gridDim.x * gridDim.y * gridDim.z;
    assert(threadblock_rows * threadblock_cols == num_blocks);

    static_assert(Threadblock::kM >= Warp::kM);
    static_assert(Threadblock::kN >= Warp::kN);
    static_assert(Threadblock::kK >= Warp::kK);

    constexpr size_t warpRows = Threadblock::kM / Warp::kM;
    constexpr size_t warpCols = Threadblock::kN / Warp::kN;
    const size_t num_warps = blockDim.x * blockDim.y * blockDim.z / warpSize;
    assert(warpRows * warpCols == num_warps);

    constexpr size_t skew = 8;

    __shared__ half As[Threadblock::kM][Threadblock::kK + skew];
    __shared__ half Bs[Threadblock::kK][Threadblock::kN + skew];

    const size_t total_threadId = threadIdx.x;
    const size_t thread_num = blockDim.x;
    constexpr size_t n_gld =
        CUDA_VECTORIZED_BITS_LOAD / (sizeof(half) * 8);  // bytes to bits

    const size_t stride_a = thread_num * n_gld / Threadblock::kK;
    const size_t stride_b = thread_num * n_gld / Threadblock::kN;

    Loader<half, Threadblock::kM, Threadblock::kK + skew,
           Index<Threadblock::kM, Threadblock::kK, n_gld>>
        LoaderA{A, As,      total_threadId, blockIdx.y * Threadblock::kM * K,
                K, stride_a};
    Loader<half, Threadblock::kK, Threadblock::kN + skew,
           Index<Threadblock::kK, Threadblock::kN, n_gld>>
        LoaderB{B, Bs,      total_threadId, blockIdx.x * Threadblock::kN,
                N, stride_b};

    const size_t warp_id = threadIdx.x / warpSize;

    const size_t warp_x = warp_id / warpCols;
    const size_t warp_y = warp_id % warpRows;
    const size_t AsWarpOffset = warp_y * Warp::kM;
    const size_t BsWarpOffset = warp_x * Warp::kN;

    constexpr size_t threadRows = Warp::kM / WMMAblock::kM;
    constexpr size_t threadCols = Warp::kN / WMMAblock::kN;

    C = &C[blockIdx.y * Threadblock::kM * N + blockIdx.x * Threadblock::kN +
           warp_y * Warp::kM * N + warp_x * Warp::kN];

    nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMAblock::kM, WMMAblock::kN,
                           WMMAblock::kK, half, nvcuda::wmma::row_major>
        A_frag[threadRows];
    nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMAblock::kM, WMMAblock::kN,
                           WMMAblock::kK, half, nvcuda::wmma::row_major>
        B_frag[threadRows];
    nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMAblock::kM,
                           WMMAblock::kN, WMMAblock::kK, half>
        C_frag[threadRows][threadRows];
    for (size_t i = 0; i < threadRows; ++i) {
        for (size_t j = 0; j < threadRows; ++j) {
            nvcuda::wmma::fill_fragment(C_frag[i][j], 0.0);
        }
    }

    cuda::pipeline<cuda::thread_scope_thread> pipe = cuda::make_pipeline();
    for (size_t block = 0; block < K; block += Threadblock::kK) {
        //  Fast kernel but with many bank conflicfts upon store
        LoaderA.load();
        LoaderB.load();
        LoaderA.next(Threadblock::kK);
        LoaderB.next(Threadblock::kK * N);

        //  Slower kernel but with zero bank conflicfts
        // pipe.producer_acquire();
        // LoaderA.load(pipe);
        // LoaderB.load(pipe);
        // pipe.producer_commit();
        // pipe.consumer_wait();

        __syncthreads();
#pragma unroll
        for (size_t bk = 0; bk < Threadblock::kK; bk += WMMAblock::kK) {
#pragma unroll
            for (size_t i = 0; i < threadRows; ++i) {
                nvcuda::wmma::load_matrix_sync(
                    A_frag[i], &As[AsWarpOffset + i * WMMAblock::kM][bk],
                    Threadblock::kK + skew);
            }
#pragma unroll
            for (size_t i = 0; i < threadRows; ++i) {
                nvcuda::wmma::load_matrix_sync(
                    B_frag[i], &Bs[bk][BsWarpOffset + i * WMMAblock::kN],
                    Threadblock::kM + skew);
            }
#pragma unroll
            for (size_t i = 0; i < threadRows; ++i) {
#pragma unroll
                for (size_t j = 0; j < threadRows; j++) {
                    nvcuda::wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j],
                                           C_frag[i][j]);
                }
            }
        }
        __syncthreads();
    }
#pragma unroll
    for (size_t i = 0; i < threadRows; ++i) {
#pragma unroll
        for (size_t j = 0; j < threadCols; ++j) {
            nvcuda::wmma::store_matrix_sync(
                &C[i * WMMAblock::kM * N + j * WMMAblock::kK], C_frag[i][j], N,
                nvcuda::wmma::mem_row_major);
        }
    }
}

void WarpedSharedUnrollVectorize(half *A, half *B, half *C, size_t M, size_t N,
                                 size_t K, float alpha, float beta) {
    using tb = Block<128, 128, 32>;
    // Cutlass has often 256, 128, 64.... for sm80
    using wb = Block<32, 32, 32>;

    using wmma_block = Block<16, 16, 16>;
    assert(M % tb::kM == 0);
    assert(N % tb::kN == 0);
    static_assert(tb::kM % wb::kM == 0);
    static_assert(tb::kN % wb::kN == 0);
    dim3 gdim(M / tb::kM, N / tb::kN, 1);
    dim3 bdim(32 * (tb::kM / wb::kM) * (tb::kN / wb::kN), 1, 1);
    WarpCompute<tb, wb, wmma_block><<<gdim, bdim>>>(A, B, C, M, N, K);
}

int main() {
    using T = half;
    const float alpha{1.0}, beta{0.0};
    constexpr size_t M{4096};
    constexpr size_t N{4096};
    constexpr size_t K{4096};
    Matrix<T> A(M, K);
    Matrix<T> B(K, N);
    Matrix<T> D(M, N);
    Matrix<T> Dref(M, N);
    random_fill_and_sync(A);
    random_fill_and_sync(B);
    fill_and_sync(D, (T)0.0);
    cudaEvent_t beg, end;
    float elapsed_time{0};
    cudaEventCreate(&beg);
    cudaEventCreate(&end);
    constexpr size_t n{50};
    cudaEventRecord(beg);
    for (int j = 0; j < n; ++j) {
        WarpedSharedUnrollVectorize(A.device_ptr(), B.device_ptr(),
                                    D.device_ptr(), M, N, K, alpha, beta);
    };
    gpuErrchk(cudaPeekAtLastError());
    gpuErrchk(cudaDeviceSynchronize());
    cudaEventRecord(end);
    cudaEventSynchronize(end);
    cudaEventElapsedTime(&elapsed_time, beg, end);
    const double flops{2 * std::pow(double(M), 3) * 1e-9};
    std::cout << "total flops : " << flops << std::endl;
    printf(
        "Average elapsed time: (%7.6f) ms, performance: (%7.1f) GFLOPS. size: "
        "(%ld).\n",
        elapsed_time / n, (n * flops * 1000) / (double)elapsed_time, M);
}

The relevant lines of code are the instructions to load global to shared memory:

        LoaderA.load();
        LoaderB.load();
        LoaderA.next(Threadblock::kK);
        LoaderB.next(Threadblock::kK * N);

        //  Slower kernel but with zero bank conflicfts
        // pipe.producer_acquire();
        // LoaderA.load(pipe);
        // LoaderB.load(pipe);
        // pipe.producer_commit();
        // pipe.consumer_wait();

        __syncthreads();

and the two loading functions:

    __host__ __device__ void load(
        cuda::pipeline<cuda::thread_scope_thread> &barrier) {
        const size_t global_idx = offset_.row * ld_ + offset_.col;
        constexpr size_t load_bytes = 16;
#pragma unroll
        for (size_t row = 0; row < rows; row += stride_) {
            const T *src = global_ptr_ + row * ld_ + global_idx;
            T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
            cuda::memcpy_async(dst, src, load_bytes, barrier);
        }
    }
    __host__ __device__ void load() {
        const size_t global_idx = offset_.row * ld_ + offset_.col;
#pragma unroll
        for (size_t row = 0; row < rows; row += stride_) {
            const T *src = global_ptr_ + row * ld_ + global_idx;
            T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
            int4 t = reinterpret_cast<const int4 *>(src)[0];
            reinterpret_cast<int4 *>(dst)[0] = t;
        }
    }

My questions are:

  • Am I using the cuda::pipelines correctly? The Cuda samples use different variants of the async copies, and I am not sure which one is the most appropriate for my use case.
  • Why is the kernel throughput not improved despite eliminating bank conflicts?