Wrong answer with mma.sync.aligned.m8n8k4

I tried to write a simple matrix multiplication code on V100 to use tensor core, but it gave the wrong answer.
I used the mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 instruction, and here is my code

#include <cuda_fp16.h>
#include <iostream>
#include <mma.h>
using namespace nvcuda;

__global__ void mma_test(half* output)
{
    int lane = threadIdx.x % 32;
    uint out[4] = { 0 };
    
    for (int i = 0; i < 10000; i++) {
        uint MultiA[2] = { 0 };
        uint MultiB[2] = { 0 };

        half* test1 = reinterpret_cast<half*>(MultiA);
        half* test2 = reinterpret_cast<half*>(MultiB);
        test1[0] = 0.8;
        test1[1] = 0.8;
        test1[2] = 0.8;
        test1[3] = 0.8;
        test2[0] = 0.7;
        test2[1] = 0.7;
        test2[2] = 0.7;
        test2[3] = 0.7;

        asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
                     "{ %0, %1, %2, %3 },"
                     "{ %4, %5 },"
                     "{ %6, %7 },"
                     "{ %8, %9, %10, %11 };\n"
                     : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3])
                     : "r"(MultiA[0]), "r"(MultiA[1]),
                     "r"(MultiB[0]), "r"(MultiB[1]),
                     "r"(out[0]), "r"(out[1]), "r"(out[2]), "r"(out[3]));
    }
    int store_row = lane % 4 + lane / 16 * 4;
    int store_col = (lane % 16) / 4;
    reinterpret_cast<uint4*>(output)[store_row * 4 + store_col] = reinterpret_cast<uint4*>(out)[0];
}

int main()
{
    half* output = (half*)malloc(sizeof(half) * 32 * 8);
    half* output_d = NULL;
    cudaMalloc(&output_d, sizeof(half) * 32 * 8);

    mma_test<<<1, 32>>>(output_d);

    cudaMemcpy(output, output_d, sizeof(half) * 32 * 8, cudaMemcpyDeviceToHost);

    for (int i = 0; i < 32 * 8; i++) {
        std::cout << (float)output[i] << " ";
    }
    std::cout << std::endl;
}

I found that when i was small, the answer approached the correct answer, but when i became large, the result was different.

You’re running into limits on fp16 format/calculations. In floating point, you cannot sum very large values and very small values, and expect to get very accurate results.

1 Like

How much different? Can you show a set of inputs and corresponding output(s)?

Note that your sample data as shown is not exactly representable in the binary16 half-precision format. If I got my math right (still on the first mug of coffee this morning :-), 0.7 is stored as 0.7002 and 0.8 is stored as 0.7998. When computing dot products, the resulting numerical error will increase in magnitude with the length of the vectors.

1 Like

The result of the product itself is always 0.8x0.7x4 (because k=4), in each output location. That value taking into account all considerations for doing that in fp16 is 2.24023.

At each step, you are summing that value with the sum of the previous iterations. As the sum of the previous iterations gets large (relative to what can be represented in fp16), then the result of the sum of e.g. 8192+2.24023 doesn’t give you 8194.24023 as you might expect.

This problem is due to the limited range of the mantissa/significand in any modern “floating point” number representation. The difference between the largest and smallest number that can be combined will vary based on the accuracy you expect.

1 Like

Here is a slightly modified example, showing results after varying amounts of accumulation:

$ cat t2242.cu
#include <cuda_fp16.h>
#include <iostream>
#include <mma.h>
using namespace nvcuda;

__global__ void mma_test(int loops, half* output)
{
    int lane = threadIdx.x % 32;
    uint out[4] = { 0 };

    for (int i = 0; i < loops; i++) {
        uint MultiA[2] = { 0 };
        uint MultiB[2] = { 0 };

        half* test1 = reinterpret_cast<half*>(MultiA);
        half* test2 = reinterpret_cast<half*>(MultiB);
        test1[0] = 0.8;
        test1[1] = 0.8;
        test1[2] = 0.8;
        test1[3] = 0.8;
        test2[0] = 0.7;
        test2[1] = 0.7;
        test2[2] = 0.7;
        test2[3] = 0.7;

        asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
                     "{ %0, %1, %2, %3 },"
                     "{ %4, %5 },"
                     "{ %6, %7 },"
                     "{ %8, %9, %10, %11 };\n"
                     : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3])
                     : "r"(MultiA[0]), "r"(MultiA[1]),
                     "r"(MultiB[0]), "r"(MultiB[1]),
                     "r"(out[0]), "r"(out[1]), "r"(out[2]), "r"(out[3]));
    }
    int store_row = lane % 4 + lane / 16 * 4;
    int store_col = (lane % 16) / 4;
    reinterpret_cast<uint4*>(output)[store_row * 4 + store_col] = reinterpret_cast<uint4*>(out)[0];
}

int main(int argc, char *argv[])
{
    int loops = 1;
    if (argc > 1) loops = atoi(argv[1]);
    half* output = (half*)malloc(sizeof(half) * 32 * 8);
    half* output_d = NULL;
    cudaMalloc(&output_d, sizeof(half) * 32 * 8);

    mma_test<<<1, 32>>>(loops, output_d);

    cudaMemcpy(output, output_d, sizeof(half) * 32 * 8, cudaMemcpyDeviceToHost);

    for (int i = 0; i < 1; i++) {
        std::cout << __half2float(output[i]) << " ";
    }
    std::cout << std::endl;
}
$ nvcc -o t2242 t2242.cu -arch=sm_70
$ ./t2242 1
2.24023
$ ./t2242 2
4.48047
$ ./t2242 4
8.96094
$ ./t2242 100
224.75
$ ./t2242 1000
2058
$ ./t2242 2000
4058
$ ./t2242 3041
8184
$ ./t2242 3042
8188
$ ./t2242 3043
8192
$ ./t2242 3044
8192
$ ./t2242 3045
8192
$
1 Like

If I want to use tensor core to calculate a matrix greater than 2000 × 2000, what should I do?
Can tensor core only be used to calculate small matrices?

Generally speaking, if the number of bits in the significand of the floating-point format is N, and assuming the numbers summed are of roughly the same magnitude, you should expect significant numerical issues when the number of addends is >= 2N. FP16 provides 10 stored significand bits, so for a 2000x2000 element matrix, this means trouble.

The easiest way to get around the summing problem is to accumulate in FP32. You can see from the description in the PTX manual that the precision for source and destination data can be selected independently, so this is a distinct possibility.

An alternative might be to split long dot products of FP16 source data into segments of, say 256 elements each, and then sum the partial sums from these segments at the end. This is not going to be as robust as summing in FP32.

1 Like

note that since the “problem” in this particular case is related to the accumulation step, you could “work around” it by using f32 accumulation. Example:

$ cat t2242.cu
#include <cuda_fp16.h>
#include <iostream>
#include <mma.h>
using namespace nvcuda;

__global__ void mma_test(int loops, float* output)
{
    float out[8] = { 0 };

    uint MultiA[2] = { 0 };
    uint MultiB[2] = { 0 };

    half* test1 = reinterpret_cast<half*>(MultiA);
    half* test2 = reinterpret_cast<half*>(MultiB);
    test1[0] = 0.8;
    test1[1] = 0.8;
    test1[2] = 0.8;
    test1[3] = 0.8;
    test2[0] = 0.7;
    test2[1] = 0.7;
    test2[2] = 0.7;
    test2[3] = 0.7;
    for (int i = 0; i < loops; i++) {

        asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
                     "{ %0, %1, %2, %3, %4, %5, %6, %7 },"
                     "{ %8, %9 },"
                     "{ %10, %11 },"
                     "{ %12, %13, %14, %15, %16, %17, %18, %19 };\n"
                     : "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]), "=f"(out[4]), "=f"(out[5]), "=f"(out[6]), "=f"(out[7])
                     : "r"(MultiA[0]), "r"(MultiA[1]),
                     "r"(MultiB[0]), "r"(MultiB[1]),
                     "f"(out[0]), "f"(out[1]), "f"(out[2]), "f"(out[3]), "f"(out[0]), "f"(out[1]), "f"(out[2]), "f"(out[3]));
    }
    for (int i = 0; i < 8; i++) output[threadIdx.x*8+i] =out[i];
}

int main(int argc, char *argv[])
{
    int loops = 1;
    if (argc > 1) loops = atoi(argv[1]);
    float* output = (float*)malloc(sizeof(float) * 32 * 8);
    float* output_d = NULL;
    cudaMalloc(&output_d, sizeof(float) * 32 * 8);

    mma_test<<<1, 32>>>(loops, output_d);

    cudaMemcpy(output, output_d, sizeof(float) * 32 * 8, cudaMemcpyDeviceToHost);

    for (int i = 0; i < 1; i++) {
        std::cout << output[i] << " ";
    }
    std::cout << std::endl;
}
$ nvcc -o t2242 t2242.cu -arch=sm_70
$ ./t2242 1
2.24008
$ ./t2242 2
4.48016
$ ./t2242 10000
22374.7
$

Note that fp32 only has sufficient mantissa bits to represent about 6-7 decimal digits, and likewise fp16 only has sufficient mantissa bits to represent ~3 decimal digits. Therefore the result above being accurate to about 3 digits is about as much as you should expect. (2.24 is the correct result for a single iteration, so 22400 would be the correct result for 10000 iterations. The above result, rounded to 3 decimal places, is 22400).

2 Likes

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.