A GEMM kernel I wrote runs at about 65 TFLOPs on an A5000 GPU but incurs about 8 million shared memory bank conflicts (l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum). To avoid these, I included shared memory transfers by using a cuda::pipeline
. I used the pipeline similar to its usage in the Cuda Samples globalToShmemAsyncCopy project. The async copies eliminate almost all bank conflicts, but the program runs slower, at about 55-60 TFLOPs.
A sample program can be seen below:
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda/pipeline>
#include <mma.h>
#include <iostream>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <typeinfo>
#include <vector>
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) exit(code);
}
}
template <size_t BM_, size_t BN_, size_t BK_>
struct Block {
constexpr static size_t kM = BM_;
constexpr static size_t kN = BN_;
constexpr static size_t kK = BK_;
};
constexpr size_t CUDA_VECTORIZED_BITS_LOAD = 128;
template <typename T>
class Matrix {
public:
Matrix(size_t rows, size_t cols)
: rows_(rows), cols_(cols), host(new T[rows * cols * sizeof(T)]) {
DeviceAlloc();
};
~Matrix() {
delete[] host;
cudaFree(device);
};
// Delete copy constructor /assignment operator because the Matrix has
// reference semantics and still a destructor that cleans memory
Matrix(const Matrix &) = delete;
Matrix &operator=(const Matrix &) = delete;
Matrix(Matrix &&) noexcept = default;
Matrix &operator=(Matrix &&) noexcept = default;
void sync(cudaMemcpyKind kind) {
cudaError_t status{};
if (kind == cudaMemcpyHostToDevice) {
status = cudaMemcpy(device, host, n_ele() * sizeof(T), kind);
} else if (kind == cudaMemcpyDeviceToHost) {
status = cudaMemcpy(host, device, n_ele() * sizeof(T), kind);
} else {
throw std::runtime_error{"Not supported"};
};
// cudaCheck(status);
}
T *device_ptr() { return device; };
T *host_ptr() { return host; }
const T *host_ptr() const { return host; }
T &operator[](size_t i) { return *(host + i); };
T operator[](size_t i) const { return *(host + i); };
[[nodiscard]] size_t n_ele() const noexcept { return rows_ * cols_; };
private:
size_t rows_;
size_t cols_;
T *device;
T *host;
void DeviceAlloc() {
cudaMalloc((void **)&device, rows_ * cols_ * sizeof(T));
};
};
template <typename T>
void random_fill_and_sync(Matrix<T> &matrix) {
std::default_random_engine engine{0};
std::uniform_real_distribution<float> uniform(-1, 1);
for (size_t i = 0; i < matrix.n_ele(); ++i) {
matrix[i] = (T)(uniform(engine));
}
matrix.sync(cudaMemcpyHostToDevice);
}
template <typename T>
void fill_and_sync(Matrix<T> &matrix, T val) {
for (size_t i = 0; i < matrix.n_ele(); ++i) {
matrix[i] = val;
}
matrix.sync(cudaMemcpyHostToDevice);
};
template <size_t rows, size_t cols, size_t num_vectorized_loads>
struct Index {
const size_t row{0};
const size_t col{0};
__host__ __device__ explicit Index(size_t laneId)
: row(laneId / (cols / num_vectorized_loads)),
col((laneId % (cols / num_vectorized_loads)) *
num_vectorized_loads){};
};
template <typename T, size_t rows, size_t cols, typename ThreadOffset>
struct Loader {
T (&shmem_)[rows][cols];
T *global_ptr_;
const ThreadOffset offset_;
const size_t stride_;
const size_t ld_;
__host__ __device__ Loader(T *global_ptr, T (&shmem)[rows][cols],
size_t threadId, size_t blockOffset, size_t ld,
size_t stride)
: global_ptr_(global_ptr + blockOffset),
shmem_(shmem),
offset_(threadId),
ld_(ld),
stride_(stride){};
__host__ __device__ void load(
cuda::pipeline<cuda::thread_scope_thread> &barrier) {
const size_t global_idx = offset_.row * ld_ + offset_.col;
constexpr size_t load_bytes = 16;
#pragma unroll
for (size_t row = 0; row < rows; row += stride_) {
const T *src = global_ptr_ + row * ld_ + global_idx;
T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
cuda::memcpy_async(dst, src, load_bytes, barrier);
}
}
__host__ __device__ void load() {
const size_t global_idx = offset_.row * ld_ + offset_.col;
#pragma unroll
for (size_t row = 0; row < rows; row += stride_) {
const T *src = global_ptr_ + row * ld_ + global_idx;
T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
int4 t = reinterpret_cast<const int4 *>(src)[0];
reinterpret_cast<int4 *>(dst)[0] = t;
}
}
__host__ __device__ void next(size_t stride) {
global_ptr_ = global_ptr_ + stride;
};
};
template <class Threadblock, class Warp, class WMMAblock>
__global__ void WarpCompute(half *__restrict__ A, half *__restrict__ B, half *C,
size_t M, size_t N, size_t K) {
assert(M % Threadblock::kM == 0);
assert(N % Threadblock::kN == 0);
assert(K % Threadblock::kK == 0);
const size_t threadblock_rows = M / Threadblock::kM;
const size_t threadblock_cols = N / Threadblock::kN;
const size_t num_blocks = gridDim.x * gridDim.y * gridDim.z;
assert(threadblock_rows * threadblock_cols == num_blocks);
static_assert(Threadblock::kM >= Warp::kM);
static_assert(Threadblock::kN >= Warp::kN);
static_assert(Threadblock::kK >= Warp::kK);
constexpr size_t warpRows = Threadblock::kM / Warp::kM;
constexpr size_t warpCols = Threadblock::kN / Warp::kN;
const size_t num_warps = blockDim.x * blockDim.y * blockDim.z / warpSize;
assert(warpRows * warpCols == num_warps);
constexpr size_t skew = 8;
__shared__ half As[Threadblock::kM][Threadblock::kK + skew];
__shared__ half Bs[Threadblock::kK][Threadblock::kN + skew];
const size_t total_threadId = threadIdx.x;
const size_t thread_num = blockDim.x;
constexpr size_t n_gld =
CUDA_VECTORIZED_BITS_LOAD / (sizeof(half) * 8); // bytes to bits
const size_t stride_a = thread_num * n_gld / Threadblock::kK;
const size_t stride_b = thread_num * n_gld / Threadblock::kN;
Loader<half, Threadblock::kM, Threadblock::kK + skew,
Index<Threadblock::kM, Threadblock::kK, n_gld>>
LoaderA{A, As, total_threadId, blockIdx.y * Threadblock::kM * K,
K, stride_a};
Loader<half, Threadblock::kK, Threadblock::kN + skew,
Index<Threadblock::kK, Threadblock::kN, n_gld>>
LoaderB{B, Bs, total_threadId, blockIdx.x * Threadblock::kN,
N, stride_b};
const size_t warp_id = threadIdx.x / warpSize;
const size_t warp_x = warp_id / warpCols;
const size_t warp_y = warp_id % warpRows;
const size_t AsWarpOffset = warp_y * Warp::kM;
const size_t BsWarpOffset = warp_x * Warp::kN;
constexpr size_t threadRows = Warp::kM / WMMAblock::kM;
constexpr size_t threadCols = Warp::kN / WMMAblock::kN;
C = &C[blockIdx.y * Threadblock::kM * N + blockIdx.x * Threadblock::kN +
warp_y * Warp::kM * N + warp_x * Warp::kN];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMAblock::kM, WMMAblock::kN,
WMMAblock::kK, half, nvcuda::wmma::row_major>
A_frag[threadRows];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMAblock::kM, WMMAblock::kN,
WMMAblock::kK, half, nvcuda::wmma::row_major>
B_frag[threadRows];
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMAblock::kM,
WMMAblock::kN, WMMAblock::kK, half>
C_frag[threadRows][threadRows];
for (size_t i = 0; i < threadRows; ++i) {
for (size_t j = 0; j < threadRows; ++j) {
nvcuda::wmma::fill_fragment(C_frag[i][j], 0.0);
}
}
cuda::pipeline<cuda::thread_scope_thread> pipe = cuda::make_pipeline();
for (size_t block = 0; block < K; block += Threadblock::kK) {
// Fast kernel but with many bank conflicfts upon store
LoaderA.load();
LoaderB.load();
LoaderA.next(Threadblock::kK);
LoaderB.next(Threadblock::kK * N);
// Slower kernel but with zero bank conflicfts
// pipe.producer_acquire();
// LoaderA.load(pipe);
// LoaderB.load(pipe);
// pipe.producer_commit();
// pipe.consumer_wait();
__syncthreads();
#pragma unroll
for (size_t bk = 0; bk < Threadblock::kK; bk += WMMAblock::kK) {
#pragma unroll
for (size_t i = 0; i < threadRows; ++i) {
nvcuda::wmma::load_matrix_sync(
A_frag[i], &As[AsWarpOffset + i * WMMAblock::kM][bk],
Threadblock::kK + skew);
}
#pragma unroll
for (size_t i = 0; i < threadRows; ++i) {
nvcuda::wmma::load_matrix_sync(
B_frag[i], &Bs[bk][BsWarpOffset + i * WMMAblock::kN],
Threadblock::kM + skew);
}
#pragma unroll
for (size_t i = 0; i < threadRows; ++i) {
#pragma unroll
for (size_t j = 0; j < threadRows; j++) {
nvcuda::wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j],
C_frag[i][j]);
}
}
}
__syncthreads();
}
#pragma unroll
for (size_t i = 0; i < threadRows; ++i) {
#pragma unroll
for (size_t j = 0; j < threadCols; ++j) {
nvcuda::wmma::store_matrix_sync(
&C[i * WMMAblock::kM * N + j * WMMAblock::kK], C_frag[i][j], N,
nvcuda::wmma::mem_row_major);
}
}
}
void WarpedSharedUnrollVectorize(half *A, half *B, half *C, size_t M, size_t N,
size_t K, float alpha, float beta) {
using tb = Block<128, 128, 32>;
// Cutlass has often 256, 128, 64.... for sm80
using wb = Block<32, 32, 32>;
using wmma_block = Block<16, 16, 16>;
assert(M % tb::kM == 0);
assert(N % tb::kN == 0);
static_assert(tb::kM % wb::kM == 0);
static_assert(tb::kN % wb::kN == 0);
dim3 gdim(M / tb::kM, N / tb::kN, 1);
dim3 bdim(32 * (tb::kM / wb::kM) * (tb::kN / wb::kN), 1, 1);
WarpCompute<tb, wb, wmma_block><<<gdim, bdim>>>(A, B, C, M, N, K);
}
int main() {
using T = half;
const float alpha{1.0}, beta{0.0};
constexpr size_t M{4096};
constexpr size_t N{4096};
constexpr size_t K{4096};
Matrix<T> A(M, K);
Matrix<T> B(K, N);
Matrix<T> D(M, N);
Matrix<T> Dref(M, N);
random_fill_and_sync(A);
random_fill_and_sync(B);
fill_and_sync(D, (T)0.0);
cudaEvent_t beg, end;
float elapsed_time{0};
cudaEventCreate(&beg);
cudaEventCreate(&end);
constexpr size_t n{50};
cudaEventRecord(beg);
for (int j = 0; j < n; ++j) {
WarpedSharedUnrollVectorize(A.device_ptr(), B.device_ptr(),
D.device_ptr(), M, N, K, alpha, beta);
};
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
cudaEventRecord(end);
cudaEventSynchronize(end);
cudaEventElapsedTime(&elapsed_time, beg, end);
const double flops{2 * std::pow(double(M), 3) * 1e-9};
std::cout << "total flops : " << flops << std::endl;
printf(
"Average elapsed time: (%7.6f) ms, performance: (%7.1f) GFLOPS. size: "
"(%ld).\n",
elapsed_time / n, (n * flops * 1000) / (double)elapsed_time, M);
}
The relevant lines of code are the instructions to load global to shared memory:
LoaderA.load();
LoaderB.load();
LoaderA.next(Threadblock::kK);
LoaderB.next(Threadblock::kK * N);
// Slower kernel but with zero bank conflicfts
// pipe.producer_acquire();
// LoaderA.load(pipe);
// LoaderB.load(pipe);
// pipe.producer_commit();
// pipe.consumer_wait();
__syncthreads();
and the two loading functions:
__host__ __device__ void load(
cuda::pipeline<cuda::thread_scope_thread> &barrier) {
const size_t global_idx = offset_.row * ld_ + offset_.col;
constexpr size_t load_bytes = 16;
#pragma unroll
for (size_t row = 0; row < rows; row += stride_) {
const T *src = global_ptr_ + row * ld_ + global_idx;
T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
cuda::memcpy_async(dst, src, load_bytes, barrier);
}
}
__host__ __device__ void load() {
const size_t global_idx = offset_.row * ld_ + offset_.col;
#pragma unroll
for (size_t row = 0; row < rows; row += stride_) {
const T *src = global_ptr_ + row * ld_ + global_idx;
T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
int4 t = reinterpret_cast<const int4 *>(src)[0];
reinterpret_cast<int4 *>(dst)[0] = t;
}
}
My questions are:
- Am I using the cuda::pipelines correctly? The Cuda samples use different variants of the async copies, and I am not sure which one is the most appropriate for my use case.
- Why is the kernel throughput not improved despite eliminating bank conflicts?