cuBLAS GEMM INT8 is much slower than FP16 in T4

We tried to use GEMM with INT8 (using cuBLAS GEMMEX API), but we met the following issues,

  1. In our typical settings, M=768, N=786432, K=128, GEMM with INT8 (volta_sgemm_int8_128x128_nt) is much slower than FP16 (turing_h1688gemm_128x128_ldg8_nt), 21.443ms vs. 8.6957ms. I changed to CUDA version from 10.1 to 11.2, the performane results are same.

  2. We would like to use UINT8 instead of INT8, How to configure the cublasGemmEx? It is not clear in the cuBLAS manual. I try to use CUDA_R_8U instead of CUDA_R_8I, but the results seems wrong.

Our benchmark code: GitHub - Junsong-Wang/cuBLASTest

The test is performed in Telas T4 card, with Driver Version: 418.181.07, CUDA Version: 10.1

Attached the Test Results:

root@c0dca262005a:~/cuBLASTest/build# nvprof ./cublastest
==7890== NVPROF is profiling process 7890, command: ./cublastest
===== start to test HGEMM, M=768, N=786432, K=128, test iterations:16 =====
FP16, total Time (timeofday) in 16 interations is 1.91351s.
===== start to test GEMMEx(INT8), M=768, N=786432, K=128, test iterations:16 =====
INT8, total Time (timeofday) in 16 interations is 3.74584s.
==7890== Profiling application: ./cublastest
==7890== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   90.96%  5.17309s        32  161.66ms  110.38ms  213.00ms  [CUDA memcpy DtoH]
                    6.03%  343.08ms        16  21.443ms  21.181ms  21.899ms  volta_sgemm_int8_128x128_nt
                    2.45%  139.13ms        16  8.6957ms  7.7914ms  12.507ms  turing_h1688gemm_128x128_ldg8_nt
                    0.56%  31.810ms         5  6.3621ms  2.0160us  20.800ms  [CUDA memcpy HtoD]
      API calls:   59.41%  5.68842s        36  158.01ms  65.628us  234.50ms  cudaMemcpy2D
                   22.32%  2.13684s         8  267.10ms  33.173us  1.18106s  cudaHostAlloc
                   10.39%  995.07ms         9  110.56ms  1.0420us  652.45ms  cudaFree
                    7.79%  746.06ms         6  124.34ms  59.643us  447.81ms  cudaFreeHost
                    0.05%  4.8069ms         6  801.16us  61.777us  2.4387ms  cudaMallocPitch
                    0.02%  1.7026ms        32  53.204us  29.020us  69.090us  cudaLaunchKernel
                    0.01%  852.75us         3  284.25us  277.59us  295.24us  cuDeviceTotalMem
                    0.01%  610.51us       285  2.1420us     158ns  97.318us  cuDeviceGetAttribute
                    0.00%  414.18us         3  138.06us  7.6690us  384.40us  cudaMalloc
                    0.00%  276.94us        80  3.4610us     933ns  13.357us  cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
                    0.00%  123.31us       169     729ns     428ns  7.0680us  cudaFuncSetAttribute
                    0.00%  119.43us         3  39.809us  33.549us  46.820us  cuDeviceGetName
                    0.00%  40.201us         1  40.201us  40.201us  40.201us  cudaMemcpy
                    0.00%  17.637us        16  1.1020us     517ns  7.7770us  cudaEventCreateWithFlags
                    0.00%  12.096us        32     378ns     225ns     577ns  cudaGetLastError
                    0.00%  8.4200us         1  8.4200us  8.4200us  8.4200us  cuDeviceGetPCIBusId
                    0.00%  6.4310us        11     584ns     345ns  1.7760us  cudaDeviceGetAttribute
                    0.00%  5.9760us         2  2.9880us  2.9040us  3.0720us  cuInit
                    0.00%  5.0820us         1  5.0820us  5.0820us  5.0820us  cudaGetDevice
                    0.00%  3.7190us         5     743ns     250ns  2.1850us  cuDeviceGetCount
                    0.00%  2.0360us         4     509ns     189ns     983ns  cuDeviceGet
                    0.00%  1.3220us         2     661ns     526ns     796ns  cuDriverGetVersion
                    0.00%     884ns         3     294ns     290ns     304ns  cuDeviceGetUuid

If you think you’re getting incorrect results, I suggest filing a bug. How to report a bug

I don’t think it is a bug, maybe I didn’t use the API correctly or the API can not support UINT8, because in the cuBLAS manual, there is no explicit statement of supporting UINT8. For the first performance issue, I have posted our benchmark, could you help to see it, thanks.

We also run the benchmark in the latested RTX3090, INT8 is still much slower than fp16.

The profile is conducted with Nsight Compute,

I aslo attacted the ncu report.
profile_rtx3090.ncu-rep (8.5 MB)

By investigating the profiled metrics, I fond for INT8, it seems the tensorcore is not enabled, could you help to see why?

I have the same issue on A100.
Driver Version: 525.105.17
CUDA Version: 1.1, V11.1.105

I tried to leverage the tensor core to perform int8 matrix multiplication, but it gave slower results than FP16.

      ....
      // Enable tensor cores
      cublasSetMathMode(cublasHandle, CUBLAS_TF32_TENSOR_OP_MATH);
      ....
      cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                  MATRIX_M, MATRIX_N, MATRIX_K, 
                  &alpha,
                  a_int8_device, CUDA_R_8I, MATRIX_M,
                  b_int8_device, CUDA_R_8I, MATRIX_K,
                  &beta, 
                  c_int32_device, CUDA_R_32I, MATRIX_M,
                  CUBLAS_COMPUTE_32I,
                  CUBLAS_GEMM_DEFAULT_TENSOR_OP);  // Use tensor cores op

After studying the document for a few hours, I guess the API does support int8 with tensor cores.

According to the document, cublasGemmEx computeType with tensor core only supports: CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F_FAST_16F, CUBLAS_COMPUTE_32F_FAST_16BF, CUBLAS_COMPUTE_32F_FAST_TF32

Not sure if this is the reason. Would someone confirm this? Thanks.

An important aspect to use INT8 tensorcore in CUBLAS GEMM style operations is given in the note in the documentation for the cublasGemmEx() function:

CUBLAS_COMPUTE_32I and CUBLAS_COMPUTE_32I_PEDANTIC compute types are only supported with A, B being 4-byte aligned and lda, ldb being multiples of 4. For a better performance, it is also recommended that IMMA kernels requirements for a regular data ordering are met (listed here).

An “IMMA kernel” is an integer tensorcore kernel. So the implication is that specific conditions must be met to use tensorcore for integer work. If we follow that last “listed here” link to the proper place, we see:

To use IMMA kernels, one of the following sets of requirements, with the first being the preferred one, must be met:

  1. Using a regular data ordering:
  • All matrix pointers must be 4-byte aligned. For even better performance, this condition should hold with 16 instead of 4.
  • Leading dimensions of matrices A, B, C must be multiples of 4.
  • Only the “TN” format is supported - A must be transposed and B non-transposed.
  • Dimensions m and k must be multiples of 4.
  1. Using the IMMA-specific data ordering - CUBLASLT_ORDER_COL32 for matrices A,C,D, and CUBLASLT_ORDER_COL4_4R2_8C (on Turing or Ampere architecture) or CUBLASLT_ORDER_COL32_2R_4R4 (on Ampere architecture) for matrix B:
  • Leading dimensions of matrices A, B, C must fulfill conditions specific to the memory ordering (see cublasLtOrder_t).
  • Matmul descriptor must specify CUBLAS_OP_T on matrix B and CUBLAS_OP_N (default) on matrix A and C.
  • If scaleType CUDA_R_32I is used, the only supported values for alpha and beta are 0 or 1.

Those are important notes to witness INT8 calculations on Tensorcore. There are two recipes given there, I will follow the first, in particular by choosing appropriate dimensions and choosing A transposed and B non-transposed.

Here is an example using CUDA 12.0 on Ampere A100:

$ cat t2.cu
#include <cublas_v2.h>
#include <iostream>
#ifdef USE_INT8
using mt = char;
using rt = int;
using st = int;
cudaDataType   Atype = CUDA_R_8I;
cudaDataType   Ctype = CUDA_R_32I;
cublasComputeType_t   computeType = CUBLAS_COMPUTE_32I;
#else
// using FP16
#include <cuda_fp16.h>
using mt = half;
using rt = half;
using st = half;
cudaDataType   Atype = CUDA_R_16F;
cudaDataType   Ctype = CUDA_R_16F;
cublasComputeType_t   computeType = CUBLAS_COMPUTE_16F;
#endif
int main(){

  int dim = 4096;
  int m = dim;
  int n = dim;
  int k = dim;
  mt *A, *B;
  rt *C;
  cudaMalloc(&A, sizeof(A[0])*m*k);
  cudaMalloc(&B, sizeof(B[0])*n*k);
  cudaMalloc(&C, sizeof(C[0])*m*n);
  st alpha = 1;
  st beta = 0;
  cublasHandle_t h;
  cublasStatus_t stat = cublasCreate(&h);
  stat = cublasGemmEx(h,
                           CUBLAS_OP_T,
                           CUBLAS_OP_N,
                           m,
                           n,
                           k,
                           &alpha,
                           A,
                           Atype,
                           dim,
                           B,
                           Atype,
                           dim,
                           &beta,
                           C,
                           Ctype,
                           dim,
                           computeType,
                           CUBLAS_GEMM_DEFAULT);
  std::cout << (int)stat << std::endl;
  cudaDeviceSynchronize();
  cudaError_t err = cudaGetLastError();
  std::cout << cudaGetErrorString(err) << std::endl;
}
$ nvcc -o t2 t2.cu -lcublas
$ nsys nvprof --print-gpu-trace ./t2
WARNING: t2 and any of its children processes will be profiled.

0
no error
Generating '/tmp/nsys-report-8be3.qdstrm'
[1/3] [========================100%] report6.nsys-rep
[2/3] [========================100%] report6.sqlite
[3/3] Executing 'gputrace' stats report

  Start (ns)    Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd           Device            Ctx  Strm                      Name               
 -------------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------------------  ---  ----  ---------------------------------------------
 1,611,077,622        557,631   3,491    16    32     1   256     1     1      174         0.049         0.098                                                     NVIDIA A100-SXM4-40GB (0)    1     7  ampere_h16816gemm_256x128_ldg8_stages_64x3_tn

Generated:
    /home/.../report6.nsys-rep
    /home/.../report6.sqlite
$ nvcc -o t2 t2.cu -lcublas -DUSE_INT8
$ nsys nvprof --print-gpu-trace ./t2
WARNING: t2 and any of its children processes will be profiled.

0
no error
Generating '/tmp/nsys-report-be40.qdstrm'
[1/3] [========================100%] report7.nsys-rep
[2/3] [========================100%] report7.sqlite
[3/3] Executing 'gputrace' stats report

  Start (ns)    Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd           Device            Ctx  Strm                                             Name
 -------------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------------------  ---  ----  -------------------------------------------------------------------------------------------
 1,565,091,375        420,671   3,180   512     4     1   128     1     1      156         0.000         0.074                                                     NVIDIA A100-SXM4-40GB (0)    1     7  void cutlass::Kernel<cutlass_80_tensorop_i16832gemm_s8_128x64_128x3_tn_align16>(T1::Params)

Generated:
    /home/.../report7.nsys-rep
    /home/.../report7.sqlite
$

We see that in the first compilation case (FP16), the kernel invoked is ampere_h16816gemm_256x128_ldg8_stages_64x3_tn which is a FP16 TC kernel, and the kernel duration is ~558 microseconds.

In the second compilation case (INT8), the kernel invoked is cutlass::Kernel<cutlass_80_tensorop_i16832gemm_s8_128x64_128x3_tn_align16>(T1::Params) which is a INT8 TC kernel, and the kernel duration is ~421 microseconds, so somewhat faster than the FP16 kernel.

From what I can see of the documentation, there is no recipe that allows CUBLAS_OP_N on both A and B, if you want to witness INT8 TC usage.

1 Like

Thanks for the code, it is helpful!

Got one question, to call cutlass::Kernel<cutlass_80_tensorop_i16832gemm_s8_128x64_128x3_tn_align16>, seems must do this:
cublasGemmEx(h, CUBLAS_OP_T, CUBLAS_OP_N, …). i.e, mat A must transpose while B mustn’t transpose.

So, what’s the actual criterion to call this kernel?

I don’t have any further criterion beyond what I mentioned.

thanks.
that’s interesting. It must require matrix A set ‘CUBLAS_OP_T’ and matrix B set ‘CUBLAS_OP_N’, otherwise anther kernel "ampere_igemm_int8_128x128_ldg4_nn’ will be called, which is significantly slower than “cutlass::Kernel<cutlass_80_tensorop_i16832gemm_s8_128x64_128x3_tn_align16>”.

That is exactly what was already stated here:

Really helps! Thanks