The best input layout settings in CuBlas

I understand that the memory layout of input matrices affects the performance of cuBLAS GEMM. According to the information I’ve found ( cuBLAS related question - CUDA / CUDA Programming and Performance - NVIDIA Developer Forums), the NT case (that is, for A*B, A is row-major and B is column-major) should be the fastest. However, I observed different phenomena on my RTX 3080 Ti 3080 laptop. I ran the simpleCublas example from the CUDA Toolkit samples and set cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);. Then I added some precise timing code. I tested several sizes, such as m=n=k=1024, 1048, 4096. I found that the TN case was always slightly better than the NT case. So, has NVIDIA now done some special optimization for the TN case ? Or could someone verify this phenomenon? thanks.

the test results :
m=n=k=4096
TN: 3951 us
NT: 4247 us

m=n=k=2048
TN: 704 us
NT: 780 us

m=n=k=1024:
TN: 205 us
NT: 220 us

my codes :


/* Includes, system */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <helper_cuda.h>

#include "C:\F_develop\1-cuda\1-cuda_tookit_samples\cuda-samples-11.6\cuda-samples-11.6\Samples\4_CUDA_Libraries\batchCUBLAS\batchCUBLAS.h"

/* Matrix size */
#define N (4096)


/* Main */
int main(int argc, char** argv) {
    cublasStatus_t status;
    float* h_A;
    float* h_B;
    float* h_C;
    //float *h_C_ref;
    float* d_A = 0;
    float* d_B = 0;
    float* d_C = 0;
    float alpha = 1.0f;
    float beta = 0.0f;
    int n2 = N * N;
    int i;
    float error_norm;
    float ref_norm;
    float diff;
    cublasHandle_t handle;

    int dev = findCudaDevice(argc, (const char**)argv);

    if (dev == -1) {
        return EXIT_FAILURE;
    }

    /* Initialize CUBLAS */
    printf("simpleCUBLAS test running..\n");

    status = cublasCreate(&handle);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! CUBLAS initialization error\n");
        return EXIT_FAILURE;
    }

    status = cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);  // enable tensor core

    /* Allocate host memory for the matrices */
    h_A = reinterpret_cast<float*>(malloc(n2 * sizeof(h_A[0])));

    if (h_A == 0) {
        fprintf(stderr, "!!!! host memory allocation error (A)\n");
        return EXIT_FAILURE;
    }

    h_B = reinterpret_cast<float*>(malloc(n2 * sizeof(h_B[0])));

    if (h_B == 0) {
        fprintf(stderr, "!!!! host memory allocation error (B)\n");
        return EXIT_FAILURE;
    }

    h_C = reinterpret_cast<float*>(malloc(n2 * sizeof(h_C[0])));

    if (h_C == 0) {
        fprintf(stderr, "!!!! host memory allocation error (C)\n");
        return EXIT_FAILURE;
    }

    /* Fill the matrices with test data */
    for (i = 0; i < n2; i++) {
        h_A[i] = 1;
        h_B[i] = 1;
        h_C[i] = 1;
    }

    /* Allocate device memory for the matrices */
    if (cudaMalloc(reinterpret_cast<void**>(&d_A), n2 * sizeof(d_A[0])) !=
        cudaSuccess) {
        fprintf(stderr, "!!!! device memory allocation error (allocate A)\n");
        return EXIT_FAILURE;
    }

    if (cudaMalloc(reinterpret_cast<void**>(&d_B), n2 * sizeof(d_B[0])) !=
        cudaSuccess) {
        fprintf(stderr, "!!!! device memory allocation error (allocate B)\n");
        return EXIT_FAILURE;
    }

    if (cudaMalloc(reinterpret_cast<void**>(&d_C), n2 * sizeof(d_C[0])) !=
        cudaSuccess) {
        fprintf(stderr, "!!!! device memory allocation error (allocate C)\n");
        return EXIT_FAILURE;
    }

    /* Initialize the device matrices with the host matrices */
    status = cublasSetVector(n2, sizeof(h_A[0]), h_A, 1, d_A, 1);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! device access error (write A)\n");
        return EXIT_FAILURE;
    }

    status = cublasSetVector(n2, sizeof(h_B[0]), h_B, 1, d_B, 1);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! device access error (write B)\n");
        return EXIT_FAILURE;
    }

    status = cublasSetVector(n2, sizeof(h_C[0]), h_C, 1, d_C, 1);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! device access error (write C)\n");
        return EXIT_FAILURE;
    }


    double start, stop, total = 0.0;

    int repeat = 10;

    for (int i = 0; i < repeat; i++) {

        start = second();

        /* Performs operation using cublas */

        //status = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, N, N, &alpha, d_A,
        //    N, d_B, N, &beta, d_C, N);


        status = cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, N, N, &alpha, d_A,
            N, d_B, N, &beta, d_C, N);

        cudaError_t cudaStatus = cudaDeviceSynchronize();

        stop = second();

        total += (stop - start);
        fprintf(stdout, "^^^^ elapsed = %10.8f sec \n", (stop - start));

    }
    fprintf(stdout, "^^^^ average elapsed = %10.8f sec \n", total / repeat);


    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! kernel execution error.\n");
        return EXIT_FAILURE;
    }


    /* Memory clean up */
    free(h_A);
    free(h_B);

    if (cudaFree(d_A) != cudaSuccess) {
        fprintf(stderr, "!!!! memory free error (A)\n");
        return EXIT_FAILURE;
    }

    if (cudaFree(d_B) != cudaSuccess) {
        fprintf(stderr, "!!!! memory free error (B)\n");
        return EXIT_FAILURE;
    }

    if (cudaFree(d_C) != cudaSuccess) {
        fprintf(stderr, "!!!! memory free error (C)\n");
        return EXIT_FAILURE;
    }

    /* Shutdown */
    status = cublasDestroy(handle);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! shutdown error (A)\n");
        return EXIT_FAILURE;
    }

}

Using approximately your code, on my L4 GPU, on linux with CUDA 12.2, I found the NT case to be a bit faster. It’s entirely possible for cublas to behave slightly differently on different GPU architectures, or different GPUs. It may also vary by CUDA version, so if you are using CUDA 11.4, that may be another factor.

Here is my test case:

# cat t264.cu
/* Includes, system */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>


/* Matrix size */
#define N (4096)


/* Main */
int main(int argc, char** argv) {
    cublasStatus_t status;
    float* h_A;
    float* h_B;
    float* h_C;
    //float *h_C_ref;
    float* d_A = 0;
    float* d_B = 0;
    float* d_C = 0;
    float alpha = 1.0f;
    float beta = 0.0f;
    int n2 = N * N;
    int i;
    cublasHandle_t handle;

    int dev = 0;

    if (dev == -1) {
        return EXIT_FAILURE;
    }

    /* Initialize CUBLAS */
    printf("simpleCUBLAS test running..\n");

    status = cublasCreate(&handle);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! CUBLAS initialization error\n");
        return EXIT_FAILURE;
    }

    status = cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);  // enable tensor core

    /* Allocate host memory for the matrices */
    h_A = reinterpret_cast<float*>(malloc(n2 * sizeof(h_A[0])));

    if (h_A == 0) {
        fprintf(stderr, "!!!! host memory allocation error (A)\n");
        return EXIT_FAILURE;
    }

    h_B = reinterpret_cast<float*>(malloc(n2 * sizeof(h_B[0])));

    if (h_B == 0) {
        fprintf(stderr, "!!!! host memory allocation error (B)\n");
        return EXIT_FAILURE;
    }

    h_C = reinterpret_cast<float*>(malloc(n2 * sizeof(h_C[0])));

    if (h_C == 0) {
        fprintf(stderr, "!!!! host memory allocation error (C)\n");
        return EXIT_FAILURE;
    }

    /* Fill the matrices with test data */
    for (i = 0; i < n2; i++) {
        h_A[i] = 1;
        h_B[i] = 1;
        h_C[i] = 1;
    }

    /* Allocate device memory for the matrices */
    if (cudaMalloc(reinterpret_cast<void**>(&d_A), n2 * sizeof(d_A[0])) !=
        cudaSuccess) {
        fprintf(stderr, "!!!! device memory allocation error (allocate A)\n");
        return EXIT_FAILURE;
    }

    if (cudaMalloc(reinterpret_cast<void**>(&d_B), n2 * sizeof(d_B[0])) !=
        cudaSuccess) {
        fprintf(stderr, "!!!! device memory allocation error (allocate B)\n");
        return EXIT_FAILURE;
    }

    if (cudaMalloc(reinterpret_cast<void**>(&d_C), n2 * sizeof(d_C[0])) !=
        cudaSuccess) {
        fprintf(stderr, "!!!! device memory allocation error (allocate C)\n");
        return EXIT_FAILURE;
    }

    /* Initialize the device matrices with the host matrices */
    status = cublasSetVector(n2, sizeof(h_A[0]), h_A, 1, d_A, 1);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! device access error (write A)\n");
        return EXIT_FAILURE;
    }

    status = cublasSetVector(n2, sizeof(h_B[0]), h_B, 1, d_B, 1);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! device access error (write B)\n");
        return EXIT_FAILURE;
    }

    status = cublasSetVector(n2, sizeof(h_C[0]), h_C, 1, d_C, 1);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! device access error (write C)\n");
        return EXIT_FAILURE;
    }



    int repeat = 2;

    for (int i = 0; i < repeat; i++) {


        /* Performs operation using cublas */
#ifdef USE_NT
        status = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, N, N, &alpha, d_A,
            N, d_B, N, &beta, d_C, N);
#else

        status = cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, N, N, &alpha, d_A,
            N, d_B, N, &beta, d_C, N);
#endif
        cudaError_t cudaStatus = cudaDeviceSynchronize();



    }


    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! kernel execution error.\n");
        return EXIT_FAILURE;
    }


    /* Memory clean up */
    free(h_A);
    free(h_B);

    if (cudaFree(d_A) != cudaSuccess) {
        fprintf(stderr, "!!!! memory free error (A)\n");
        return EXIT_FAILURE;
    }

    if (cudaFree(d_B) != cudaSuccess) {
        fprintf(stderr, "!!!! memory free error (B)\n");
        return EXIT_FAILURE;
    }

    if (cudaFree(d_C) != cudaSuccess) {
        fprintf(stderr, "!!!! memory free error (C)\n");
        return EXIT_FAILURE;
    }

    /* Shutdown */
    status = cublasDestroy(handle);

    if (status != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "!!!! shutdown error (A)\n");
        return EXIT_FAILURE;
    }

}
# nvcc -o t264 t264.cu -lcublas
# nsys nvprof --print-gpu-trace ./t264
WARNING: t264 and any of its children processes will be profiled.

simpleCUBLAS test running..
Generating '/tmp/nsys-report-e836.qdstrm'
[1/3] [========================100%] report10.nsys-rep
[2/3] [========================100%] report10.sqlite
[3/3] Executing 'cuda_gpu_trace' stats report

 Start (ns)   Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd     Device      Ctx  Strm                                            Name                             
 -----------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------  ---  ----  -----------------------------------------------------------------------------------------
 871,260,644     17,751,373   1,174                                                                               67.109          3,758.096  Pageable  Device    NVIDIA L4 (0)    1     7  [CUDA memcpy HtoD]                                                         
 889,183,249     16,269,804   1,176                                                                               67.109          4,093.641  Pageable  Device    NVIDIA L4 (0)    1     7  [CUDA memcpy HtoD]                                                         
 905,625,789     16,187,244   1,178                                                                               67.109          4,093.641  Pageable  Device    NVIDIA L4 (0)    1     7  [CUDA memcpy HtoD]                                                         
 954,185,121      2,372,930   1,183   256     2     1   256     1     1      244         0.000         0.074                                                     NVIDIA L4 (0)    1     7  void cutlass::Kernel<cutlass_80_tensorop_s1688f16gemm_256x128_16x3_tn_align4>(T1::Params)
 956,592,099      2,397,250   1,188   256     2     1   256     1     1      244         0.000         0.074                                                     NVIDIA L4 (0)    1     7  void cutlass::Kernel<cutlass_80_tensorop_s1688f16gemm_256x128_16x3_tn_align4>(T1::Params)

Generated:
    /root/bobc/report10.nsys-rep
    /root/bobc/report10.sqlite
# nvcc -o t264 t264.cu -lcublas -DUSE_NT
# nsys nvprof --print-gpu-trace ./t264
WARNING: t264 and any of its children processes will be profiled.

simpleCUBLAS test running..
Generating '/tmp/nsys-report-276b.qdstrm'
[1/3] [========================100%] report11.nsys-rep
[2/3] [========================100%] report11.sqlite
[3/3] Executing 'cuda_gpu_trace' stats report

  Start (ns)    Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd     Device      Ctx  Strm                                            Name                           
 -------------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------  ---  ----  -----------------------------------------------------------------------------------------
   925,761,720     17,598,285   1,174                                                                               67.109          3,758.096  Pageable  Device    NVIDIA L4 (0)    1     7  [CUDA memcpy HtoD]                                                       
   943,531,109     16,568,365   1,176                                                                               67.109          4,026.532  Pageable  Device    NVIDIA L4 (0)    1     7  [CUDA memcpy HtoD]                                                       
   960,278,034     16,061,292   1,178                                                                               67.109          4,160.750  Pageable  Device    NVIDIA L4 (0)    1     7  [CUDA memcpy HtoD]                                                       
 1,009,513,943      2,154,626   1,183   128     4     1   256     1     1      239         0.000         0.074                                                     NVIDIA L4 (0)    1     7  void cutlass::Kernel<cutlass_80_tensorop_s1688f16gemm_128x256_16x3_nt_align4>(T1::Params)
 1,011,702,137      2,184,514   1,188   128     4     1   256     1     1      239         0.000         0.074                                                     NVIDIA L4 (0)    1     7  void cutlass::Kernel<cutlass_80_tensorop_s1688f16gemm_128x256_16x3_nt_align4>(T1::Params)

Generated:
    /root/bobc/report11.nsys-rep
    /root/bobc/report11.sqlite
#

CUBLAS SGEMM has a special behavior when you set the TENSOR_OP_MATH setting. As far as I know that setting is deprecated and CUBLAS SGEMM may someday no longer have that path to use Tensor Core. You will have to use an alternate CUBLAS API to use tensorcore. And if you are interested in information in/comparison with advice given in a 11 year-old thread, it probably doesn’t make sense to do a comparison with tensorcore, anyway.

I think its not likely, but I think it is possible that windows could have an effect on the timing you are measuring, since you haven’t shown your timing code (and I’m not really asking to see your timing code either). Work dispatch to your RTX 3080 GPU is subject to WDDM batching, which can have an impact on what you measure from a host measurement perspective. To get best data for a comparison like this, I would generally use a profiler as I have done.

1 Like

Thank you very much for your suggestions and experiments.

  1. The timing function I used is from “Samples\4_CUDA_Libraries\batchCUBLAS\batchCUBLAS.h”. I didn’t modify it, but as you mentioned, perhaps it’s not very reliable on Windows platforms?
    So I tried to follow your approach using nvprof, but it failed with the message “Nvprof Cannot be used with compute capability 8.0 and higher”. Therefore, I used the nsight-compute-gui tool, and the results were the same as my previous ones: TN-case slightly outperforms NT-case.

my results for NT and TN cases on Windows with CUDA 11.6 :

NT case - cuda-11.6-windows :

TN case - cuda-11,6-windows:

  1. You mentioned that the CUDA version might have an impact, which I strongly agree with. So I upgraded the CUDA version from 11.6 to 12.3, but the results remained the same: TN-case slightly outperforms NT-case.

my results for NT and TN cases on Windows with CUDA 12.3:

NT-case-cuda-12.3-windows:

TN-case-cuda-12.3-windows

  1. As you said, the OS might also affect the results. So I conducted experiments on Ubuntu 22.04 with CUDA 12.2, using exactly the same code and commands as you did. The nvprof results were still consistent with what I observed on Windows: TN-case slightly outperforms NT-case.

my results for NT and TN cases on Linux with CUDA 12.2 using nvprof and nsight-compute-gui :

NT-case-cuda-12-2-linux-nvprof:

TN-case-cuda-12-2-linux-nvprof:

NT-case-cuda-12-2-linux-nsc-gui:

TN-case-cuda-12-2-linux-nsc-gui:

So, after ruling out these factors, this difference is likely caused by different architectures. I have three more points of curiosity:

a. If NT is optimal, do the other three combinations (NN, TN, TT) perform transpose operations (e.g., in LDS) during execution to approach the NT case?

b. I also found a very old paper [link to paper]. Figure 1 in the paper shows that TN is the best performing among the four combinations.

c. The online slides about cutlass all use TN case as examples for explanation, and int8-tensor core only supports TN. So are you now focusing on optimizing TN to improve the speed of ai model forward inference?

thanks.

The tool I used is nsys, not nvprof. nsys happens to have a command line argument option that is called nvprof, but it does not use the nvprof tool. Anyway, ncu is a suitable alternative.

I don’t doubt that in various situations TN may give better performance than NT. The reason for my posting was to indicate that I don’t think you can draw that conclusion universally.

For questions pertaining to the internal workings of CUBLAS, I won’t be able to answer those: it is a closed-source library.

Considering the test case you posted for CUDA 12.2 on linux, the difference in performance from NT to TN appears to be about 4%. Given that you have a test case where TN appears to be faster than NT by about 4%, and I have a similar test case where NT appears to be faster than TN by about 10%, I would hesitate to draw directional conclusions or guidance from that. Neither one is dramatically faster than the other.

Thank you for your professional response.

I completely agree with your point. Actually, I came across the following information while working on my course assignment:

  1. Best Order For Performance · Issue #131 · NVIDIA/cutlass (github.com) :“Column-major * row-major (“NT”) GEMMs typically achieve the highest performance for nearly every math instruction (except integer-valued TensorCore operations).”

  2. cuBLAS related question - CUDA / CUDA Programming and Performance - NVIDIA Developer Forums : Here the (N,T) combination is the fastest

  3. the paper [link to paper]. Figure 1 in the paper shows that TN is the best performing among the four combinations

  4. [QST] Int8 support for Dgrad+Wgrad conv2d · Issue #411 · NVIDIA/cutlass · GitHub : Tensor core instruction is basically a tiny TN(row x col->row) gemm, it is not difficult to support it if the data in the global memory is also TN layout. However, for the other layouts, we need to transpose the data before loading the data to the registers to feed the tensor core. ldmatrix can do the the transpose for fp16 data, but not other data types.

Based on the first and second points, NT is the fastest. However, according to the third and fourth points, TN should be the fastest, especially considering the fourth point, which mentions that other layouts require additional transpose operations, and the ldmatrix instruction only supports transposition of fp16 data.

That’s why I was confused and decided to do some experiments and discuss it with you. In fact, I conducted additional experiments with fp16 data this morning, and NT turned out to be the fastest. NN and TT were slightly slower (by about 0.1ms), while TN was the slowest (by about 0.2ms).

One thing’s for sure, there isn’t much of a performance gap between them, however, If anyone has found the latest benchmark data for the four different layouts in cublas, it would be great if you could share it. Thanks.