Hi
I’m trying out a simple tensor core matrix multiplication example. Here are my normal matrix product implementation and one using tensor cores:
void block_gemm(const int K,
const half2* tileA, const int LDA,
const half2* tileB, const int LDB,
float* tileC, const int LDC)
{
const int tid_x = threadIdx.x;
const int tid_y = threadIdx.y;
float acc = 0.0f;
#pragma unroll
for(int k = 0; k < K; ++k){
const float2 valA = __half22float2(tileA[IDX2D(tid_y, k, LDA)]);
const float2 valB = __half22float2(tileB[IDX2D(tid_x, k, LDB)]);
acc += valA.x * valB.x;
acc += valA.y * valB.y;
}
tileC[IDX2D(tid_y, tid_x, LDC)] = acc;
}
void block_gemm_wmma(const int K,
const half* tileA, const int LDA,
const half* tileB, const int LDB,
float* tileC, const int LDC)
{
using namespace nvcuda;
const int tid_x = threadIdx.x;
const int tid_y = threadIdx.y;
const int wid = (tid_y * blockDim.x + tid_x) / WARPSIZE;
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
constexpr int NUMBLOCK_M = BLOCKDIM_Y / WMMA_M;
constexpr int NUMBLOCK_N = BLOCKDIM_X / WMMA_N;
if(wid < NUMBLOCK_M * NUMBLOCK_N){
const int warpM = (wid / NUMBLOCK_M) * WMMA_M;
const int warpN = (wid % NUMBLOCK_N) * WMMA_N;
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b; // col major as b is stored in transposed form in shared memory
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c;
wmma::fill_fragment(c, 0.0f);
for(int k = 0; k < K; k += WMMA_K){
wmma::load_matrix_sync(a, &tileA[IDX2D(warpM, k, LDA)], LDA);
wmma::load_matrix_sync(b, &tileB[IDX2D(warpN, k, LDB)], LDB);
wmma::mma_sync(c, a, b, c);
}
wmma::store_matrix_sync(&tileC[IDX2D(warpM, warpN, LDC)], c, LDC, wmma::mem_row_major);
}
}
Input matrices are half precision 32xK
matrices in shared memory, with entries in the range [0, 1]
. When I run these two matrix products and compare the results index by index, I get differences of the order 1e-5
. Is this to be expected or is there a mistake somewhere in my implementation?