cusparseScsrilu02 breaks with large matrices

Hello, I am trying to implement an iterative solver using incomplete LU factorization with sparse matrices. My code works perfectly for a 1000x1000 system with 3750 non zero elements. But when the input is a 31287x31287 system with 2467643 non zero elements, the code stalls after the call to cusparseScsrilu02. What I mean by that is

  • I can not copy any data from device to host
  • Code stalls at SpSvanalysis and seems to run for ever
  • Calling cudaDeviceSynchronize after cusparseScsrilu02 freezes the program

I have used a text-difference programm (that incluced in VSCode) to check that I read the csr matrices correctly. Bellow is the function I am using.

SparseMatrix contains these elements:
typedef struct

  • double *values; // element values
  • int *row_idx; // cumulative elements of each row → last elements equals NNZ
  • int *col_idx; // column indices of the elements
  • int size; // Matrix size (assuming square matrices only)

and Vector struct:

  • double *values; // element values
  • int size; // Vector size

Bellow is my function:

void solveSystemSparseIterative(SparseMatrix *mat, Vector *B, double *X, double tolerance)

{

int n = mat->size;

int nnz = mat->row_idx[n];

int maxIters = 5000;

// create float copy of system elements

float *  host_float_values = (float *)malloc(nnz * sizeof(float));

float *host_float_rhs = (float *)malloc(n * sizeof(float));

double *zeros = (double *)malloc(n * sizeof(double));

float *tempf = (float *)malloc(nnz * sizeof(float));

double *tempd = (double *)malloc(nnz * sizeof(double));

int maxThreads, blocks, threads;

threads = 256;

if (nnz > threads)

    {

        maxThreads = threads;

        blocks = nnz / maxThreads + 1;

    }

else

    {

        blocks = 1;

        maxThreads = nnz;

    }

blocks = 1;

for (int i = 0; i < n; i++)

    {

        host_float_rhs[i] = B->values[i];

        zeros[i] = 0.0;

    }

for (int i = 0; i < nnz; i++)

        host_float_values[i] = mat->values[i];



// INITIALIZE CUSOLVER

cusparseHandle_t sparseHandle = NULL;

cublasHandle_t blasHandle;

cudaStream_t stream = NULL;



// cusparseStatus_t status;

cusparseCreate(&sparseHandle);

cublasCreate(&blasHandle);

cudaStreamCreate(&stream);



// ALLOCATE MEMORY

double *Xcalculated = (double *)malloc(n * sizeof(double));

double *temp = (double *)malloc(nnz * sizeof(double));

double *Lvalues, *Uvalues, *Avalues, *solution, *rhs, *rhsCopy, *temp_solutionX, *temp_solutionY;

float *f_values;

int *rowPtr, *colIdx, *rowPtrCopy, *colIdxCopy;

checkCudaErrors(cudaMalloc((void **)&Uvalues, nnz * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&Lvalues, nnz * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&Avalues, nnz * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&rhs, n * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&rhsCopy, n * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&solution, n * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&temp_solutionX, n * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&temp_solutionY, n * sizeof(double)));

checkCudaErrors(cudaMalloc((void **)&f_values, nnz * sizeof(float)));

checkCudaErrors(cudaMalloc((void **)&rowPtr, (n + 1) * sizeof(int)));

checkCudaErrors(cudaMalloc((void **)&rowPtrCopy, (n + 1) * sizeof(int)));

checkCudaErrors(cudaMalloc((void **)&colIdx, nnz * sizeof(int)));

checkCudaErrors(cudaMalloc((void **)&colIdxCopy, nnz * sizeof(int)));

// COPY MATRIX A TO DEVICE MEMORY

checkCudaErrors(cudaMemcpy(rowPtr, mat->row_idx, (n + 1) * sizeof(int), cudaMemcpyHostToDevice));

checkCudaErrors(cudaMemcpy(colIdx, mat->col_idx, nnz * sizeof(int), cudaMemcpyHostToDevice));

checkCudaErrors(cudaMemcpy(rowPtrCopy, rowPtr, (n + 1) * sizeof(int), cudaMemcpyDeviceToDevice));

checkCudaErrors(cudaMemcpy(colIdxCopy, colIdx, nnz * sizeof(int), cudaMemcpyDeviceToDevice));

// COPY FLOAT MATRIX ELEMENTS

checkCudaErrors(cudaMemcpy(f_values, host_float_values, nnz * sizeof(float), cudaMemcpyHostToDevice));

checkCudaErrors(cudaMemcpy(Avalues, mat->values, nnz * sizeof(double), cudaMemcpyHostToDevice));

// COPY FLOAT B ELEMENTS

// cudaMemcpy(rhs, B->values, n, cudaMemcpyHostToDevice);

checkCudaErrors(cudaMemcpy(rhs, B->values, n * sizeof(double), cudaMemcpyHostToDevice));

checkCudaErrors(cudaMemcpy(rhsCopy, rhs, n * sizeof(double), cudaMemcpyDeviceToDevice));

// INIT EMPTY VECTOR

checkCudaErrors(cudaMemcpy(temp_solutionX, zeros, n * sizeof(double), cudaMemcpyHostToDevice));

checkCudaErrors(cudaMemcpy(temp_solutionY, temp_solutionX, n * sizeof(double), cudaMemcpyDeviceToDevice));


// SETUP MATRIX DESCRIPTOR

cusparseMatDescr_t descrA;

cusparseCreateMatDescr(&descrA);

cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL);

cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO);

// INITIALIZE VARIABLES FOR LU FACTORIZATION

int pBufferSize;

size_t spSvBufferSizeL, spSvBufferSizeU;

void *pBuffer, *spSvBufferL, *spSvBufferU;

// int structural_zero, numerical_zero;

cusparseSolvePolicy_t policy = CUSPARSE_SOLVE_POLICY_NO_LEVEL;

csrilu02Info_t LUinfo;

cusparseCreateCsrilu02Info(&LUinfo);

double tole = 0;

float boost = 1e-8;

checkCudaErrors(cusparseScsrilu02_numericBoost(sparseHandle, LUinfo, 1, &tole, &boost));

printf("Buffer size..\n");

// CALCULATE LU FACTORIZATION BUFFER SIZE

checkCudaErrors(cusparseScsrilu02_bufferSize(sparseHandle, n, nnz, descrA,f_values, rowPtr, colIdx, LUinfo,&pBufferSize));

checkCudaErrors(cudaMalloc(&pBuffer, pBufferSize));

// pBuffer returned by cudaMalloc is automatically aligned to 128 bytes

printf("Buffer size for LU is %d\n",pBufferSize);

printf("Analysis..\n");

// LU FACTORIZATION ANALYSIS

checkCudaErrors(cusparseScsrilu02_analysis(sparseHandle, n, nnz, descrA,f_values, rowPtr, colIdx, LUinfo, policy, pBuffer));

cusparseStatus_t status;

int structural_zero;

status = cusparseXcsrilu02_zeroPivot(sparseHandle, LUinfo, &structural_zero);

if (CUSPARSE_STATUS_ZERO_PIVOT == status)

        printf("A(%d,%d) is missing\n", structural_zero, structural_zero);

printf("Factorization..\n");

// A = L * U

checkCudaErrors(cusparseScsrilu02(sparseHandle, n, nnz, descrA, f_values, rowPtr, colIdx, LUinfo, policy, pBuffer));

f_values now contain L U matrices

cusparseDestroyMatDescr(descrA);

cudaFree(pBuffer);

cusparseDestroyCsrilu02Info(LUinfo);

cudaError_t err;

printf("Convert to double..\n");

//DEVICE TYPECAST

floatToDoubleVector<<<blocks, maxThreads>>>(f_values, Lvalues, nnz);

checkCudaErrors(cudaMemcpy(Uvalues, Lvalues, nnz * sizeof(double), cudaMemcpyDeviceToDevice));

printf("\ndone converting..\n");

cusparseSpMatDescr_t descrL, descrU, descrACopy;

// Create a copy of A to calculate residual r = b - Ax

cusparseCreateCsr(&descrACopy, n, n, nnz, rowPtrCopy, colIdxCopy, Avalues, CUSPARSE_INDEX_32I,
                  CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_64F);

cusparseCreateCsr(&descrL, n, n, nnz, rowPtr, colIdx, Lvalues, CUSPARSE_INDEX_32I,
                  CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_64F);

cusparseCreateCsr(&descrU, n, n, nnz, rowPtr, colIdx, Uvalues, CUSPARSE_INDEX_32I,
                  CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_64F);

printf("Set attributes..\n");

cusparseFillMode_t lower = CUSPARSE_FILL_MODE_LOWER;

cusparseDiagType_t unit = CUSPARSE_DIAG_TYPE_UNIT;

cusparseFillMode_t upper = CUSPARSE_FILL_MODE_UPPER;

cusparseDiagType_t nonUnit = CUSPARSE_DIAG_TYPE_NON_UNIT;

cusparseSpMatSetAttribute(descrL, CUSPARSE_SPMAT_FILL_MODE, (void *)&lower, sizeof(lower));

cusparseSpMatSetAttribute(descrL, CUSPARSE_SPMAT_DIAG_TYPE, (void *)&unit, sizeof(unit));

cusparseSpMatSetAttribute(descrU, CUSPARSE_SPMAT_FILL_MODE, (void *)&upper, sizeof(upper));

cusparseSpMatSetAttribute(descrU, CUSPARSE_SPMAT_DIAG_TYPE, (void *)&nonUnit, sizeof(nonUnit));

// INITIALIZE B,X,Y VECTOR DESCRIPTORS

cusparseDnVecDescr_t descrX, descrY, descrB;

cusparseCreateDnVec(&descrB, n, rhs, CUDA_R_64F);

cusparseCreateDnVec(&descrY, n, temp_solutionY, CUDA_R_64F);

cusparseCreateDnVec(&descrX, n, temp_solutionX, CUDA_R_64F);

// SETUP TRIANGULAR SOLVER DESCRIPTOR

cusparseSpSVDescr_t spsvDescrL, spsvDescrU;

cusparseSpSV_createDescr(&spsvDescrL);

cusparseSpSV_createDescr(&spsvDescrU);

double plusOne = 1.0;

printf("SpSv analysisL.. \n");

checkCudaErrors(cusparseSpSV_bufferSize(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrL, descrB,descrY, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrL, &spSvBufferSizeL));

cudaMalloc((void **)&spSvBufferL, spSvBufferSizeL);

printf("spSvBufferSizeL: %ld\n", spSvBufferSizeL);

checkCudaErrors(cusparseSpSV_analysis(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrL, descrB,descrY, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrL, spSvBufferL));

printf("SpSv analysisU.. \n");

checkCudaErrors(cusparseSpSV_bufferSize(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrU, descrY,  descrX, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrU, &spSvBufferSizeU));

cudaMalloc((void **)&spSvBufferU, spSvBufferSizeU);

printf("spSvBufferSizeU: %ld\n", spSvBufferSizeU);

checkCudaErrors(cusparseSpSV_analysis(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrU, descrY, descrX, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrU, spSvBufferU));

printf("SpSv solve L.. \n");

// solve L*y = b

checkCudaErrors(cusparseSpSV_solve(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrL, descrB,descrY, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrL));

printf("SpSv solve U.. \n");

// solve U*x = y

checkCudaErrors(cusparseSpSV_solve(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrU, descrY, descrX, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrU));


printf("enter loop\n");

for (int i = 0; i < maxIters; i++)

{

        double minusOne = -1.0;

        double one = 1.0;

        size_t spMvBufferSize = 0;

        void *spMvBuffer;

        // CALCULATE RESIDUAL and store it on B vector

        cusparseSpMV_bufferSize(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &minusOne, descrACopy, descrX, &one, descrB, CUDA_R_64F, CUSPARSE_SPMV_CSR_ALG2, &spMvBufferSize);

        cudaMalloc(&spMvBuffer, spMvBufferSize);

        cusparseSpMV(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &minusOne, descrACopy, descrX, &one, descrB, CUDA_R_64F, CUSPARSE_SPMV_CSR_ALG2, spMvBuffer);

        // CUBLAS NORM

        double resNormm, bNorm;

        cublasDnrm2(blasHandle, n, rhs, 1, &resNormm);

        cublasDnrm2(blasHandle, n, rhsCopy, 1, &bNorm);

        if ((resNormm / bNorm) < tolerance)

        {

            printf("Iters: %d\n", i);

            break;

        }

        // solve L*y = r : B contains the residual

        checkCudaErrors(cusparseSpSV_solve(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrL, descrB,descrY, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrL));

        // solve U*c = y

        checkCudaErrors(cusparseSpSV_solve(sparseHandle, CUSPARSE_OPERATION_NON_TRANSPOSE, &plusOne, descrU, descrY,descrX, CUDA_R_64F, CUSPARSE_SPSV_ALG_DEFAULT, spsvDescrU));

        // Xn+1 = Xn + Cn

        cublasDaxpy(blasHandle, n, &one, temp_solutionX, 1, solution, 1);

        cudaMemcpy(temp_solutionX, solution, n * sizeof(double), cudaMemcpyDeviceToDevice);

        // restore B values

        cudaMemcpy(rhs, rhsCopy, n * sizeof(double), cudaMemcpyDeviceToDevice);


}

// TRANSFER SOLUTION TO X VECTOR

cudaMemcpy(X, temp_solutionX, n * sizeof(double), cudaMemcpyHostToDevice);

cusparseDestroyDnVec(descrX);

cusparseDestroyDnVec(descrY);

cusparseDestroyDnVec(descrB);

cusparseSpSV_destroyDescr(spsvDescrL);

cusparseSpSV_destroyDescr(spsvDescrU);

cusparseDestroy(sparseHandle);

cudaMemcpy(X, solution, n * sizeof(double), cudaMemcpyDeviceToHost);

// FREE RESOURCES

}

Can you please check what is the first call that stalls? you can put cudaDeviceSynchronize() after each call.
Secondly, you can try to run compute-sanitizer to check that there are no previous failing APIs/Kernels.
Lastly, if you don’t have success with the previous steps, can you please share the sparse matrix and the full code? in this way, we can try to reproduce the issue on our side

1 Like

Thank you for your immediate response! You can access here the full code in my gitHub repo, where I have included the matrices. First run make to compile the files and then ./gpu_sparse to run the aforementioned code.

  • The first stalling function is actually cusparseScsrilu02().
  • Compute-sanitizer does not print any previous failing APIs/Kernels

We have found the cause for this issue. It has nothing to do with the matrix size. We’ll update the fix in the next release.

1 Like

Thank you very much for your answer. Is there any workaround until the fix is published?

no, the problem is related to specific matrix sizes. Potentially, you can add padding to increase the matrix size or add numerical zero for the non-zero elements. However, I would not suggest these changes because they are invasive and may break the application workflow.

1 Like

I understand. Should the total size be a multiple of 128 for example? I think it is worth a shot. I am developing this workflow for an internship, so a more immediate solution than just waiting for the next update is much appreciated.

Unfortunately, we don’t have a workaround for it at the moment. The fix will be released soon.

1 Like

I am currently trying to replicate csrilu02 using standard C (CPU) code. Can you provide the reference paper or the pseudocode of this function? Thank you in advance

This TR describes the implementation of ILU on GPU. Based on it you can implement a parallel CPU code in C.

1 Like

Thank you !