While learning tensor core low level code example (using wmma), mismatch occurs between wmma and cublas results when comparing

while I learning of tensor core low level code example using wmma::, i found there is a mismatch in a resulting c matrxi where tensor core and cublas sums were compared and large number of mismatches occurred. Because i am learning still tensor core I could not find the error is. Can you investigate?
code-samples/posts/tensor-cores


[root@localhost tensor-cores]# ls -l
total 832
-rw-r--r--. 1 root root   1685 Feb  5 08:01 Makefile
-rw-r--r--. 1 root root    317 Feb  5 08:01 README.md
-rwxr-xr-x. 1 root root 830920 Feb  5 08:01 TCGemm
-rw-r--r--. 1 root root  11380 Feb  5 08:01 simpleTensorCoreGEMM.cu
[root@localhost tensor-cores]# make && ./TCGemm
nvcc -o TCGemm -arch=sm_70 -lcublas -lcurand simpleTensorCoreGEMM.cu

M = 16384, N = 16384, K = 16384. alpha = 2.000000, beta = 2.000000

Running with wmma...
Running with cuBLAS...

Checking results...
8266.587891 8267.766602
8240.230469 8241.420898
8242.393555 8243.574219
8209.478516 8210.649414
8100.519043 8101.664062
8251.499023 8252.675781
8189.156738 8190.297852
8260.410156 8261.580078
8311.802734 8313.015625
WMMA does not agree with cuBLAS! 268435456 errors!
[root@localhost tensor-cores]# git remote -v
origin  https://github.com/NVIDIA-developer-blog/code-samples.git (fetch)
origin  https://github.com/NVIDIA-developer-blog/code-samples.git (push)

@g900nvda I cannot reproduce this observation, the error check at the end of the app passes for me. Please add information about the exact CUDA version that was used to build this example code and the GPU you are running on.

The magnitude of the differences reported by OP seem plausibly explainable by accumulated rounding error when operating on random matrices of this size. The sample code uses the default generator of CURAND. I wonder whether changes have been made to this default generator in recent years, possibly leading to differences in the matrices generated.

1 Like

interesting, i am using workstation with RTX2070 and Centos 9 stream with 12.2 cuda toolkit + driver. But the example is running on container: centos 9 stream with 12.3 cuda toolkit with 12.2 driver. What did you run on?

Secondly i am wondering if there is a simpler tensore core examples like this instead of MATRIX multiplication but addition of A + B = C for simplicity. I am somewhat struggling with this example.

Tensor cores are used for matrix multiplication.

C++ warp matrix operations leverage Tensor Cores to accelerate matrix problems of the form D=A*B+C

If you want to do matrix addition, you need to express it as matrix multiplication.

The sample code cannot really be simplified further. Both input matrices and the output matrix are split into tiles of size 16x16. Each warp computes one tile of the output matrix by multiplying and accumulating the corresponding tiles of A and B.

oops i guess wmma! that is the essence of it! thx. I dont think should express as mm, just deal with mm.

That is an illegal combination. You cannot use CUDA toolkit 12.3 on a GPU driver that supports CUDA 12.2. Furthermore, the forward compatibility library setup in CUDA is not applicable to GeForce GPUs.

sorry, i got it other way around:
12.3 driver with 12.2 running on container. will it work?

i might upgrade everything to 12.3 and retry.

upgraded everything to 12.3 but still lot of errors:

[root@localhost tensor-cores]# make
nvcc -o TCGemm -arch=sm_70 -lcublas -lcurand simpleTensorCoreGEMM.cu
[root@localhost tensor-cores]# ./TCGemm
Makefile README.gg README.md TCGemm simpleTensorCoreGEMM.cu
[root@localhost tensor-cores]# ./TCGemm

M = 16384, N = 16384, K = 16384. alpha = 2.000000, beta = 2.000000

Running with wmma…
grid: 256, 256, 1, blockDim: 128, 4, 1.
warpSize: 32, M,N,K: 16384, 16384, 16384.
Running with cuBLAS…

Checking results…
8266.587891 8267.766602
8240.230469 8241.420898
8242.393555 8243.574219
8209.478516 8210.649414
8100.519043 8101.664062
8251.499023 8252.675781
8189.156738 8190.297852
8260.410156 8261.580078
8311.802734 8313.015625
WMMA does not agree with cuBLAS! 268435456 errors!
[root@localhost tensor-cores]# yum list installed | grep toolkit
[root@localhost tensor-cores]# yum list installed | grep cuda | head -1
[root@localhost tensor-cores]# yum list installed | grep cuda -i | head -
[root@localhost tensor-cores]# nvidia-smi
Tue Feb 6 05:52:38 2024
±--------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------±---------------------±---------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 2070 … Off | 00000000:01:00.0 Off | N/A |
| 22% 43C P0 N/A / 215W | 0MiB / 8192MiB | 0% Default |
| | | N/A |
±----------------------------------------±---------------------±---------------------+

±--------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
±--------------------------------------------------------------------------------------+
[root@localhost tensor-cores]# ls -l /usr/local/cuda
cuda/ cuda-12.3/
[root@localhost tensor-cores]# ls -l /usr/local/cuda*
lrwxrwxrwx. 1 root root 21 Feb 6 04:05 /usr/local/cuda → /usr/local/cuda-12.3/

/usr/local/cuda-12.3:
total 92
-rw-r–r–. 1 root root 160 Feb 6 04:06 DOCS
-rw-r–r–. 1 root root 61498 Feb 6 04:06 EULA.txt
-rw-r–r–. 1 root root 524 Feb 6 04:06 README
drwxr-xr-x. 3 root root 4096 Feb 6 04:06 bin
drwxr-xr-x. 5 root root 4096 Feb 6 04:06 compute-sanitizer
drwxr-xr-x. 5 root root 53 Feb 6 04:06 extras
drwxr-xr-x. 6 root root 91 Feb 6 04:06 gds
drwxr-xr-x. 2 root root 40 Feb 6 04:06 gds-12.3
lrwxrwxrwx. 1 root root 28 Feb 6 04:06 include → targets/x86_64-linux/include
lrwxrwxrwx. 1 root root 24 Feb 6 04:06 lib64 → targets/x86_64-linux/lib
drwxr-xr-x. 7 root root 4096 Feb 6 04:06 libnvvp
drwxr-xr-x. 8 root root 130 Feb 6 04:06 nsight-compute-2023.3.1
drwxr-xr-x. 6 root root 100 Feb 6 04:06 nsight-systems-2023.3.3
drwxr-xr-x. 2 root root 53 Feb 6 04:06 nsightee_plugins
drwxr-xr-x. 3 root root 21 Feb 6 04:05 nvml
drwxr-xr-x. 6 root root 62 Feb 6 04:06 nvvm
drwxr-xr-x. 3 root root 17 Feb 6 04:06 share
drwxr-xr-x. 2 root root 4096 Feb 6 04:06 src
drwxr-xr-x. 3 root root 26 Feb 6 04:05 targets
drwxr-xr-x. 2 root root 43 Feb 6 04:06 tools
-rw-r–r–. 1 root root 2857 Feb 6 04:05 version.json
[root@localhost tensor-cores]#

Further up in the thread you mentioned building with

RTX 2070 has compute capability 7.5 (Turing), so I would suggest building with -arch=sm_75. When I stated earlier that I had built and run this sample app without issues I was running on a Turing class GPU. But I have CUDA version 11.x installed on that machine.

As before, the mismatches you are observing are not indicative of either hardware or software being broken. My working hypothesis is that there is a difference in the generated matrices between CUDA versions and that this leads to slightly higher relative errors (such as 1.5e-4 instead of the 1.0e-4 limit used by the code) when compared to the CUBLAS. It is also possible that changes were made to the CUBLAS GEMM implementation that impacts this specific comparison (without being indicative of any real issue).

For a quick check, you could change the “magic” limit of 1e-4 in the code to 1.5e-4 to see whether that makes the mismatches go away. I am calling the error limit “magic” because my assumption is that this was just a value picked when this app first was created that would be sufficiently tight to flag real errors but large enough to avoid false positives due to accumulated rounding error.

sm=75 worked, thx! will have to take some time to digest this
[root@localhost tensor-cores]# make ; ./TCGemm
nvcc -o TCGemm -arch=sm_75 -lcublas -lcurand simpleTensorCoreGEMM.cu

M = 16384, N = 16384, K = 16384. alpha = 2.000000, beta = 2.000000

Running with wmma…
grid: 256, 256, 1, blockDim: 128, 4, 1.
warpSize: 32, M,N,K: 16384, 16384, 16384.
Running with cuBLAS…

Checking results…
Results verified: cublas and WMMA agree.

wmma took 1974.263672ms
cublas took 228.444153ms

For a faster code using wmma you should check out the cudaTensorCoreGemm sample in the CUDA Toolkit.
This code was written as a demo only!

[root@localhost tensor-cores]#

Interesting. When you compile with -arch=sm_70 but run on a GPU with CC 7.5, the PTX intermediate representation gets JIT-compiled for CC 7.5 since the embedded binary code for CC 7.0 cannot be used. When you compile with arch=sm_75 the PTX gets compiled (offline) into machine code for CC7.5.

Numerical differences due to the two compilation paths (offline compiler vs JIT compiler) could point to an issue with the latest JIT compiler (e.g. more aggressive re-association of floating-point expressions), which could be an intentional change or a bug.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.