cuBLAS INT8 tensor core mode vs. FP16 mode

Hi all,

I recently acquired an RTX card and was testing the new INT8 tensor core mode supported by Turing. I put together a simple test program (based on the “Programming Tensor Cores” devblogs article) to compare the execution times of INT8 mode vs. FP16 mode using the tensor cores. Strangely the execution times of tensor-FP16 mode and tensor-INT8 mode are practically the same. I was expecting much better execution times for tensor-INT8 mode since it’s supposed to have nearly twice the throughput of tensor-FP16 mode.

Here’s the timing results (this is for 16384x16384 matrices):

cublas FP16 with tensor cores
0: cublas time (ms): 314.191833
1: cublas time (ms): 316.307465
2: cublas time (ms): 314.961639
3: cublas time (ms): 314.648590
4: cublas time (ms): 313.170502
5: cublas time (ms): 316.192474
6: cublas time (ms): 313.694214
7: cublas time (ms): 315.624695
8: cublas time (ms): 313.759094
9: cublas time (ms): 313.800476
average time (ms): 314.635101

cublas INT8 with tensor cores
0: cublas time (ms): 309.059052
1: cublas time (ms): 309.326996
2: cublas time (ms): 308.243988
3: cublas time (ms): 308.633636
4: cublas time (ms): 309.602264
5: cublas time (ms): 310.339111
6: cublas time (ms): 309.275238
7: cublas time (ms): 308.934967
8: cublas time (ms): 310.953979
9: cublas time (ms): 308.894135
average time (ms): 309.326324

Anyone have an idea why the execution times are nearly the same despite the supposed throughput advantage of tensor-INT8 mode ?

Source code is as follows:

//
// source: https://github.com/NVIDIA-developer-blog/code-samples/blob/master/posts/tensor-cores/simpleTensorCoreGEMM.cu
//	   https://devblogs.nvidia.com/parallelforall/programming-tensor-cores-cuda-9/
//

#include <stdio.h>
#include <cublas_v2.h>

// Define some error checking macros.
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); }
void cudaErrCheck_(cudaError_t stat, const char *file, int line) {
  if (stat != cudaSuccess) {
    fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line);
  }
}

#define cublasErrCheck(stat) { cublasErrCheck_((stat), __FILE__, __LINE__); }
void cublasErrCheck_(cublasStatus_t stat, const char *file, int line) {
  if (stat != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "cuBLAS Error: %d %s %d\n", stat, file, line);
  }
}

// host code

int main(int argc, char* argv[])
{
  // variable declarations

  half *a_fp16;
  half *b_fp16;

  float *c_fp32;

  int8_t *a_i8;
  int8_t *b_i8;

  int *c_i32;

  cublasHandle_t cublasHandle;
   
  cudaEvent_t startcublas;
  cudaEvent_t stopcublas;

  // process command-line args

  cudaErrCheck(cudaSetDevice(atoi(argv[1])));

  int MatDim = atoi(argv[2]);

  // create timing events

  cudaErrCheck(cudaEventCreate(&startcublas));
  cudaErrCheck(cudaEventCreate(&stopcublas));

  // create CUBLAS handle

  cublasErrCheck(cublasCreate(&cublasHandle));

  // allocate device side memory

  cudaErrCheck(cudaMalloc((void**)&a_fp16, MatDim * MatDim * sizeof(half)));
  cudaErrCheck(cudaMalloc((void**)&b_fp16, MatDim * MatDim * sizeof(half)));

  cudaErrCheck(cudaMalloc((void**)&c_fp32, MatDim * MatDim * sizeof(float)));

  cudaErrCheck(cudaMalloc((void**)&a_i8, MatDim * MatDim * sizeof(int8_t)));
  cudaErrCheck(cudaMalloc((void**)&b_i8, MatDim * MatDim * sizeof(int8_t)));

  cudaErrCheck(cudaMalloc((void**)&c_i32, MatDim * MatDim * sizeof(int)));

  // perform FP16 CUBLAS matmul without tensor cores

  printf("\ncublas FP16 without tensor cores\n");

  cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH));

  float alpha_fp32 = 1.0f;
  float beta_fp32 = 0.0f;

  float cublasTime, cublasTimeTot = 0.0f;

  for (int l=0; l<10; l++) {
    cudaErrCheck(cudaEventRecord(startcublas));

    cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                MatDim, MatDim, MatDim, 
                &alpha_fp32,
                a_fp16, CUDA_R_16F, MatDim,
                b_fp16, CUDA_R_16F, MatDim,
                &beta_fp32, 
                c_fp32, CUDA_R_32F, MatDim,
		CUDA_R_32F, CUBLAS_GEMM_DEFAULT));

    cudaErrCheck(cudaEventRecord(stopcublas));
    cudaErrCheck(cudaEventSynchronize(stopcublas));
    cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

    cublasTimeTot += cublasTime;

    printf("%d: cublas time (ms): %f\n", l, cublasTime);
  }

  printf("average time (ms): %f\n\n", cublasTimeTot/10);

  // perform FP16 CUBLAS matmul with tensor cores

  printf("cublas FP16 with tensor cores\n");

  cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));

  cublasTimeTot = 0.0f;

  for (int l=0; l<10; l++) {
    cudaErrCheck(cudaEventRecord(startcublas));

    cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                MatDim, MatDim, MatDim, 
                &alpha_fp32,
                a_fp16, CUDA_R_16F, MatDim,
                b_fp16, CUDA_R_16F, MatDim,
                &beta_fp32, 
                c_fp32, CUDA_R_32F, MatDim,
		CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

    cudaErrCheck(cudaEventRecord(stopcublas));
    cudaErrCheck(cudaEventSynchronize(stopcublas));
    cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

    cublasTimeTot += cublasTime;

    printf("%d: cublas time (ms): %f\n", l, cublasTime);
  }

  printf("average time (ms): %f\n\n", cublasTimeTot/10);

  // perform INT8 CUBLAS matmul without tensor cores

  printf("cublas INT8 without tensor cores\n");

  cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH));

  int alpha_i32 = 1;
  int beta_i32 = 0;

  cublasTimeTot = 0.0f;

  for (int l=0; l<10; l++) {
    cudaErrCheck(cudaEventRecord(startcublas));

    cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                MatDim, MatDim, MatDim, 
                &alpha_i32,
                a_i8, CUDA_R_8I, MatDim,
                b_i8, CUDA_R_8I, MatDim,
                &beta_i32, 
                c_i32, CUDA_R_32I, MatDim,
		CUDA_R_32I, CUBLAS_GEMM_DEFAULT));

    cudaErrCheck(cudaEventRecord(stopcublas));
    cudaErrCheck(cudaEventSynchronize(stopcublas));
    cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

    cublasTimeTot += cublasTime;

    printf("%d: cublas time (ms): %f\n", l, cublasTime);
  }

  printf("average time (ms): %f\n\n", cublasTimeTot/10);

  // perform INT8 CUBLAS matmul with tensor cores

  printf("cublas INT8 with tensor cores\n");

  cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));

  cublasTimeTot = 0.0f;

  for (int l=0; l<10; l++) {
    cudaErrCheck(cudaEventRecord(startcublas));

    cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                MatDim, MatDim, MatDim, 
                &alpha_i32,
                a_i8, CUDA_R_8I, MatDim,
                b_i8, CUDA_R_8I, MatDim,
                &beta_i32, 
                c_i32, CUDA_R_32I, MatDim,
		CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

    cudaErrCheck(cudaEventRecord(stopcublas));
    cudaErrCheck(cudaEventSynchronize(stopcublas));
    cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

    cublasTimeTot += cublasTime;

    printf("%d: cublas time (ms): %f\n", l, cublasTime);
  }

  printf("average time (ms): %f\n\n", cublasTimeTot/10);

  // clean up

  cudaErrCheck(cudaEventDestroy(startcublas));             
  cudaErrCheck(cudaEventDestroy(stopcublas));

  cudaErrCheck(cudaFree(a_fp16));
  cudaErrCheck(cudaFree(b_fp16));
  cudaErrCheck(cudaFree(c_fp32));

  cudaErrCheck(cudaFree(a_i8));
  cudaErrCheck(cudaFree(b_i8));
  cudaErrCheck(cudaFree(c_i32));

  cudaErrCheck(cudaDeviceReset());

  // all done

  return 0;

} // main

// End-of-File