Hi!
I recently want to implement a LU linear solver for a sparse matrix with cuSolver library. I refer to the example at cuSOLVER :: CUDA Toolkit Documentation. I implemented it as a cuda kernel with the Pytorch deep learning framework, which is appended at the end.
However, every time I call this kernel, it says
double free or corruption (out)
Aborted (core dumped)
without any further notice. I used ‘printf’ and found that the error was raised inside ‘cusolverSpXcsrqrAnalysisBatched’, but of course I’m not 100% confident it is a precise estimation. I’m really confused about the potential reason for this error, because there is actually no explicit memory free in my code. I can make sure that pointers from ‘tensor.data_ptr()’ are not NULL.
Can anyone propose any guess for the fact of the error?
Thanks!!
#include <torch/extension.h>
#include <torch/data/iterator.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <curand_kernel.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <cusolverSp.h>
#include <cuda_runtime_api.h>
// Solve x in the linear system Ax=rhs. A is a CSR-sparse squared matrix of size [num_v, num_v]. x and rhs are batched tensors of size [batch_size, num_v].
std::vector<torch::Tensor> solve_LU_cuda(
const torch::Tensor csrRowA,
const torch::Tensor csrColIndA,
const torch::Tensor csrValA,
const torch::Tensor rhs,
const int num_v,
const int batch_size,
const int nnz) {
// init x
auto options_float = torch::TensorOptions()
.dtype(torch::kFloat32)
.layout(torch::kStrided)
.device(torch::kCUDA, 0)
.requires_grad(false);
torch::Tensor x = torch::zeros_like(rhs, options_float);
// copy and cat the values of A, make it batched A
// row and col indices are fixed, since A is shared inside the batch
torch::Tensor batched_csrValA = csrValA.clone();
for (int i = 0; i < batch_size; i++)
batched_csrValA = torch::cat({batched_csrValA, csrValA}, 0);
int* ptr_csrRowA = csrRowA.data_ptr<int>();
int* ptr_csrColIndA = csrColIndA.data_ptr<int>();
float* ptr_batched_csrValA = batched_csrValA.data_ptr<float>();
float* ptr_rhs = rhs.data_ptr<float>();
float* ptr_x = x.data_ptr<float>();
// solve
// step 1: create cusolver handle, qr info and matrix descriptor
cusolverSpHandle_t cusolverH = NULL;
csrqrInfo_t info = NULL;
cusparseMatDescr_t descrA = NULL;
cusolverStatus_t cusolver_status = CUSOLVER_STATUS_SUCCESS;
cusparseStatus_t cusparse_status = CUSPARSE_STATUS_SUCCESS;
cusolver_status = cusolverSpCreate(&cusolverH);
assert (cusolver_status == CUSOLVER_STATUS_SUCCESS);
cusparse_status = cusparseCreateMatDescr(&descrA);
assert(cusparse_status == CUSPARSE_STATUS_SUCCESS);
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL);
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ONE); // * NOTE *: the index of A starts from 1 / base-1
csrRowA.add(1);
csrColIndA.add(1);
cusolver_status = cusolverSpCreateCsrqrInfo(&info);
assert(cusolver_status == CUSOLVER_STATUS_SUCCESS);
// step 2: symbolic analysis
cusolver_status = cusolverSpXcsrqrAnalysisBatched(
cusolverH, num_v, num_v, nnz, // num_v is the len of A's side, and A is squared
descrA, ptr_csrRowA, ptr_csrColIndA,
info);
assert(cusolver_status == CUSOLVER_STATUS_SUCCESS);
// step 3: prepare working space
size_t size_qr = 0;
size_t size_internal = 0;
void *buffer_qr = NULL; // the working buffer for numerical factorization
cudaError_t cudaStat = cudaSuccess;
/*------ You can print something before here. ------*/
cusolver_status = cusolverSpScsrqrBufferInfoBatched( // calculate the size of buffer
cusolverH, num_v, num_v, nnz, // num_v is the len of A's side, and A is squared
descrA, ptr_batched_csrValA, ptr_csrRowA, ptr_csrColIndA,
batch_size,
info,
&size_internal,
&size_qr);
assert(cusolver_status == CUSOLVER_STATUS_SUCCESS);
/*------ But you cannot print anything since here. ------*/
cudaStat = cudaMalloc((void**)&buffer_qr, size_qr);
assert(cudaStat == cudaSuccess);
// step 4: numerical factorization
cusolver_status = cusolverSpScsrqrsvBatched(
cusolverH,num_v, num_v, nnz, // num_v is the len of A's side, and A is squared
descrA, ptr_batched_csrValA, ptr_csrRowA, ptr_csrColIndA,
ptr_rhs, ptr_x,
batch_size,
info,
buffer_qr);
assert(cusolver_status == CUSOLVER_STATUS_SUCCESS);
// x has been modified via ptr_x
return {x};