Issue with cuBLAS Batched Matrix Multiplication Result

I’m trying to write a proof of concept piece of code to drop-in a CUDA kernel and bind it to PyTorch. Originally, I was doing matrix multiplication using cublasSgemm but for efficiency, I’d like to be able to handle batched matrix multiplication.

How it works/the issue

My C++ frontend receives a PyTorch (ATen) tensor, converts it to a float ** (this is the input I expect for any drop-in kernel), and passes it into the kernel. I also transpose the last two dimensions and swap the function’s arguments to account for the column-major computation done in cuBLAS. When I print out the data being passed into the kernel, it appears correct.

The CUDA kernel then takes the matrices and moves them over to the device and performs the batched matrix multiplication on them. When I print A and B, they always appear correct; however, the result is completely incorrect.

When I tried a batched matrix multiplication on sizes (8, *n*, 1) by (8, 1, *n*) where n is any number, it passes. However, once I go to a matrix (i.e. (8, 2, 2) by (8, 2, 2) ), it fails.

An example of how it is incorrect

The following tensors are the inputs:

a:  tensor([[[0.1780, 0.2388],
         [0.1507, 0.1490]],

        [[0.4671, 0.0749],
         [0.0233, 0.8508]],

        [[0.1400, 0.8500],
         [0.5346, 0.4268]],

        [[0.8869, 0.1096],
         [0.7395, 0.8337]],

        [[0.8988, 0.0331],
         [0.5237, 0.4365]],

        [[0.1290, 0.8415],
         [0.4120, 0.8382]],

        [[0.2308, 0.2783],
         [0.6880, 0.5639]],

        [[0.7866, 0.7133],
         [0.9002, 0.1314]]])
b:  tensor([[[0.6130, 0.8714],
         [0.2734, 0.1876]],

        [[0.9920, 0.0813],
         [0.1368, 0.0769]],

        [[0.6237, 0.6213],
         [0.0611, 0.8062]],

        [[0.6433, 0.6694],
         [0.0106, 0.0428]],

        [[0.0890, 0.2948],
         [0.0557, 0.7533]],

        [[0.4321, 0.3028],
         [0.9345, 0.5660]],

        [[0.5646, 0.2192],
         [0.3528, 0.0780]],

        [[0.1936, 0.9967],
         [0.1739, 0.7742]]])

And the next two tensors are my kernel’s output and the expected output, respectively.

Output:
tensor([[[0.2404, 0.0769],
         [0.2762, 0.0932]],

        [[0.4652, 0.0657],
         [0.1435, 0.0757]],

        [[0.4194, 0.4395],
         [0.7954, 0.3961]],

        [[1.0656, 0.0410],
         [0.6286, 0.0368]],

        [[0.2344, 0.4446],
         [0.1316, 0.3306]],

        [[0.1805, 0.3537],
         [0.6174, 1.2607]],

        [[0.2811, 0.1351],
         [0.2807, 0.1422]],

        [[1.0496, 0.8338],
         [0.2691, 0.2258]]])

Expected output:
tensor([[[0.1744, 0.1999],
         [0.1331, 0.1592]],

        [[0.4736, 0.0437],
         [0.1395, 0.0674]],

        [[0.1393, 0.7722],
         [0.3595, 0.6762]],

        [[0.5717, 0.5983],
         [0.4846, 0.5307]],

        [[0.0819, 0.2900],
         [0.0710, 0.4832]],

        [[0.8421, 0.5153],
         [0.9613, 0.5992]],

        [[0.2285, 0.0723],
         [0.5874, 0.1948]],

        [[0.2763, 1.3362],
         [0.1971, 0.9991]]]

My thoughts on why it may be incorrect

The result is completely different and there is no overlap in the numbers which makes me think it doesn’t have to do with the transpose/conversion from row-major to col-major. I thought it may have to do with the leading dimension argument in the cublasSgemmBatched call, but that seems correct to me as well. Since the matrices seem to be being passed in correctly, I’m completely unsure of where the problem is arising, especially since it works with the “dot product”. I’ve attached the code below. Any guidance would be great! Thanks in advance!

Also on a side note, I couldn’t find anything about this online but is it fine to initialize a handle globally to prevent additional overhead from multiple calls? This code is binded to python and I’m not sure I can pass in a handle through there so I just create it in the frontend.

My code is heavily inspired from this stackoverflow post and this github repo.

Code to reproduce

My C++ frontend is as follows:

...
// global initialize to reduce overhead of creating handle
// in every call
cublasHandle_t g_cublas_handle = nullptr;

void init_cublas_handle() {
  cublasStatus_t status = cublasCreate(&g_cublas_handle);
  if (status != CUBLAS_STATUS_SUCCESS)
  {
    std::cerr << "cuBLAS initialization error.";
  }
}

void destroy_cublas_handle() {
  cublasStatus_t status = cublasDestroy(g_cublas_handle);
  if (status != CUBLAS_STATUS_SUCCESS)
  {
    std::cerr << "Shutdown error!";
  }
}

float **raw_data(torch::Tensor tensor, int batch_dim, int rows, int cols) {
  float **data_ptr = (float**) malloc(batch_dim * sizeof(float*));
  auto accessor = tensor.accessor<float, 3>();

  for (int b = 0; b < batch_dim; b++) {
    data_ptr[b] = (float*) malloc(rows * cols * sizeof(float));
    for (int i = 0; i < rows; i++) {
      for (int j = 0; j < cols; j++) {
        data_ptr[b][i * cols + j] = accessor[b][i][j];
      }
    }
  }

  return data_ptr;
}
void free_raw_data(float **ptr, int batch_dim) {
    for (int b = 0; b < batch_dim; b++) {
        free(ptr[b]);
    }
    free(ptr);
}

torch::Tensor cublas_bmm(torch::Tensor B, torch::Tensor A, int dim)
{
    // (B^T A^T)^T for row major to col major
    torch::Tensor A_tensor = torch::transpose(A, 1, 2);
    torch::Tensor B_tensor = torch::transpose(B, 1, 2);

    int A_rows = A_tensor.size(1);
    int A_cols = A_tensor.size(2);
    int B_rows = B_tensor.size(1);
    int B_cols = B_tensor.size(2);

    int batch_dim = A_tensor.size(0);
    assert(batch_dim == B_tensor.size(0));

    torch::Tensor C = torch::zeros({batch_dim, B_cols, A_rows}, torch::kFloat32).contiguous();
    int C_rows = C.size(1);
    int C_cols = C.size(2);

    // expand out arrays to fit batched operation
    float **A_arr = raw_data(A_tensor, batch_dim, A_rows, A_cols);
    float **B_arr = raw_data(B_tensor, batch_dim, B_rows, B_cols);
    float **C_arr = raw_data(C, batch_dim, C_rows, C_cols);

    cublas_bmm_wrapper(g_cublas_handle, A_arr, B_arr, C_arr, A_rows, B_rows, B_cols, batch_dim);
    auto accessor = C.accessor<float, 3>();

    for (int b = 0; b < batch_dim; b++) {
      for (int i = 0; i < C_rows; i++)
      {
        for (int j = 0; j < C_cols; j++)
        {
          accessor[b][i][j] = C_arr[b][i * C_cols + j];
        }
      }
    }

    free_raw_data(A_arr, batch_dim);
    free_raw_data(B_arr, batch_dim);
    free_raw_data(C_arr, batch_dim);
    
    return C;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
  m.def("init_cublas", &init_cublas_handle, "Create cuBLAS handle.");
  m.def("destroy_cublas", &destroy_cublas_handle, "Destroy cuBLAS handle.");
  m.def("cublas_bmm", &cublas_bmm, "cuBLAS Batched Torch Matrix Multiplication");
}

My CUDA kernel is as follows:

#ifndef __CUBLAS_BMM_KERNEL_H__
#define __CUBLAS_BMM_KERNEL_H__

...

#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess) 
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}

static inline void sgemmbatched(
        const cublasHandle_t handle,
        const cublasOperation_t &trans_a,
        const cublasOperation_t &trans_b,
        const int batch_size,

        const float **a, const int a_rows, const int a_cols, const int lda,
        const float **b, const int b_rows, const int b_cols, const int ldb,
        float **c, const int c_rows, const int c_cols, const int ldc,
        const float *alpha,
        const float *beta ) {
        int m, n, k;
        if (trans_a == cublasOperation_t::CUBLAS_OP_N)
        {
                m = a_rows;
                k = a_cols;
        }
        else
        {
                m = a_cols;
                k = a_rows;
        }
        if (trans_b == cublasOperation_t::CUBLAS_OP_N)
        {
                assert(k == b_rows);
                k = b_rows;
                n = b_cols;
        }
        else
        {
                assert(k == b_cols);
                k = b_cols;
                n = b_rows;
        }
        assert(m == c_rows);
        assert(n = c_cols);


        cublasStatus_t status = cublasSgemmBatched(handle,
                trans_a, trans_b,
                m, n, k,
                alpha,
                a, a_stride,
                b, b_stride,
                beta,
                c, c_stride,
                batch_siz
        );

        assert(status == CUBLAS_STATUS_SUCCESS);
}

void cublas_bmm_wrapper(cublasHandle_t handle,
    float **A, float **B, float **C,
    size_t a_rows, size_t b_rows, size_t b_cols,
    size_t batch_size) {

    cudaError_t cudaStatus;
    cublasStatus_t status;

    float **d_A = (float**)malloc(batch_size * sizeof(float*));
    float **d_B = (float**)malloc(batch_size * sizeof(float*));
    float **d_C = (float**)malloc(batch_size * sizeof(float*));

    size_t size_A = a_rows * b_rows * sizeof(float);
    size_t size_B = b_rows * b_cols * sizeof(float);
    size_t size_C = a_rows * b_cols * sizeof(float);

    for (int i = 0 ; i < batch_size; i ++) {
        gpuErrchk(cudaMalloc(&d_A[i], size_A));
        gpuErrchk(cudaMemcpy(d_A[i], A[i], size_A, cudaMemcpyHostToDevice));

        gpuErrchk(cudaMalloc(&d_B[i], size_B));
        gpuErrchk(cudaMemcpy(d_B[i], B[i], size_B, cudaMemcpyHostToDevice));

        gpuErrchk(cudaMalloc(&d_C[i], size_C));
        gpuErrchk(cudaMemcpy(d_C[i], C[i], size_C, cudaMemcpyHostToDevice));
    }
    cudaCheckErrors("inner cudaMalloc/cudaMemcpy fail");

    const float **d_A_arr = 0, **d_B_arr = 0;
    float **d_C_arr = 0;
    gpuErrchk(cudaMalloc(&d_A_arr, batch_size * sizeof(float*)));
    gpuErrchk(cudaMalloc(&d_B_arr, batch_size * sizeof(float*)));
    gpuErrchk(cudaMalloc(&d_C_arr, batch_size * sizeof(float*)));
    cudaCheckErrors("outer cudaMalloc fail");

    gpuErrchk(cudaMemcpy(d_A_arr, d_A, batch_size * sizeof(float*), cudaMemcpyHostToDevice));
    gpuErrchk(cudaMemcpy(d_B_arr, d_B, batch_size * sizeof(float*), cudaMemcpyHostToDevice));
    gpuErrchk(cudaMemcpy(d_C_arr, d_C, batch_size * sizeof(float*), cudaMemcpyHostToDevice));
    cudaCheckErrors("outer cudaMemcpy H2D fail");

    const float alpha = 1.0f, beta = 0.0f;

    sgemmbatched(
        handle, CUBLAS_OP_N, CUBLAS_OP_N,
        batch_size,
        d_A_arr, a_rows, b_rows, a_rows,
        d_B_arr, b_rows, b_cols, b_rows,
        d_C_arr, a_rows, b_cols, a_rows,
        &alpha, &beta
    );

    for (int i = 0; i < batch_size; i++) {
        gpuErrchk(cudaMemcpy(C[i], d_C[i], size_C, cudaMemcpyDeviceToHost));
        gpuErrchk(cudaFree(d_A[i]));
        gpuErrchk(cudaFree(d_B[i]));
        gpuErrchk(cudaFree(d_C[i]));
    }
    cudaCheckErrors("cudaMemcpy D2H fail");

    gpuErrchk(cudaFree(d_A_arr));
    gpuErrchk(cudaFree(d_B_arr));
    gpuErrchk(cudaFree(d_C_arr));

    free(d_A);
    free(d_B);
    free(d_C);
    cudaCheckErrors("free cuda memory fail");
 }

#endif // __CUBLAS_BMM_KERNEL_H__

setup.py to bind C++ frontend to python

from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='custom_mm',
    ext_modules=[
        CUDAExtension('custom_mm', [
            'custom_mm.cpp',
        ])
    ],
    cmdclass={
        'build_ext': BuildExtension
    })

setup.py to create matmuls library

from setuptools import setup

setup(
    name='matmuls',
    py_modules=['matmuls'],
    install_requires=['custom_mm', 'torch', 'numpy'],
    entry_points='''
        [console_scripts]
        matmuls=matmuls:matmuls
    ''',
)

Python interface

import torch
from torch.autograd.function import InplaceFunction
import custom_mm


def cublas_matmul(a: torch.Tensor, b: torch.Tensor, torch_: bool = False) -> torch.Tensor:

    a = a.contiguous()
    b = b.contiguous()

    if len(a.shape) >= 3 and len(b.shape) >= 3:
        _, a_dim2 = a.shape[-2:]
        b_dim1, _ = b.shape[-2:]
        lda, ldb = a.shape[0], b.shape[0]
        assert lda == ldb
        assert a_dim2 == b_dim1
        if len(a.shape) == 3 and len(b.shape) == 3:
            _c = custom_mm.cublas_bmm(a, b, 3)
        return _c.clone().detach()
    else:
        return a @ b

Python tests

import torch
import custom_mm
import matmuls

custom_mm.init_cublas()


def test_result(function, a: torch.Tensor, b: torch.Tensor):
    expected = torch.matmul(a, b)
    output = function(a, b)
    assert(expected.shape == output.shape)
    assert(torch.allclose(expected, output))
    return True


def test_raw_cublas_matmul(a_dim, b_dim):
    a = torch.rand(a_dim)
    b = torch.rand(b_dim)
    assert test_result(custom_mm.cublas_mmul, a, b)


def test_matmuls(a_dim, b_dim):
    a = torch.rand(a_dim)
    b = torch.rand(b_dim)
    assert test_result(matmuls.cublasMM.apply, a, b)


test_raw_cublas_matmul((8, 64), (64, 8))
test_raw_cublas_matmul((8, 64, 16), (16, 8))
test_raw_cublas_matmul((8, 64, 16), (8, 16, 8))
test_raw_cublas_matmul((1, 8, 64, 16), (1, 8, 16, 8))
test_raw_cublas_matmul((2, 8, 64, 16), (2, 8, 16, 8))

test_matmuls((8, 64), (64, 8))
test_matmuls((8, 64, 16), (16, 8))
test_matmuls((8, 64, 16), (8, 16, 8))
test_matmuls((1, 8, 64, 16), (1, 8, 16, 8))
test_matmuls((2, 8, 64, 16), (2, 8, 16, 8))

custom_mm.destroy_cublas()

I ended up resolving this by replacing the transpose in the C++ front end and simply swapping dimensions in the batched mmul call. I’m still not sure why it didn’t work, but it’s functioning fine now.