Differences in Precision Between Tensor Cores and CUDA Cores

I am learning to perform algorithmic deep optimization using the Tensor Core API wmma.

I encountered an issue while using the API, and below is an example using GEMM where matrix A and B are FP16, and the resulting matrix C is FP32 (float), all of them are row-majored.

As you can see, the tensor core code below is from the official cuda-samples (I only replace col-majored input with row-majored), and the calculation results from Tensor Cores and CUDA Cores have an error greater than 1e-3.

Is this an issue with my code, or is this expected behavior?

圖片
These are results from Naive CUDA GEMM kernel, Optimized CUDA kernel and Tensor Core wmma API.
You can see that the first two results are consistent as expected, but the final Tensor Core result has a significant discrepancy compared to the first two.

#include <cuda_runtime.h>
#include <mma.h>
using namespace nvcuda;
#include <iostream>
#include <random>
#include <utility>
#include <vector>
#include <cuda_fp16.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>

void fill_random_half_values(__half* arr, size_t n, std::default_random_engine& e)
{
    std::uniform_real_distribution<float> uniform_dist(-128.0f, 127.0f);
    for (size_t i{0}; i < n; ++i) {
        arr[i] = __float2half(uniform_dist(e));
    }
}

// Naive GEMM
template<typename T>
__global__ void __launch_bounds__(1024) gemm_naive(const T *__restrict__ dA, const T *__restrict__ dB, float *__restrict__ dC, int M, int K, int N)
{
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;
    float tmp = 0;

    if (row < M && col < N)
    {
        for (int s = 0; s < K; s++)
        {
            tmp += __half2float(dA[row * K + s] * dB[s * N + col]);
        }
        dC[row * N + col] = tmp;
    }
}

template<typename T>
__global__ void __launch_bounds__(1024) gemm_CUDA(float *__restrict__ c, const T *__restrict__ a, const T *__restrict__ b, int M, int N, int K) {
    
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int TILE_SIZE = 16;

    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int col = bx * TILE_SIZE + tx;
    const int row = by * TILE_SIZE + ty;

    __shared__ T SA[TILE_SIZE][TILE_SIZE];
    __shared__ T SB[TILE_SIZE][TILE_SIZE];

    float sum = 0;
    for (int k = 0; k < (K + TILE_SIZE - 1)/TILE_SIZE; ++k) {
        if (row < M && k * TILE_SIZE + tx < K) {
            SA[ty][tx] = a[row * K + k * TILE_SIZE + tx];
        } else {
            SA[ty][tx] = 0;
        }

        if (col < N && k * TILE_SIZE + ty < K) {
            SB[ty][tx] = b[col + (k * TILE_SIZE + ty) * N];
        } else {
            SB[ty][tx] = 0;
        }

        __syncthreads();

        for (int n_k = 0; n_k < TILE_SIZE; ++n_k) {
            sum += __half2float(SA[ty][n_k] * SB[n_k][tx]);
        }
        __syncthreads();
    }

    if (row < M && col < N) {
        c[row * N + col] = sum;
    }
}

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

#define WARP_SIZE 32

__host__ __device__ int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }

__global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, float *__restrict__ C, size_t M, size_t N, size_t K) {
    const size_t K_tiles = div_ceil(K, WMMA_K);

    const size_t warp_row = blockIdx.y * WMMA_M;
    const size_t warp_col = blockIdx.x * WMMA_N;

    if (warp_row >= M || warp_col >= N) {
        return;
    }

    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> C_frag;

    wmma::fill_fragment(C_frag, 0.0f);

#pragma unroll
    for (size_t i = 0; i < K_tiles; ++i) {
        wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;
        wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;

        wmma::load_matrix_sync(A_frag, A + warp_row * K + i * WMMA_K, K);
        wmma::load_matrix_sync(B_frag, B + warp_col + i * WMMA_K * N, N);

        wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);
    }

    wmma::store_matrix_sync(C + warp_row * N + warp_col, C_frag, N, wmma::mem_row_major);
}


int main() {

    int M = 128;
    int K = 128;
    int N = 128;

    std::cout << "Matrix Sizes" << std::endl;
    std::cout << "M: " << M << std::endl;
    std::cout << "N: " << N << std::endl;
    std::cout << "K: " << K << std::endl;

    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine random_engine(seed);

    thrust::host_vector<__half> h_a_vec(M*K);
    thrust::host_vector<__half> h_b_vec(K*N);

    fill_random_half_values(h_a_vec.data(), h_a_vec.size(), random_engine);
    fill_random_half_values(h_b_vec.data(), h_b_vec.size(), random_engine);
    
    thrust::device_vector<__half> d_a_vec = h_a_vec;
    thrust::device_vector<__half> d_b_vec = h_b_vec;
    thrust::device_vector<float> d_c_vec(M*N);

    dim3 threadNum(16, 16);
    dim3 blockNum((M + threadNum.x - 1)/threadNum.x, (N + threadNum.y - 1)/threadNum.y);

    cudaEvent_t cuda_start, cuda_end;
    cudaEventCreate(&cuda_start);
    cudaEventCreate(&cuda_end);

    const int numIterations = 1;
    float naive_totalTime = 0.0f;

    // 1. CUDA NAIVE
    for (int i = 0; i < numIterations; ++i) {
        cudaEventRecord(cuda_start, 0);

        gemm_naive<__half><<<blockNum, threadNum>>>(d_a_vec.data().get(), d_b_vec.data().get(), d_c_vec.data().get(), M, K, N);

        cudaEventRecord(cuda_end, 0);
        cudaEventSynchronize(cuda_end);

        float ms = 0.0f;
        cudaEventElapsedTime(&ms, cuda_start, cuda_end);

        naive_totalTime += ms;
    }

    thrust::host_vector<float> h_naive_c_vec = d_c_vec;

    // 2. CUDA
    float v2_totalTime = 0.0f;
    for (int i = 0; i < numIterations; ++i) {
        cudaEventRecord(cuda_start, 0);

        gemm_CUDA<__half><<<blockNum, threadNum>>>(d_c_vec.data().get(), d_a_vec.data().get(), d_b_vec.data().get(), M, N, K);

        cudaEventRecord(cuda_end, 0);
        cudaEventSynchronize(cuda_end);

        float ms = 0.0f;
        cudaEventElapsedTime(&ms, cuda_start, cuda_end);

        v2_totalTime += ms;
    }
    
    thrust::host_vector<float> h_c_vec = d_c_vec;

    // 3. Tensor Core

    dim3 block(WARP_SIZE);
    dim3 grid(div_ceil(N, WMMA_N), div_ceil(M, WMMA_M));

    float v3_totalTime = 0.0f;
    for (int i = 0; i < numIterations; ++i) {
        cudaEventRecord(cuda_start, 0);

        wmmaNaiveKernel<<<grid, block>>>(d_a_vec.data().get(), d_b_vec.data().get(), d_c_vec.data().get(), M, N, K);

        cudaEventRecord(cuda_end, 0);
        cudaEventSynchronize(cuda_end);

        float ms = 0.0f;
        cudaEventElapsedTime(&ms, cuda_start, cuda_end);

        v3_totalTime += ms;
    }
    
    thrust::host_vector<float> h_c_vec_v3 = d_c_vec;

    // compare
    bool flg = 0;
    float cuda_error = 0.0f;
    float tensor_error = 0.0f;

    for (int r=0 ; r<M ; ++r) {
        for (int c=0; c<N ; ++c) {
            float naive_res = h_naive_c_vec[r * N + c];
            float cuda_res = h_c_vec[r * N + c];
            float tensor_res = h_c_vec_v3[r * N + c];

            float err_cuda = abs(naive_res-cuda_res)/naive_res;
            float err_tensor = abs(naive_res-tensor_res)/naive_res;

            cuda_error += err_cuda;
            tensor_error += err_tensor;

            if (err_cuda > 1e-3 || err_tensor > 1e-3) {
                printf("(%f, %f, %f)\n", h_naive_c_vec[r * N + c], h_c_vec[r * N + c], h_c_vec_v3[r * N + c]);
                printf("Failed: cuda, tensor: %f, %f\n", err_cuda, err_tensor);
                flg = 1;
            }
            if (flg) break;
        }
        if (flg) break;
    }

    if (!flg) {
        std::cout << "NAIVE execution time: " << naive_totalTime/numIterations << " ms" << std::endl;
        std::cout << "V2 execution time: " << v2_totalTime/numIterations << " ms, error: " << cuda_error/(M*N) << std::endl;
        std::cout << "V3 execution time: " << v3_totalTime/numIterations << " ms, error: " << tensor_error/(M*N) << std::endl;
    }

    return 0;
}

The tensor core result is the more/most accurate of the 3.

The FP16 tensor core (TC) path, the way you are using it (with a 32-bit fragment for C) does a specialized kind of 16 bit multiplication that results in a 32-bit product (and it accumulates to a 32-bit accumulator). The way you are doing your 16-bit calculations in the other two paths is that you are doing a 16 bit multiplication into a 16 bit result, then converting that 16 bit result to a 32-bit result, then accumulating that in a 32-bit accumulator.

These two methodologies are not identical, numerically.

To see some evidence of this, instead of doing this:

        tmp += __half2float(dA[row * K + s] * dB[s * N + col]);

do this:

        tmp += __half2float(dA[row * K + s]) * __half2float(dB[s * N + col]);

This of course is also not identical to what TC is doing either, but is quite a bit closer numerically. You should be able to make a similar transformation to the shared case, if you wish.

If you do that, I think you will now see a “smaller” difference between the “CUDA core” result(s) and the TC result.

Alternatively, I think everything should match also if you restrict yourself entirely to FP16, although I haven’t tried that. It’s important to keep in mind the limited range of FP16 and the limited mantissa of FP16 when doing work into FP16 results, as well.

The use of random numbers for this as well as limited range of FP16 and the use of your particular error calculation also combine to make things a little harder to pin down. I personally wouldn’t spend much time trying to explain a code where you have intentionally designed it to produce a different result every time you run it. That might be viewed as clever, or appropriate for certain kinds of production work, but certainly not conducive to careful analysis.

One of the challenges of your random distribution centered at zero is that it allows for fairly large magnitude differences in the result. A large negative value can “cancel” with a large positive value, leaving a relatively small number. The residual or accumulated error then becomes a much larger fraction of the result than it was or would have been for either the large positive or large negative values before cancellation.

Because of that, even with my changes, you will still sometimes get results that show a larger relative error, enough to hit your threshold. If we change the random distribution to have one end at zero (e.g. be entirely nonnegative) then making the change I suggested to both of the non TC paths makes a result that consistently is less than your particular error threshold.

# cat t329.cu
#include <cuda_runtime.h>
#include <mma.h>
using namespace nvcuda;
#include <iostream>
#include <random>
#include <utility>
#include <vector>
#include <cuda_fp16.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>

void fill_random_half_values(__half* arr, size_t n, std::default_random_engine& e)
{
#ifndef USE_FIX
    std::uniform_real_distribution<float> uniform_dist(-128.0f, 127.0f);
#else
    std::uniform_real_distribution<float> uniform_dist(0.f, 127.0f);
#endif
    for (size_t i{0}; i < n; ++i) {
        arr[i] = __float2half(uniform_dist(e));
    }
}

// Naive GEMM
template<typename T>
__global__ void __launch_bounds__(1024) gemm_naive(const T *__restrict__ dA, const T *__restrict__ dB, float *__restrict__ dC, int M, int K, int N)
{
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;
    float tmp = 0;

    if (row < M && col < N)
    {
        for (int s = 0; s < K; s++)
        {
#ifndef USE_FIX
                tmp += __half2float(dA[row * K + s] * dB[s * N + col]);
#else
                tmp += __half2float(dA[row * K + s]) * __half2float(dB[s * N + col]);
#endif

        }
        dC[row * N + col] = tmp;
    }
}

template<typename T>
__global__ void __launch_bounds__(1024) gemm_CUDA(float *__restrict__ c, const T *__restrict__ a, const T *__restrict__ b, int M, int N, int K) {

    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int TILE_SIZE = 16;

    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int col = bx * TILE_SIZE + tx;
    const int row = by * TILE_SIZE + ty;

    __shared__ T SA[TILE_SIZE][TILE_SIZE];
    __shared__ T SB[TILE_SIZE][TILE_SIZE];

    float sum = 0;
    for (int k = 0; k < (K + TILE_SIZE - 1)/TILE_SIZE; ++k) {
        if (row < M && k * TILE_SIZE + tx < K) {
            SA[ty][tx] = a[row * K + k * TILE_SIZE + tx];
        } else {
            SA[ty][tx] = 0;
        }

        if (col < N && k * TILE_SIZE + ty < K) {
            SB[ty][tx] = b[col + (k * TILE_SIZE + ty) * N];
        } else {
            SB[ty][tx] = 0;
        }

        __syncthreads();

        for (int n_k = 0; n_k < TILE_SIZE; ++n_k) {
#ifndef USE_FIX
            sum += __half2float(SA[ty][n_k] * SB[n_k][tx]);
#else
            sum += __half2float(SA[ty][n_k]) * __half2float(SB[n_k][tx]);
#endif
        }
        __syncthreads();
    }

    if (row < M && col < N) {
        c[row * N + col] = sum;
    }
}

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

#define WARP_SIZE 32

__host__ __device__ int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }

__global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, float *__restrict__ C, size_t M, size_t N, size_t K) {
    const size_t K_tiles = div_ceil(K, WMMA_K);

    const size_t warp_row = blockIdx.y * WMMA_M;
    const size_t warp_col = blockIdx.x * WMMA_N;

    if (warp_row >= M || warp_col >= N) {
        return;
    }

    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> C_frag;

    wmma::fill_fragment(C_frag, 0.0f);

#pragma unroll
    for (size_t i = 0; i < K_tiles; ++i) {
        wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;
        wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;

        wmma::load_matrix_sync(A_frag, A + warp_row * K + i * WMMA_K, K);
        wmma::load_matrix_sync(B_frag, B + warp_col + i * WMMA_K * N, N);

        wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);
    }

    wmma::store_matrix_sync(C + warp_row * N + warp_col, C_frag, N, wmma::mem_row_major);
}


int main() {

    int M = 128;
    int K = 128;
    int N = 128;

    std::cout << "Matrix Sizes" << std::endl;
    std::cout << "M: " << M << std::endl;
    std::cout << "N: " << N << std::endl;
    std::cout << "K: " << K << std::endl;

    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    std::default_random_engine random_engine(seed);

    thrust::host_vector<__half> h_a_vec(M*K);
    thrust::host_vector<__half> h_b_vec(K*N);

    fill_random_half_values(h_a_vec.data(), h_a_vec.size(), random_engine);
    fill_random_half_values(h_b_vec.data(), h_b_vec.size(), random_engine);

    thrust::device_vector<__half> d_a_vec = h_a_vec;
    thrust::device_vector<__half> d_b_vec = h_b_vec;
    thrust::device_vector<float> d_c_vec(M*N);

    dim3 threadNum(16, 16);
    dim3 blockNum((M + threadNum.x - 1)/threadNum.x, (N + threadNum.y - 1)/threadNum.y);

    cudaEvent_t cuda_start, cuda_end;
    cudaEventCreate(&cuda_start);
    cudaEventCreate(&cuda_end);

    const int numIterations = 1;
    float naive_totalTime = 0.0f;

    // 1. CUDA NAIVE
    for (int i = 0; i < numIterations; ++i) {
        cudaEventRecord(cuda_start, 0);

        gemm_naive<__half><<<blockNum, threadNum>>>(d_a_vec.data().get(), d_b_vec.data().get(), d_c_vec.data().get(), M, K, N);

        cudaEventRecord(cuda_end, 0);
        cudaEventSynchronize(cuda_end);

        float ms = 0.0f;
        cudaEventElapsedTime(&ms, cuda_start, cuda_end);

        naive_totalTime += ms;
    }

    thrust::host_vector<float> h_naive_c_vec = d_c_vec;

    // 2. CUDA
    float v2_totalTime = 0.0f;
    for (int i = 0; i < numIterations; ++i) {
        cudaEventRecord(cuda_start, 0);

        gemm_CUDA<__half><<<blockNum, threadNum>>>(d_c_vec.data().get(), d_a_vec.data().get(), d_b_vec.data().get(), M, N, K);

        cudaEventRecord(cuda_end, 0);
        cudaEventSynchronize(cuda_end);

        float ms = 0.0f;
        cudaEventElapsedTime(&ms, cuda_start, cuda_end);

        v2_totalTime += ms;
    }

    thrust::host_vector<float> h_c_vec = d_c_vec;

    // 3. Tensor Core

    dim3 block(WARP_SIZE);
    dim3 grid(div_ceil(N, WMMA_N), div_ceil(M, WMMA_M));

    float v3_totalTime = 0.0f;
    for (int i = 0; i < numIterations; ++i) {
        cudaEventRecord(cuda_start, 0);

        wmmaNaiveKernel<<<grid, block>>>(d_a_vec.data().get(), d_b_vec.data().get(), d_c_vec.data().get(), M, N, K);

        cudaEventRecord(cuda_end, 0);
        cudaEventSynchronize(cuda_end);

        float ms = 0.0f;
        cudaEventElapsedTime(&ms, cuda_start, cuda_end);

        v3_totalTime += ms;
    }

    thrust::host_vector<float> h_c_vec_v3 = d_c_vec;

    // compare
    bool flg = 0;
    float cuda_error = 0.0f;
    float tensor_error = 0.0f;

    for (int r=0 ; r<M ; ++r) {
        for (int c=0; c<N ; ++c) {
            float naive_res = h_naive_c_vec[r * N + c];
            float cuda_res = h_c_vec[r * N + c];
            float tensor_res = h_c_vec_v3[r * N + c];

            float err_cuda = abs(naive_res-cuda_res)/naive_res;
            float err_tensor = abs(naive_res-tensor_res)/naive_res;

            cuda_error += err_cuda;
            tensor_error += err_tensor;

            if (err_cuda > 1e-3 || err_tensor > 1e-3) {
                printf("(%f, %f, %f)\n", h_naive_c_vec[r * N + c], h_c_vec[r * N + c], h_c_vec_v3[r * N + c]);
                printf("Failed: cuda, tensor: %f, %f\n", err_cuda, err_tensor);
                flg = 1;
            }
            if (flg) break;
        }
        if (flg) break;
    }

    if (!flg) {
        std::cout << "NAIVE execution time: " << naive_totalTime/numIterations << " ms" << std::endl;
        std::cout << "V2 execution time: " << v2_totalTime/numIterations << " ms, error: " << cuda_error/(M*N) << std::endl;
        std::cout << "V3 execution time: " << v3_totalTime/numIterations << " ms, error: " << tensor_error/(M*N) << std::endl;
    }

    return 0;
}
# nvcc -o t329 t329.cu -arch=sm_89 -DUSE_FIX
# ./t329
Matrix Sizes
M: 128
N: 128
K: 128
NAIVE execution time: 0.042688 ms
V2 execution time: 0.033024 ms, error: 0
V3 execution time: 0.03136 ms, error: 1.05642e-06
# ./t329
Matrix Sizes
M: 128
N: 128
K: 128
NAIVE execution time: 0.041984 ms
V2 execution time: 0.031904 ms, error: 0
V3 execution time: 0.032864 ms, error: 1.05686e-06
# ./t329
Matrix Sizes
M: 128
N: 128
K: 128
NAIVE execution time: 0.039936 ms
V2 execution time: 0.03232 ms, error: 0
V3 execution time: 0.031232 ms, error: 1.05867e-06
#

Regarding my earlier statement:

"The tensor core result is the more/most accurate of the 3. "

that applies to your original code. I’m not making that statement relative to my modifications.

1 Like