Cublas and Cutlas 8bit GEMM matrix size constraints

Hey,

For a standard GEMM routine C = alpha(AB) + betaC, with dimensions A=MxK, B=KxN and C=MxN, what are the constraints of M, N and K for 8bit integer operations. I remember reading somewhere that M, N and K need to be a multiple of 4, but I can’t find that reference anywhere.

Furthermore I tested with no transpose (M= 4, N= 1, K = 4) works.

Could someone please give me constraint information on M, N and K for 8 bit integers, and if there is a way to avoid it with a different Cutlas routine?

Cheers,

Nick

pp reference code:

#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include <array>
#include <vector>
#include <random>
#include <math.h>
#include <cutlass/gemm/device/gemm.h>

  cudaError_t cutlass_igemm_nn(
    int M,
    int N,
    int K,
    float alpha,
    int8_t const *A,
    int lda,
    int8_t const *B,
    int ldb,
    float beta,
    int32_t *C,
    int ldc) {
    // Define type definition for single-precision CUTLASS GEMM with column-major
    // input matrices and 128x128x8 threadblock tile size (chosen by default).
    //
    // To keep the interface manageable, several helpers are defined for plausible compositions
    // including the following example for single-precision GEMM. Typical values are used as
    // default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
    //
    // To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
    using ColumnMajor = cutlass::layout::ColumnMajor;
    //using RowMajor = cutlass::layout::RowMajor;
    using CutlassGemm = cutlass::gemm::device::Gemm<int8_t,        // Data-type of A matrix
                                                    ColumnMajor,  // Layout of A matrix
                                                    int8_t,        // Data-type of B matrix
                                                    ColumnMajor,  // Layout of B matrix
                                                    int32_t,        // Data-type of C matrix
                                                    ColumnMajor>; // Layout of C matrix
    // Define a CUTLASS GEMM type
    CutlassGemm gemm_operator;
    // Construct the CUTLASS GEMM arguments object.
    //
    // One of CUTLASS's design patterns is to define gemm argument objects that are constructible
    // in host code and passed to kernels by value. These may include pointers, strides, scalars,
    // and other arguments needed by Gemm and its components.
    //
    // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
    // arguments to kernels and (2.) minimized initialization overhead on kernel entry.
    //
    CutlassGemm::Arguments args({M , N, K},  // Gemm Problem dimensions
                                {A, lda},    // Tensor-ref for source matrix A
                                {B, ldb},    // Tensor-ref for source matrix B
                                {C, ldc},    // Tensor-ref for source matrix C
                                {C, ldc},    // Tensor-ref for destination matrix D (may be different memory than source C matrix)
                                {alpha, beta}); // Scalars used in the Epilogue
    //
    // Launch the CUTLASS GEMM kernel.
    //
    cutlass::Status status = gemm_operator(args);
    //
    // Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
    //
    if (status != cutlass::Status::kSuccess) {
      return cudaErrorUnknown;
    }
    // Return success, if no errors were encountered.
    return cudaSuccess;
  }

__global__ void quantize(const float * input, int8_t * output, size_t items, float quantMult) {
        size_t x = blockIdx.x * blockDim.x + threadIdx.x;
        if (x < items) {
            output[x] = (int8_t)llrintf((input[x]*quantMult));
            //printf("%d Input: %f, output %d\n", x, input[x], (int8_t)llrintf((input[x]*quantMult)));
        }
    }
    
template<class intType>
__global__ void dequantize(intType * input, float * output, size_t items, float dequantMult) {
    size_t x = blockIdx.x * blockDim.x + threadIdx.x;
    if (x < items)
        output[x] = ((float)input[x])*dequantMult;
}

__global__ void findMaxMin(const float * input_gpu, int idxMax, int idxMin, float * output) {
    float absMax = abs(input_gpu[idxMax]);
    float absMin = abs(input_gpu[idxMin]);
    if (absMax > absMin) {
        output[0] = absMax;
    } else {
        output[0] = absMin;
    }
}

//@TODO rewrite with a nice singlePass GPU version that uses shared memory
float maxAbs(cublasHandle_t& handle, const float * input_gpu, size_t items, float * scratchMem) {
    //Get Max Absolute:
    int resMaxIdx;
     (cublasIsamax(handle, items, input_gpu, 1, &resMaxIdx));
    int resMinIdx;
     (cublasIsamin(handle, items, input_gpu, 1, &resMinIdx));
    float * output_gpu;
    if (scratchMem) {
        output_gpu = scratchMem;
    } else {
         (cudaMalloc(&output_gpu, 1*sizeof(float)));
    }
    findMaxMin<<<1,1>>>(input_gpu, resMaxIdx - 1, resMinIdx - 1, output_gpu); //FUCK YOU FORTRAN INDEXING
    cudaDeviceSynchronize(); //Not necessary due to subsequent call to CudaMemcpy
    float output;
     (cudaMemcpy(&output, &output_gpu[0], 1*sizeof(float), cudaMemcpyDeviceToHost));
    if (!scratchMem) {
         (cudaFree(output_gpu));
    }
    return output;
}

void printGPUMatrix(const float * mat, size_t rows, size_t cols) { //Since they are col major printing is a bit weird
    std::vector<float> cpuMat(rows*cols, 2.0f);
     (cudaMemcpy(&cpuMat[0], mat, rows*cols*sizeof(float), cudaMemcpyDeviceToHost));
    for (int i = 0; i<rows; i++) {
        for (int j = 0; j< cols; j++) {
            printf("%f ", cpuMat[rows*j +i]);
        }
        printf("\n");
    }
}

void printGPUMatrix(float * mat, size_t rows, size_t cols) { //Since they are col major printing is a bit weird
    std::vector<float> cpuMat(rows*cols, 2.0f);
     (cudaMemcpy(&cpuMat[0], mat, rows*cols*sizeof(float), cudaMemcpyDeviceToHost));
    for (int i = 0; i<rows; i++) {
        for (int j = 0; j< cols; j++) {
            printf("%f ", cpuMat[rows*j +i]);
        }
        printf("\n");
    }
}

void printGPUMatrix(int8_t * mat, size_t rows, size_t cols) { //Since they are col major printing is a bit weird
    std::vector<int8_t> cpuMat(rows*cols, 2.0);
     (cudaMemcpy(&cpuMat[0], mat, rows*cols*sizeof(int8_t), cudaMemcpyDeviceToHost));
    for (int i = 0; i<rows; i++) {
        for (int j = 0; j< cols; j++) {
            printf("%d ", (int)cpuMat[rows*j +i]);
        }
        printf("\n");
    }
}

cublasStatus_t cublas8bitGemmmEx(cublasHandle_t handle,
    cublasOperation_t transa, 
    cublasOperation_t transb,
    int m, int n, int k,
    const float* alpha,
    const float* A, int lda,
    const float* B, int ldb,
    const float* beta,
    float* C, int ldc) {
        
    auto algorithm = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
    auto res = cublasGemmEx(handle, transa, transb, 
        m, n, k, alpha, 
        A, CUDA_R_32F, lda, 
        B, CUDA_R_32F, ldb, beta, 
        C, CUDA_R_32F, ldc,
        CUDA_R_32F, algorithm);
    printf("True matrix\n");
    printGPUMatrix(C, m, n);
    printf("Fake matrix\n\n");
    //return res;
    if (true || (m%4 == 0 && n % 4 == 0 && k % 4 ==0)) {
        int rowsA = m;
        int colsA = k;
        int rowsB = k;
        int colsB = n;
        int rowsC = m;
        int colsC = n;

        // Make sure that we have enough threads so that kernel launches don't fail
        if (colsA > 512) {
            std::swap(rowsA, colsA);
            if (colsA > 512) {
                fprintf(stderr, "Incompatible sizes: rows %d, cols %d\n", rowsA, colsA);
            }
        }

        if (colsB > 512) {
            std::swap(rowsB, colsB);
            if (colsB > 512) {
                fprintf(stderr, "Incompatible sizes: rows %d, cols %d\n", rowsB, colsB);
            }
        }

        if (colsC > 512) {
            std::swap(rowsC, colsC);
            if (colsC > 512) {
                fprintf(stderr, "Incompatible sizes: rows %d, cols %d\n", rowsC, colsC);
            }
        }

        int32_t alpha_int = static_cast<int32_t>(*alpha);
        int32_t beta_int = static_cast<int32_t>(*beta);
        int8_t* in8bitIntA;
        int8_t* in8bitIntB;
        int32_t * out32bitInt;

         (cudaMalloc(&out32bitInt, m*n*sizeof(int32_t)));
         (cudaMalloc(&in8bitIntA, m*k*sizeof(int8_t)));
         (cudaMalloc(&in8bitIntB, k*n*sizeof(int8_t)));

        float aMaxAbs = maxAbs(handle, A, m*k, nullptr);
        float bMaxAbs = maxAbs(handle, B, k*n, nullptr);
         (cudaDeviceSynchronize());

        fprintf(stderr, "MaxAbs A: %f, MaxAbs B: %f\n", aMaxAbs, bMaxAbs);

        quantize<<<rowsA, colsA>>>(A, in8bitIntA, m*k, 127.0f/aMaxAbs);
         (cudaGetLastError());
         (cudaDeviceSynchronize()); //Shouldn't be necessary
        quantize<<<rowsB, colsB>>>(B, in8bitIntB, k*n, 127.0f/bMaxAbs);
         (cudaGetLastError());
         (cudaDeviceSynchronize());


        /* TEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEST 
         printf("A:\n");
         printGPUMatrix(A, rowsA, colsA);
         printf("A_quant:\n");
         printGPUMatrix(in8bitIntA, rowsA, colsA);
         printf("A_unquant:\n");
         float *A_restored;
         cudaMalloc(&A_restored, m*k*sizeof(float));
         dequantize<<<rowsA, colsA>>>(in8bitIntA, A_restored, m*k, (aMaxAbs/127.0f));
         printGPUMatrix(A_restored, rowsA, colsA);
         cudaFree(A_restored);
         printf("FAKE\n\n\n\n");
        EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEND*/

        res = cublasGemmEx(handle, transa, transb, 
            m, n, k, &alpha_int, 
            in8bitIntA, CUDA_R_8I, lda, 
            in8bitIntB, CUDA_R_8I, ldb, &beta_int, 
            out32bitInt, CUDA_R_32I, ldc,
            CUDA_R_32I, algorithm);
        
        dequantize<<<rowsC, colsC>>>(out32bitInt, C, m*n, (aMaxAbs/127.0f)*(bMaxAbs/127.0f) );
         (cudaGetLastError());
         (cudaDeviceSynchronize());
     
        //Free temporary used memory
         (cudaFree(out32bitInt));
         (cudaFree(in8bitIntA));
         (cudaFree(in8bitIntB));
    }
    return res;
}

cudaError_t cutlas8bitGemmmEx(cublasHandle_t handle,
    int m, int n, int k,
    const float* alpha,
    const float* A, int lda,
    const float* B, int ldb,
    const float* beta,
    float* C, int ldc) {
        
    cudaError_t res;
    
        int rowsA = m;
        int colsA = k;
        int rowsB = k;
        int colsB = n;
        int rowsC = m;
        int colsC = n;

        // Make sure that we have enough threads so that kernel launches don't fail
        if (colsA > 512) {
            std::swap(rowsA, colsA);
            if (colsA > 512) {
                fprintf(stderr, "Incompatible sizes: rows %d, cols %d\n", rowsA, colsA);
            }
        }

        if (colsB > 512) {
            std::swap(rowsB, colsB);
            if (colsB > 512) {
                fprintf(stderr, "Incompatible sizes: rows %d, cols %d\n", rowsB, colsB);
            }
        }

        if (colsC > 512) {
            std::swap(rowsC, colsC);
            if (colsC > 512) {
                fprintf(stderr, "Incompatible sizes: rows %d, cols %d\n", rowsC, colsC);
            }
        }

        int32_t alpha_int = static_cast<int32_t>(*alpha);
        int32_t beta_int = static_cast<int32_t>(*beta);
        int8_t* in8bitIntA;
        int8_t* in8bitIntB;
        int32_t * out32bitInt;

         (cudaMalloc(&out32bitInt, m*n*sizeof(int32_t)));
         (cudaMalloc(&in8bitIntA, m*k*sizeof(int8_t)));
         (cudaMalloc(&in8bitIntB, k*n*sizeof(int8_t)));

        float aMaxAbs = maxAbs(handle, A, m*k, nullptr);
        float bMaxAbs = maxAbs(handle, B, k*n, nullptr);
         (cudaDeviceSynchronize());

        fprintf(stderr, "MaxAbs A: %f, MaxAbs B: %f\n", aMaxAbs, bMaxAbs);

        quantize<<<rowsA, colsA>>>(A, in8bitIntA, m*k, 127.0f/aMaxAbs);
         (cudaGetLastError());
         (cudaDeviceSynchronize()); //Shouldn't be necessary
        quantize<<<rowsB, colsB>>>(B, in8bitIntB, k*n, 127.0f/bMaxAbs);
         (cudaGetLastError());
         (cudaDeviceSynchronize());

        //CUTLASS GEMM
        
        res = cutlass_igemm_nn(
        m, n, k, *alpha,
        in8bitIntA,
        lda,
        in8bitIntB,
        ldb,
        *beta,
        out32bitInt,
        ldc);
        
        (cudaGetLastError());
        (cudaDeviceSynchronize());
        
        dequantize<<<rowsC, colsC>>>(out32bitInt, C, m*n, (aMaxAbs/127.0f)*(bMaxAbs/127.0f) );
         (cudaGetLastError());
         (cudaDeviceSynchronize());
        
        //Free temporary used memory
         (cudaFree(out32bitInt));
         (cudaFree(in8bitIntA));
         (cudaFree(in8bitIntB));

    return res;
}

#define M 4
#define N 1
#define K 4

int main() {
    int m = M;
    int n = N;
    int k = K;

    std::mt19937 gen;
    // Go somewhat out of range too.
    std::uniform_real_distribution<float> dist(-2, 2);

    std::vector<float> A_cpu(m*k, 0);
    std::vector<float> B_cpu(k*n, 0);

    for (auto&& num : A_cpu) {
        num = dist(gen);
    }

    for (auto&& num : B_cpu) {
        num = dist(gen);
    }

    //GPU alloc
    float * A;
    float * B;
    float * C;

     (cudaMalloc(&A, m*k*sizeof(float)));
     (cudaMalloc(&B, k*n*sizeof(float)));
     (cudaMalloc(&C, m*n*sizeof(float)));
     (cudaMemcpy(A, &A_cpu[0], m*k*sizeof(float), cudaMemcpyHostToDevice));
     (cudaMemcpy(B, &B_cpu[0], k*n*sizeof(float), cudaMemcpyHostToDevice));

    //Cublas handle
    cublasHandle_t handle;
     (cublasCreate(&handle));

    //lda, ldb, ldc, transposes etc

    //auto transa = CUBLAS_OP_T;
    //auto transb = CUBLAS_OP_T;
    
    auto transa = CUBLAS_OP_N;
    auto transb = CUBLAS_OP_N;

    int lda = m;//k; maybe true fro tranposes
    int ldb = k;//n;
    int ldc = m;//n;

    const float alpha = 1;
    const float beta = 0;
    
    fprintf(stderr, "\n\nCuBlas\n");
    
    auto res2 = cublas8bitGemmmEx(handle,
        transa, 
        transb,
        m, n, k,
        &alpha,
        A, lda,
        B, ldb,
        &beta,
        C, ldc);
    
    if (res2 != CUBLAS_STATUS_SUCCESS) {
        fprintf(stderr, "ERROR Cublass did not run!");
    } else {
        printGPUMatrix(C, m, n);
    }
    
    fprintf(stderr, "\n\nCutlass\n");
    
    auto res = cutlas8bitGemmmEx(handle,
    m, n, k,
    &alpha,
    A, lda,
    B, ldb,
    &beta,
    C,  ldc);
    if (res != cudaSuccess) {
        fprintf(stderr, "ERROR Cutlass");
    }
    
    printGPUMatrix(C, m, n);

}