cuBLAS INT8 tensor core mode vs. FP16 mode

Hi all,

I recently acquired an RTX card and was testing the new INT8 tensor core mode supported by Turing. I put together a simple test program (based on the “Programming Tensor Cores” devblogs article) to compare the execution times of INT8 mode vs. FP16 mode using the tensor cores. Strangely the execution times of tensor-FP16 mode and tensor-INT8 mode are practically the same. I was expecting much better execution times for tensor-INT8 mode since it’s supposed to have nearly twice the throughput of tensor-FP16 mode.

Here’s the timing results (this is for 16384x16384 matrices):

cublas FP16 with tensor cores
    0: cublas time (ms): 314.191833
    1: cublas time (ms): 316.307465
    2: cublas time (ms): 314.961639
    3: cublas time (ms): 314.648590
    4: cublas time (ms): 313.170502
    5: cublas time (ms): 316.192474
    6: cublas time (ms): 313.694214
    7: cublas time (ms): 315.624695
    8: cublas time (ms): 313.759094
    9: cublas time (ms): 313.800476
    average time (ms): 314.635101

    cublas INT8 with tensor cores
    0: cublas time (ms): 309.059052
    1: cublas time (ms): 309.326996
    2: cublas time (ms): 308.243988
    3: cublas time (ms): 308.633636
    4: cublas time (ms): 309.602264
    5: cublas time (ms): 310.339111
    6: cublas time (ms): 309.275238
    7: cublas time (ms): 308.934967
    8: cublas time (ms): 310.953979
    9: cublas time (ms): 308.894135
    average time (ms): 309.326324

Anyone have an idea why the execution times are nearly the same despite the supposed throughput advantage of tensor-INT8 mode ?

Source code is as follows:

//
    // source: https://github.com/NVIDIA-developer-blog/code-samples/blob/master/posts/tensor-cores/simpleTensorCoreGEMM.cu

    //	   https://devblogs.nvidia.com/parallelforall/programming-tensor-cores-cuda-9/

    //

    #include <stdio.h>
    #include <cublas_v2.h>

    // Define some error checking macros.
    #define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); }
    void cudaErrCheck_(cudaError_t stat, const char *file, int line) {
      if (stat != cudaSuccess) {
        fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line);
      }
    }

    #define cublasErrCheck(stat) { cublasErrCheck_((stat), __FILE__, __LINE__); }
    void cublasErrCheck_(cublasStatus_t stat, const char *file, int line) {
      if (stat != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "cuBLAS Error: %d %s %d\n", stat, file, line);
      }
    }

    // host code

    int main(int argc, char* argv[])
    {
      // variable declarations

      half *a_fp16;
      half *b_fp16;

      float *c_fp32;

      int8_t *a_i8;
      int8_t *b_i8;

      int *c_i32;

      cublasHandle_t cublasHandle;
       
      cudaEvent_t startcublas;
      cudaEvent_t stopcublas;

      // process command-line args

      cudaErrCheck(cudaSetDevice(atoi(argv[1])));

      int MatDim = atoi(argv[2]);

      // create timing events

      cudaErrCheck(cudaEventCreate(&startcublas));
      cudaErrCheck(cudaEventCreate(&stopcublas));

      // create CUBLAS handle

      cublasErrCheck(cublasCreate(&cublasHandle));

      // allocate device side memory

      cudaErrCheck(cudaMalloc((void**)&a_fp16, MatDim * MatDim * sizeof(half)));
      cudaErrCheck(cudaMalloc((void**)&b_fp16, MatDim * MatDim * sizeof(half)));

      cudaErrCheck(cudaMalloc((void**)&c_fp32, MatDim * MatDim * sizeof(float)));

      cudaErrCheck(cudaMalloc((void**)&a_i8, MatDim * MatDim * sizeof(int8_t)));
      cudaErrCheck(cudaMalloc((void**)&b_i8, MatDim * MatDim * sizeof(int8_t)));

      cudaErrCheck(cudaMalloc((void**)&c_i32, MatDim * MatDim * sizeof(int)));

      // perform FP16 CUBLAS matmul without tensor cores

      printf("\ncublas FP16 without tensor cores\n");

      cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH));

      float alpha_fp32 = 1.0f;
      float beta_fp32 = 0.0f;

      float cublasTime, cublasTimeTot = 0.0f;

      for (int l=0; l<10; l++) {
        cudaErrCheck(cudaEventRecord(startcublas));

        cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    MatDim, MatDim, MatDim, 
                    &alpha_fp32,
                    a_fp16, CUDA_R_16F, MatDim,
                    b_fp16, CUDA_R_16F, MatDim,
                    &beta_fp32, 
                    c_fp32, CUDA_R_32F, MatDim,
    		CUDA_R_32F, CUBLAS_GEMM_DEFAULT));

        cudaErrCheck(cudaEventRecord(stopcublas));
        cudaErrCheck(cudaEventSynchronize(stopcublas));
        cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

        cublasTimeTot += cublasTime;

        printf("%d: cublas time (ms): %f\n", l, cublasTime);
      }

      printf("average time (ms): %f\n\n", cublasTimeTot/10);

      // perform FP16 CUBLAS matmul with tensor cores

      printf("cublas FP16 with tensor cores\n");

      cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));

      cublasTimeTot = 0.0f;

      for (int l=0; l<10; l++) {
        cudaErrCheck(cudaEventRecord(startcublas));

        cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    MatDim, MatDim, MatDim, 
                    &alpha_fp32,
                    a_fp16, CUDA_R_16F, MatDim,
                    b_fp16, CUDA_R_16F, MatDim,
                    &beta_fp32, 
                    c_fp32, CUDA_R_32F, MatDim,
    		CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

        cudaErrCheck(cudaEventRecord(stopcublas));
        cudaErrCheck(cudaEventSynchronize(stopcublas));
        cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

        cublasTimeTot += cublasTime;

        printf("%d: cublas time (ms): %f\n", l, cublasTime);
      }

      printf("average time (ms): %f\n\n", cublasTimeTot/10);

      // perform INT8 CUBLAS matmul without tensor cores

      printf("cublas INT8 without tensor cores\n");

      cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH));

      int alpha_i32 = 1;
      int beta_i32 = 0;

      cublasTimeTot = 0.0f;

      for (int l=0; l<10; l++) {
        cudaErrCheck(cudaEventRecord(startcublas));

        cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    MatDim, MatDim, MatDim, 
                    &alpha_i32,
                    a_i8, CUDA_R_8I, MatDim,
                    b_i8, CUDA_R_8I, MatDim,
                    &beta_i32, 
                    c_i32, CUDA_R_32I, MatDim,
    		CUDA_R_32I, CUBLAS_GEMM_DEFAULT));

        cudaErrCheck(cudaEventRecord(stopcublas));
        cudaErrCheck(cudaEventSynchronize(stopcublas));
        cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

        cublasTimeTot += cublasTime;

        printf("%d: cublas time (ms): %f\n", l, cublasTime);
      }

      printf("average time (ms): %f\n\n", cublasTimeTot/10);

      // perform INT8 CUBLAS matmul with tensor cores

      printf("cublas INT8 with tensor cores\n");

      cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));

      cublasTimeTot = 0.0f;

      for (int l=0; l<10; l++) {
        cudaErrCheck(cudaEventRecord(startcublas));

        cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    MatDim, MatDim, MatDim, 
                    &alpha_i32,
                    a_i8, CUDA_R_8I, MatDim,
                    b_i8, CUDA_R_8I, MatDim,
                    &beta_i32, 
                    c_i32, CUDA_R_32I, MatDim,
    		CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

        cudaErrCheck(cudaEventRecord(stopcublas));
        cudaErrCheck(cudaEventSynchronize(stopcublas));
        cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));

        cublasTimeTot += cublasTime;

        printf("%d: cublas time (ms): %f\n", l, cublasTime);
      }

      printf("average time (ms): %f\n\n", cublasTimeTot/10);

      // clean up

      cudaErrCheck(cudaEventDestroy(startcublas));             
      cudaErrCheck(cudaEventDestroy(stopcublas));

      cudaErrCheck(cudaFree(a_fp16));
      cudaErrCheck(cudaFree(b_fp16));
      cudaErrCheck(cudaFree(c_fp32));

      cudaErrCheck(cudaFree(a_i8));
      cudaErrCheck(cudaFree(b_i8));
      cudaErrCheck(cudaFree(c_i32));

      cudaErrCheck(cudaDeviceReset());

      // all done

      return 0;

    } // main

    // End-of-File

When the blog article was written, there was no possibility to run int8 operations through the tensor cores. Furthermore, nowhere in the github code you listed as source: is there any such code:

cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    MatDim, MatDim, MatDim, 
                    &alpha_i32,
                    a_i8, CUDA_R_8I, MatDim,
                    b_i8, CUDA_R_8I, MatDim,
                    &beta_i32, 
                    c_i32, CUDA_R_32I, MatDim,
    		CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

In the current CUDA 10.0 cublas docs, I don’t see that combo as a supported tensor operation.

https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode

I would conclude that currently the methodology to run int8 operations through tensor cores via cublas on Turing are currently unpublished. Unless you know otherwise.

With a bit of study of profiler output, you can tell when tensorcore is being used on Volta and on Turing (the signatures are somewhat different). You might want to see if you see any candidate kernels in nvprof.

I’m not trying to be catty. I’m not allowed to release currently unpublished information. You may have to wait.

Thanks Mod, you’re probably right that tensor-INT8 mode is not yet supported under CUBLAS.


I also tried out tensor-INT8 mode this time using WMMA code directly. I posted my question regarding this in the “Mixed-Precision and Tensor Cores” sub-forum: https://devtalk.nvidia.com/default/topic/1047382/mixed-precision-and-tensor-cores/tensor-wmma-int8-vs-fp16-processing-speed/

When testing tensor-INT8 WMMA execution time vs. tensor-FP16 WMMA execution time I also saw nearly the same execution times for both modes (around 0.11 ms for 2048x2048 matrices). This was surprising as I was expecting INT8 WMMA to run much faster due to the double throughput of tensor-INT8 mode vs. tensor-FP16 mode.

Would you have an idea about why the execution times are nearly the same for FP16 and INT8 tensor WMMA modes ?

Thanks again.

This was the WMMA kernel code I used in those tests (more details can also be seen in the linked post above):

__global__ void wmma_example_f16(half *a, half *b)
    {
       // Tile using a 2D grid
       int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
       int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

       // Declare the fragments
       wmma::fragment<wmma::matrix_a, WmmaDim, WmmaDim, WmmaDim, half, wmma::col_major> a_frag;
       wmma::fragment<wmma::matrix_b, WmmaDim, WmmaDim, WmmaDim, half, wmma::col_major> b_frag;
       wmma::fragment<wmma::accumulator, WmmaDim, WmmaDim, WmmaDim, float> acc_frag;

       wmma::fill_fragment(acc_frag, 0);

       // Loop over k
       for (int i = 0; i < MatDim; i += WmmaDim) {
          int aRow = warpM * WmmaDim;
          int aCol = i;

          int bRow = i;
          int bCol = warpN * WmmaDim;

          // Bounds checking
          if (aRow < MatDim && aCol < MatDim && bRow < MatDim && bCol < MatDim) {
             // Load the inputs
             wmma::load_matrix_sync(a_frag, a + aRow + aCol * MatDim, MatDim);
             wmma::load_matrix_sync(b_frag, b + bRow + bCol * MatDim, MatDim);
     
             // Perform the matrix multiplication
             wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
          }
       }
    } // wmma_example_f16

    __global__ void wmma_example_i8(signed char *a, signed char *b)
    {
       // Tile using a 2D grid
       int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
       int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

       // Declare the fragments
       wmma::fragment<wmma::matrix_a, WmmaDim, WmmaDim, WmmaDim, signed char, wmma::col_major> a_frag;
       wmma::fragment<wmma::matrix_b, WmmaDim, WmmaDim, WmmaDim, signed char, wmma::col_major> b_frag;
       wmma::fragment<wmma::accumulator, WmmaDim, WmmaDim, WmmaDim, int> acc_frag;

       wmma::fill_fragment(acc_frag, 0);

       // Loop over k
       for (int i = 0; i < MatDim; i += WmmaDim) {
          int aRow = warpM * WmmaDim;
          int aCol = i;

          int bRow = i;
          int bCol = warpN * WmmaDim;

          // Bounds checking
          if (aRow < MatDim && aCol < MatDim && bRow < MatDim && bCol < MatDim) {
             // Load the inputs
             wmma::load_matrix_sync(a_frag, a + aRow + aCol * MatDim, MatDim);
             wmma::load_matrix_sync(b_frag, b + bRow + bCol * MatDim, MatDim);
     
             // Perform the matrix multiplication
             wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
          }
       }
    } // wmma_example_i8

Thanks Mod, you’re probably right that CUBLAS doesn’t yet support tensor-INT8 mode.

I also tried tensor-INT8 mode directly using WMMA code. I posted regarding this in the “Mixed-Precision and Tensor Cores” sub-forum: https://devtalk.nvidia.com/default/topic/1047382/mixed-precision-and-tensor-cores/tensor-wmma-int8-vs-fp16-processing-speed/

Surprisingly, even when using WMMA code directly the tensor-INT8 mode execution time is almost the same as the tensor-FP16 mode. This doesn’t seem to be correct as INT8 mode is supposed to have double the throughput of FP16 mode.

Any ideas as to why the execution times are the same for the two modes ?

The kernel code used to test this was as follows:

__global__ void wmma_example_f16(half *a, half *b)
    {
       // Tile using a 2D grid
       int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
       int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

       // Declare the fragments
       wmma::fragment<wmma::matrix_a, WmmaDim, WmmaDim, WmmaDim, half, wmma::col_major> a_frag;
       wmma::fragment<wmma::matrix_b, WmmaDim, WmmaDim, WmmaDim, half, wmma::col_major> b_frag;
       wmma::fragment<wmma::accumulator, WmmaDim, WmmaDim, WmmaDim, float> acc_frag;

       wmma::fill_fragment(acc_frag, 0);

       // Loop over k
       for (int i = 0; i < MatDim; i += WmmaDim) {
          int aRow = warpM * WmmaDim;
          int aCol = i;

          int bRow = i;
          int bCol = warpN * WmmaDim;

          // Bounds checking
          if (aRow < MatDim && aCol < MatDim && bRow < MatDim && bCol < MatDim) {
             // Load the inputs
             wmma::load_matrix_sync(a_frag, a + aRow + aCol * MatDim, MatDim);
             wmma::load_matrix_sync(b_frag, b + bRow + bCol * MatDim, MatDim);
     
             // Perform the matrix multiplication
             wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
          }
       }
    } // wmma_example_f16

    __global__ void wmma_example_i8(signed char *a, signed char *b)
    {
       // Tile using a 2D grid
       int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
       int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

       // Declare the fragments
       wmma::fragment<wmma::matrix_a, WmmaDim, WmmaDim, WmmaDim, signed char, wmma::col_major> a_frag;
       wmma::fragment<wmma::matrix_b, WmmaDim, WmmaDim, WmmaDim, signed char, wmma::col_major> b_frag;
       wmma::fragment<wmma::accumulator, WmmaDim, WmmaDim, WmmaDim, int> acc_frag;

       wmma::fill_fragment(acc_frag, 0);

       // Loop over k
       for (int i = 0; i < MatDim; i += WmmaDim) {
          int aRow = warpM * WmmaDim;
          int aCol = i;

          int bRow = i;
          int bCol = warpN * WmmaDim;

          // Bounds checking
          if (aRow < MatDim && aCol < MatDim && bRow < MatDim && bCol < MatDim) {
             // Load the inputs
             wmma::load_matrix_sync(a_frag, a + aRow + aCol * MatDim, MatDim);
             wmma::load_matrix_sync(b_frag, b + bRow + bCol * MatDim, MatDim);
     
             // Perform the matrix multiplication
             wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
          }
       }
    } // wmma_example_i8

Throughput in terms of operations per second is doubled.

In the same 32 bit wide register space we previously held 2 FP16 values, we now store 4 INT8 values. The tensor core performs twice the multiply-add operations in approximately the same run time.

This research article dives into details such as clock cycles (TABLE I)

turns out WMMA 8x32x16 in INT8 mode executes a bit faster than FP16 on RTX tensor cores.

Thanks for the document link @cbuchner1.

It looks like even WMMA 16x16x16 INT8 mode is nearly as fast as 8x32x16 INT8 mode, ie. 59 clock cycles for the former and 56 clock cycles for the latter.

Based on the values given, 16x16x16 INT8 mode at 59 clock cycles compared to 16x16x16 FP16 (with FP32 accumulate) at 99 clock cycles, makes the INT8 mode around 68% faster than FP16 mode.

But the two test kernels I posted previously (“wmma_example_f16” and “wmma_example_i8”) are showing nearly the same execution times. Both these kernels are using 16x16x16 WMMA dimensions (ie. “WmmaDim” in the code is defined as 16) so based on the theoretical clock cycle count from the document INT8 mode should be much faster, and yet it seems to be showing nearly the same execution times.

Any idea on this?

Here’s execution times for multiplication of 4096x4096 matrices:

wmma FP16 (FP32 acc)
0: wmma time (ms): 0.836704
1: wmma time (ms): 0.834112
2: wmma time (ms): 0.834720
3: wmma time (ms): 0.843648
4: wmma time (ms): 0.835616
5: wmma time (ms): 0.835296
6: wmma time (ms): 0.841440
7: wmma time (ms): 0.840512
8: wmma time (ms): 0.839680
9: wmma time (ms): 0.841856
average time (ms): 0.838359

wmma INT8 (INT32 acc)
0: wmma time (ms): 0.833984
1: wmma time (ms): 0.832512
2: wmma time (ms): 0.833120
3: wmma time (ms): 0.831968
4: wmma time (ms): 0.832384
5: wmma time (ms): 0.833344
6: wmma time (ms): 0.833088
7: wmma time (ms): 0.835584
8: wmma time (ms): 0.833504
9: wmma time (ms): 0.833600
average time (ms): 0.833309

There is one major issue with the integer wmma API in the nvcuda namespace.

It introduces additional byte permutations. This is because the fragment load instruction puts individual 8 bit integers individually into 32 bit registers. So before executing the wmma hardware instruction 4 8 bit integers have to get combined back into a 32 bit register using byte permutes.

This extra overhead is already visible in the PTX code generated by nvcc and it causes a notable slowdown.

The only way I got more speed in my particular application (big integer multiplications) was to use PTX WMMA inline assembly. However the exact mappings between theads, registers and matrix elements is undocumented and may change across architectures. Still if you require peak performance, this is the way to go.

__device__ __inline__ void wmma_8x32x16(int &d0, int &d1, int &d2, int &d3, int &d4, int &d5, int &d6, int &d7,
         const unsigned int &a0,
         const unsigned int &b0, const unsigned int &b1, const unsigned int &b2, const unsigned int &b3,
         const int &c0, const int &c1, const int &c2, const int &c3, const int &c4, const int &c5, const int &c6, const int &c7)
{
    asm(" wmma.mma.sync.aligned.row.row.m8n32k16.s32.u8.u8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, {%8}, {%9,%10,%11,%12}, {%13,%14,%15,%16,%17,%18,%19,%20};\n\t"
        : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3), "=r"(d4), "=r"(d5), "=r"(d6), "=r"(d7) : "r"(a0), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(c4), "r"(c5), "r"(c6), "r"(c7));
}

Thanks for the info. I generated the PTX file and looking at the code produced, I can’t find the additional byte permutations you mentioned.

I see the two “wmma.load” instructions (lines 54 and 55 in the PTX below) where each loads into a vector of two 32-bit registers (which matches what is said in the PTX documentation for INT8 mode). This is then directly followed by the “wmma.mma” instruction (line 56) with the output result being placed in a vector of eight 32-bit registers (also as specified in the PTX documentation for INT8 mode). Those dual “wmma.load” and “wmma.mma” instructions appear four times as the main loop was unrolled four times by the compiler.

I can’t seem to find where byte permutations were being done.

Or do you mean the byte permutation code can’t be seen in the PTX code, but is present in the SASS code?

This was the PTX code produced:

// .globl	_Z15wmma_example_i8PaS_
.visible .entry _Z15wmma_example_i8PaS_(
	.param .u64 _Z15wmma_example_i8PaS__param_0,
	.param .u64 _Z15wmma_example_i8PaS__param_1
)
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<169>;
	.reg .b64 	%rd<21>;


	ld.param.u64 	%rd5, [_Z15wmma_example_i8PaS__param_0];
	ld.param.u64 	%rd6, [_Z15wmma_example_i8PaS__param_1];
	cvta.to.global.u64 	%rd7, %rd6;
	mov.u32 	%r85, %ntid.x;
	mov.u32 	%r86, %ctaid.x;
	mov.u32 	%r87, %tid.x;
	mad.lo.s32 	%r88, %r85, %r86, %r87;
	mov.u32 	%r89, WARP_SZ;
	div.u32 	%r90, %r88, %r89;
	mov.u32 	%r91, %ntid.y;
	mov.u32 	%r92, %ctaid.y;
	mov.u32 	%r93, %tid.y;
	cvta.to.global.u64 	%rd1, %rd5;
	shl.b32 	%r1, %r90, 4;
	mad.lo.s32 	%r94, %r91, %r92, %r93;
	shl.b32 	%r95, %r94, 4;
	setp.lt.s32	%p2, %r1, 256;
	setp.lt.s32	%p3, %r95, 256;
	and.pred  	%p1, %p2, %p3;
	shl.b32 	%r96, %r94, 12;
	cvt.s64.s32	%rd8, %r96;
	add.s64 	%rd20, %rd7, %rd8;
	mov.u32 	%r137, 0;
	mov.u32 	%r138, %r137;
	mov.u32 	%r139, %r137;
	mov.u32 	%r140, %r137;
	mov.u32 	%r141, %r137;
	mov.u32 	%r142, %r137;
	mov.u32 	%r143, %r137;
	mov.u32 	%r144, %r137;
	mov.u32 	%r136, %r137;

BB2_1:
	@!%p1 bra 	BB2_3;
	bra.uni 	BB2_2;

BB2_2:
	shl.b32 	%r97, %r136, 8;
	add.s32 	%r98, %r97, %r1;
	cvt.s64.s32	%rd9, %r98;
	add.s64 	%rd10, %rd1, %rd9;
	mov.u32 	%r99, 256;
	wmma.load.a.sync.aligned.col.m16n16k16.global.s8 	{%r100, %r101}, [%rd10], %r99;
	wmma.load.b.sync.aligned.col.m16n16k16.global.s8 	{%r102, %r103}, [%rd20], %r99;
	wmma.mma.sync.aligned.col.col.m16n16k16.s32.s8.s8.s32 {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137}, {%r100, %r101}, {%r102, %r103}, {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137};

BB2_3:
	@!%p1 bra 	BB2_5;
	bra.uni 	BB2_4;

BB2_4:
	shl.b32 	%r104, %r136, 8;
	add.s32 	%r105, %r104, %r1;
	add.s32 	%r106, %r105, 4096;
	cvt.s64.s32	%rd11, %r106;
	add.s64 	%rd12, %rd1, %rd11;
	mov.u32 	%r107, 256;
	wmma.load.a.sync.aligned.col.m16n16k16.global.s8 	{%r108, %r109}, [%rd12], %r107;
	add.s64 	%rd13, %rd20, 16;
	wmma.load.b.sync.aligned.col.m16n16k16.global.s8 	{%r110, %r111}, [%rd13], %r107;
	wmma.mma.sync.aligned.col.col.m16n16k16.s32.s8.s8.s32 {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137}, {%r108, %r109}, {%r110, %r111}, {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137};

BB2_5:
	@!%p1 bra 	BB2_7;
	bra.uni 	BB2_6;

BB2_6:
	shl.b32 	%r112, %r136, 8;
	add.s32 	%r113, %r112, %r1;
	add.s32 	%r114, %r113, 8192;
	cvt.s64.s32	%rd14, %r114;
	add.s64 	%rd15, %rd1, %rd14;
	mov.u32 	%r115, 256;
	wmma.load.a.sync.aligned.col.m16n16k16.global.s8 	{%r116, %r117}, [%rd15], %r115;
	add.s64 	%rd16, %rd20, 32;
	wmma.load.b.sync.aligned.col.m16n16k16.global.s8 	{%r118, %r119}, [%rd16], %r115;
	wmma.mma.sync.aligned.col.col.m16n16k16.s32.s8.s8.s32 {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137}, {%r116, %r117}, {%r118, %r119}, {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137};

BB2_7:
	@!%p1 bra 	BB2_9;
	bra.uni 	BB2_8;

BB2_8:
	shl.b32 	%r120, %r136, 8;
	add.s32 	%r121, %r120, %r1;
	add.s32 	%r122, %r121, 12288;
	cvt.s64.s32	%rd17, %r122;
	add.s64 	%rd18, %rd1, %rd17;
	mov.u32 	%r123, 256;
	wmma.load.a.sync.aligned.col.m16n16k16.global.s8 	{%r124, %r125}, [%rd18], %r123;
	add.s64 	%rd19, %rd20, 48;
	wmma.load.b.sync.aligned.col.m16n16k16.global.s8 	{%r126, %r127}, [%rd19], %r123;
	wmma.mma.sync.aligned.col.col.m16n16k16.s32.s8.s8.s32 {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137}, {%r124, %r125}, {%r126, %r127}, {%r144, %r143, %r142, %r141, %r140, %r139, %r138, %r137};

BB2_9:
	add.s64 	%rd20, %rd20, 64;
	add.s32 	%r136, %r136, 64;
	setp.lt.s32	%p4, %r136, 256;
	@%p4 bra 	BB2_1;

	ret;
}

This was the kernel (configured for 16x16x16 WMMA tiles, and 256x256 input matrices):

__global__ void wmma_example_i8(signed char *a, signed char *b)
{
  // Tile using a 2D grid
  int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

  // Declare the fragments
  wmma::fragment<wmma::matrix_a, 16, 16, 16, signed char, wmma::col_major> a_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, signed char, wmma::col_major> b_frag;
  wmma::fragment<wmma::accumulator, 16, 16, 16, int> acc_frag;

  wmma::fill_fragment(acc_frag, 0);

  // Loop over k
  for (int i = 0; i < 256; i += 16) {
    int aRow = warpM * 16;
    int aCol = i;

    int bRow = i;
    int bCol = warpN * 16;

    // Bounds checking
    if (aRow < 256 && aCol < 256 && bRow < 256 && bCol < 256) {
      // Load the inputs
      wmma::load_matrix_sync(a_frag, a + aRow + aCol * 256, 256);
      wmma::load_matrix_sync(b_frag, b + bRow + bCol * 256, 256);

      // Perform the matrix multiplication
      wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
    }
  }
} // wmma_example_i8

If you don’t see any permute operations, then it’s probably my use of elementwise access of the wmma::fragment<> of matrix_a/b that triggered a particularly slow code path.

I was accessing/writing to fragment.x[i] with i indexing 0 to fragment.num_elements-1 just before the WMMA. For multiplication I needed to set up some matrices in a special way that could not be achieved by simply using a fragment load operation.

Christian

Ok, thanks for clarifying that.

This is really weird behavior, the execution times are almost the same between the two modes. Even when removing the “wmma.load” commands, so that it’s only timing the “wmma.mma_sync” loop, the two modes are still nearly equal in execution time.