How to perform PyCUDA 4x4 matrix inversion with same accuracy than numpy linalg

I am facing an issue of accuracy about my code which performs a high number (128, 256, 512) of 4x4 matrix inversion. When I use the original version, i.e the numpy classical function np.linalg.inv or np.linakg.pinv, everything works fine.

Unfortunately, with the CUDA code below, I get nan and inf values into inverted matrix.

To be more explicit, here is an example of matrix to invert :

2.120771107884677649e+09 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
    0.000000000000000000e+00 3.557266600921528288e+27 3.557266600921528041e+07 3.557266600921528320e+17
    0.000000000000000000e+00 3.557266600921528041e+07 3.557266600921528288e+27 3.557266600921528041e+07
    0.000000000000000000e+00 3.557266600921528320e+17 3.557266600921528041e+07 1.778633300460764144e+27

If I use classical numpy “inv”, I get for the following inverted 4x4 matrix :

4.715266047722758306e-10 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
    0.000000000000000000e+00 2.811147187396482366e-28 -2.811147186834252285e-48 -5.622294374792964645e-38
    0.000000000000000000e+00 -2.811147186834252285e-48 2.811147187396482366e-28 -5.622294374230735768e-48
    0.000000000000000000e+00 -5.622294374792964645e-38 -5.622294374230735768e-48 5.622294374792964732e-28

To check the validity of this inverse matrix, I have multiplied it by original matrix and the result is the identity matrix.

But with CUDA GPU inversion, I get after the inversion this matrix :

0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
    0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
    -inf -inf -9.373764907941219970e-01 -inf
    inf nan -inf nan

So, I woul like to increase the precision into my CUDA kernel or python code to avoid these nanand inf values.

Here is the CUDA kernel code and calling part of my main code (I have commented the classical method with numpy inv function :

# Create arrayFullCross_vec array
        arrayFullCross_vec = np.zeros((dimBlocks,dimBlocks,integ_prec,integ_prec))

        # Create arrayFullCross_vec array
        invCrossMatrix_gpu = np.zeros((dimBlocks*dimBlocks*integ_prec**2))
     
        # Create arrayFullCross_vec array
        invCrossMatrix = np.zeros((dimBlocks,dimBlocks,integ_prec,integ_prec))

        # Build observables covariance matrix
        arrayFullCross_vec = buildObsCovarianceMatrix4_vec(k_ref, mu_ref, ir)
        """        
        # Compute integrand from covariance matrix
        for r_p in range(integ_prec):
          for s_p in range(integ_prec):
            # original version (without GPU)
            invCrossMatrix[:,:,r_p,s_p] = np.linalg.inv(arrayFullCross_vec[:,:,r_p,s_p])
        """
        # GPU version
        invCrossMatrix_gpu = gpuinv4x4(arrayFullCross_vec.flatten(),integ_prec**2)
        invCrossMatrix = invCrossMatrix_gpu.reshape(dimBlocks,dimBlocks,integ_prec,integ_prec)
        """

and here the CUDA kernel code and gpuinv4x4 function :

kernel = SourceModule("""
    
    __device__ unsigned getoff(unsigned &off){
      unsigned ret = off & 0x0F;
      off = off >> 4;
      return ret;
    }
    
    const int block_size = 256;
    const unsigned tmsk = 0xFFFFFFFF;
    // in-place is acceptable i.e. out == in)
    // T = double or double only
    typedef double T;
    __global__ void inv4x4(const T * __restrict__ in, T * __restrict__ out, const size_t n, const unsigned * __restrict__ pat){
    
      __shared__ T si[block_size];
      size_t idx = threadIdx.x+blockDim.x*blockIdx.x;
      if (idx < n*16){
        si[threadIdx.x] = in[idx];
        unsigned lane = threadIdx.x & 15;
        unsigned sibase = threadIdx.x & 0x03F0;
        __syncwarp();
        unsigned off = pat[lane];
        T a,b;
        a  = si[sibase + getoff(off)];
        a *= si[sibase + getoff(off)];
        a *= si[sibase + getoff(off)];
        if (!getoff(off)) a = -a;
        b  = si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        if (getoff(off)) a += b;
        else a -=b;
        off = pat[lane+16];
        b  = si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        if (getoff(off)) a += b;
        else a -=b;
        b  = si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        if (getoff(off)) a += b;
        else a -=b;
        off = pat[lane+32];
        b  = si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        if (getoff(off)) a += b;
        else a -=b;
        b  = si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        b *= si[sibase + getoff(off)];
        if (getoff(off)) a += b;
        else a -=b;
        T det = si[sibase + (lane>>2)]*a;
        det += __shfl_down_sync(tmsk, det, 4, 16); // first add
        det += __shfl_down_sync(tmsk, det, 8, 16); // second add
        det =  __shfl_sync(tmsk, det, 0, 16); // broadcast
        out[idx] = a / det;
      }
    }
    """)
    
    # python function for inverting 4x4 matrices
    # n should be an even number
    def gpuinv4x4(inp, n):
        # internal constants not to be modified
        hpat = ( 0x0EB51FA5, 0x1EB10FA1, 0x0E711F61, 0x1A710B61, 0x1EB40FA4, 0x0EB01FA0, 0x1E700F60, 0x0A701B60, 0x0DB41F94, 0x1DB00F90, 0x0D701F50, 0x19700B50, 0x1DA40E94, 0x0DA01E90, 0x1D600E50, 0x09601A50, 0x1E790F69, 0x0E391F29, 0x1E350F25, 0x0A351B25, 0x0E781F68, 0x1E380F28, 0x0E341F24, 0x1A340B24, 0x1D780F58, 0x0D381F18, 0x1D340F14, 0x09341B14, 0x0D681E58, 0x1D280E18, 0x0D241E14, 0x19240A14, 0x0A7D1B6D, 0x1A3D0B2D, 0x063D172D, 0x16390729, 0x1A7C0B6C, 0x0A3C1B2C, 0x163C072C, 0x06381728, 0x097C1B5C, 0x193C0B1C, 0x053C171C, 0x15380718, 0x196C0A5C, 0x092C1A1C, 0x152C061C, 0x05281618)
        # Convert parameters into numpy array
        # float32 
        """
        inpd = np.array(inp, dtype=np.float32)
        hpatd = np.array(hpat, dtype=np.uint32)
        output = np.empty((n*16), dtype= np.float32)
        """
        # float64
        """
        inpd = np.array(inp, dtype=np.float64)
        hpatd = np.array(hpat, dtype=np.uint32)
        output = np.empty((n*16), dtype= np.float64)
        """
        # float128
        inpd = np.array(inp, dtype=np.float128)
        hpatd = np.array(hpat, dtype=np.uint32)
        output = np.empty((n*16), dtype= np.float128)
        # Get kernel function
        matinv4x4 = kernel.get_function("inv4x4")
        # Define block, grid and compute
        blockDim = (256,1,1) # do not change
        gridDim = ((n/16)+1,1,1)
        # Kernel function
        matinv4x4 (
            cuda.In(inpd), cuda.Out(output), np.uint64(n), cuda.In(hpatd),
            block=blockDim, grid=gridDim)
        return output

As you can see, I tried to increase accuracy of inverting operation by replacing np.float32 by np.float64 or np.float128 but problem remains.

I have also replaced typedef float T; by typedef double T;but without success.

Anyone could help me to perform the right inversion of these matrices and mostly avoid the ‘nan’ and ‘inf’ values ? I think this is a real issue of precision but I can’t find how to circumvent this problem.

Regards

I haven’t looked at your code, which may contain bugs.

  1. Is it possible some of the matrices are singular or nearly so?
  2. Are you using (partial) pivoting?

Check whether the NVIDIA CUDA developer website still has BatchedSolver_v1_1 available as a download. In it you will find a function dmatinv_batch() that should be close to what you want/need, or at least you can extract relevant portions of the code and tune it to the latest GPU architectures (the code is under a BSD license and thus compatible with all usage scenarios).

/* dmatinv_batch() inverts one or many square, non-singular matrices of double-
   precision elements. Partial pivoting is employed in the inversion process 
   for increased numerical stability.

   A     pointer to an array of the double-precision matrices to be inverted, 
         where each matrix is stored in column-major order
   Ainv  pointer to an array of the double-precision matrices which receive
         the inverses of the corresponding matrices pointed to by A, where 
         each matrix is stored in column-major order
   n     number of rows and columns of the matrices in the arrays pointed to 
         by A and Ainv. n must be greater than, or equal to 2. On sm_13 GPUs,
         n must be less than, or equal to, 44. On sm_2x and sm_3x GPUs, n must
         be less than, or equal to, 77.
   batch the number of matrices to be inverted. It must be greater than zero.

   Returns:

    0    operation completed successfully
   -1    n is out of bounds, batch is out of bounds
   -2    a CUDA error occured
*/
int dmatinv_batch(double *A, double *Ainv, int n, int batch);

For tiny matrices, such as 3x3, you would want to handle each matrix within a single thread. The register usage when doing this should be fairly light for 3x3 matrices, maybe around 32 registers.

Since the code from BatchedSolver_v1_1 is under an open source license, I may as well quote the core portion of the code for inverting 3x3 matrices here. It is as follows:

/*
 * Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 *
 *   Redistributions of source code must retain the above copyright notice, 
 *   this list of conditions and the following disclaimer.
 *
 *   Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 *   Neither the name of NVIDIA Corporation nor the names of its contributors
 *   may be used to endorse or promote products derived from this software 
 *   without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

template<typename T, int arch>
__global__ void matinv_3x3_matrix_per_thread (const T *A, T *Ainv, int batch)
{
    const int blkNum = blockIdx.y * gridDim.x + blockIdx.x;
    const int thrdNum = blkNum * blockDim.x + threadIdx.x;
    const int N = 3;
    int perm0, perm1, perm2;
    int icol0, icol1, icol2;
    T AA00, AA01, AA02; 
    T AA10, AA11, AA12;
    T AA20, AA21, AA22;
    T tmp;
#if USE_PIVOTING
    typename config<T,arch>::absValType t;
    typename config<T,arch>::absValType p;
    int i, pvt;
#endif

    A    += thrdNum * N * N;
    Ainv += thrdNum * N * N;

    if (thrdNum < batch) {

        AA00 = A[0];
        AA10 = A[1];
        AA20 = A[2];
        AA01 = A[3];
        AA11 = A[4];
        AA21 = A[5];
        AA02 = A[6];
        AA12 = A[7];
        AA22 = A[8];

        perm0 = 0;
        perm1 = 1;
        perm2 = 2;
        
        /****************** iteration 0 ***********/

#if USE_PIVOTING
        /* search pivot row */
        p = absOp (AA00);
        pvt = 0;
        t = absOp (AA10);
        if (t > p) { p = t;  pvt = 1; }
        t = absOp (AA20);
        if (t > p) { p = t;  pvt = 2; }
        
        /* swap pivot row with row 0 */
        if (pvt == 1) {
            tmp = AA00;  AA00 = AA10;  AA10 = tmp;
            tmp = AA01;  AA01 = AA11;  AA11 = tmp;
            tmp = AA02;  AA02 = AA12;  AA12 = tmp;
            /* update permutation vector based on row swap */
            i = perm0;  perm0 = perm1;  perm1 = i;
        }
        if (pvt == 2) {
            tmp = AA00;  AA00 = AA20;  AA20 = tmp;
            tmp = AA01;  AA01 = AA21;  AA21 = tmp;
            tmp = AA02;  AA02 = AA22;  AA22 = tmp;
            /* update permutation vector based on row swap */
            i = perm0;  perm0 = perm2;  perm2 = i;
        }
#endif // USE_PIVOTING

        /* scale current row */
        tmp = rcpOp (AA00);
        icol0 = perm0;
        AA00 = tmp;
        AA01 = mulOp (tmp, AA01);
        AA02 = mulOp (tmp, AA02);

        /* eliminate above and below current row */
        tmp = AA10;
        AA10 = mulOp (negOp(tmp), AA00);
        AA11 = fmnaOp (tmp, AA01, AA11);
        AA12 = fmnaOp (tmp, AA02, AA12);

        tmp = AA20;
        AA20 = mulOp (negOp(tmp), AA00);
        AA21 = fmnaOp (tmp, AA01, AA21);
        AA22 = fmnaOp (tmp, AA02, AA22);
        
        /****************** iteration 1 ***********/

#if USE_PIVOTING
        /* search pivot row */
        p = absOp (AA11);
        pvt = 1;
        t = absOp (AA21);
        if (t > p) { p = t;  pvt = 2; }

        /* swap pivot row with row 1 */
        if (pvt == 2) {
            tmp = AA10;   AA10 = AA20;   AA20 = tmp;
            tmp = AA11;   AA11 = AA21;   AA21 = tmp;
            tmp = AA12;   AA12 = AA22;   AA22 = tmp;
            /* update permutation vector based on row swap */
            i = perm1;   perm1 = perm2;   perm2 = i;
        }
#endif // USE_PIVOTING

        /* scale current row */
        tmp = rcpOp (AA11);
        icol1 = perm1;
        AA10 = mulOp (tmp, AA10);
        AA11 = tmp;
        AA12 = mulOp (tmp, AA12);

        /* eliminate above and below current row */
        tmp = AA01;
        AA00 = fmnaOp (tmp, AA10, AA00);
        AA01 = mulOp (negOp(tmp), AA11);
        AA02 = fmnaOp (tmp, AA12, AA02);
        
        tmp = AA21;
        AA20 = fmnaOp (tmp, AA10, AA20);
        AA21 = mulOp (negOp(tmp), AA11);
        AA22 = fmnaOp (tmp, AA12, AA22);
        
        /****************** iteration 2 ****************/

        /* scale current row */
        tmp = rcpOp (AA22);
        icol2 = perm2;
        AA20 = mulOp (tmp, AA20);
        AA21 = mulOp (tmp, AA21);
        AA22 = tmp;

        /* eliminate above and below current row */
        tmp = AA02;
        AA00 = fmnaOp (tmp, AA20, AA00);
        AA01 = fmnaOp (tmp, AA21, AA01); 
        AA02 = mulOp (negOp(tmp), AA22);

        tmp = AA12;
        AA10 = fmnaOp (tmp, AA20, AA10);
        AA11 = fmnaOp (tmp, AA21, AA11);
        AA12 = mulOp (negOp(tmp), AA22);

        /* sort columns into the correct order */
        Ainv(0,icol0) = AA00;
        Ainv(1,icol0) = AA10;
        Ainv(2,icol0) = AA20;
        Ainv(0,icol1) = AA01;
        Ainv(1,icol1) = AA11;
        Ainv(2,icol1) = AA21;
        Ainv(0,icol2) = AA02;
        Ainv(1,icol2) = AA12;
        Ainv(2,icol2) = AA22;
    }
}

The overloaded operations used in the code above are, in the case of double precision:

__device__ __forceinline__ double fmnaOp (double a, double b, double c)
{
    return -(a * b) + c;
}

__device__ __forceinline__ double mulOp (double a, double b)
{
    return a * b;
}

__device__ __forceinline__ double rcpOp (double a)
{
    return 1.0 / a;
}

__device__ __forceinline__ double absOp (double a)
{
    return fabs(a);
}

__device__ __forceinline__ double negOp (double a)
{
    return -(a);
}

I was unable to locate the Batched Solver code any longer on the developer website.

The cross-posted question has an answer:

[url]python - How to perform PyCUDA 4x4 matrix inversion with same accuracy than numpy linalg "inv" or "pinv" function - Stack Overflow

Argh. [Later:] Actually, it seems logged-in registered CUDA developers can download the file in question here:

https://developer.nvidia.com/rdp/assets/cuda-batched-solver-tgz

OP now changed their post from a 3x3 matrix to a 4x4 matrix. Anyhow, below is working code (using partial pivoting) based on NVIDIA’s code in BatchedSolver_v1_1.tgz that delivers the result desired by the OP. The output of the program should look like this:

C:\Users\Norbert\My Programs>matinv4x4
 4.7152660477227583e-010   0.0000000000000000e+000   0.0000000000000000e+000  -0.0000000000000000e+000
-0.0000000000000000e+000   2.8111471873964824e-028  -2.8111471868342523e-048  -5.6222943747929646e-038
 0.0000000000000000e+000  -2.8111471868342523e-048   2.8111471873964824e-028  -5.6222943742307346e-048
 0.0000000000000000e+000  -5.6222943747929646e-038  -5.6222943742307346e-048   5.6222943747929647e-028

Note that the use of FMA in the GPU version leads to slightly different (presumably: more accurate) results than those returned by the CPU version. When configured for pivoting, and compiled for an sm_61 target by CUDA 8, the kernel requires 80 registers. When configured without pivoting, the kernel requires 60 registers.

#include <cstdio>
#include <cstdlib>

#define MAT_COLUMN_MAJOR  (0)
#define USE_PIVOTING      (1)

/*
 * Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 *
 *   Redistributions of source code must retain the above copyright notice, 
 *   this list of conditions and the following disclaimer.
 *
 *   Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 *   Neither the name of NVIDIA Corporation nor the names of its contributors
 *   may be used to endorse or promote products derived from this software 
 *   without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

__device__ __forceinline__ double fmnaOp (double a, double b, double c)
{
    return -(a * b) + c;
}

__device__ __forceinline__ double mulOp (double a, double b)
{
    return a * b;
}

__device__ __forceinline__ double rcpOp (double a)
{
    return 1.0 / a;
}

__device__ __forceinline__ double absOp (double a)
{
    return fabs(a);
}

__device__ __forceinline__ double negOp (double a)
{
        return -(a);
}

#if MAT_COLUMN_MAJOR
#define Ainv(row,col)    Ainv[(col)*N+(row)]
#else
#define Ainv(row,col)    Ainv[(row)*N+(col)]
#endif

template<typename T>
__global__ void matinv_4x4_matrix_per_thread (const T *A, T *Ainv, int batch)
{
    const int blkNum = blockIdx.y * gridDim.x + blockIdx.x;
    const int thrdNum = blkNum * blockDim.x + threadIdx.x;
    const int N = 4;
    int perm0, perm1, perm2, perm3;
    int icol0, icol1, icol2, icol3;
    T AA00, AA01, AA02, AA03; 
    T AA10, AA11, AA12, AA13;
    T AA20, AA21, AA22, AA23;
    T AA30, AA31, AA32, AA33;
    T tmp;
#if USE_PIVOTING
    T t;
    T p;
    int i, pvt;
#endif

    A    += thrdNum * N * N;
    Ainv += thrdNum * N * N;

    if (thrdNum < batch) {

#if MAT_COLUMN_MAJOR
        AA00 = A[0];
        AA10 = A[1];
        AA20 = A[2];
        AA30 = A[3];
        AA01 = A[4];
        AA11 = A[5];
        AA21 = A[6];
        AA31 = A[7];
        AA02 = A[8];
        AA12 = A[9];
        AA22 = A[10];
        AA32 = A[11];
        AA03 = A[12];
        AA13 = A[13];
        AA23 = A[14];
        AA33 = A[15];
#else
        AA00 = A[0];
        AA01 = A[1];
        AA02 = A[2];
        AA03 = A[3];
        AA10 = A[4];
        AA11 = A[5];
        AA12 = A[6];
        AA13 = A[7];
        AA20 = A[8];
        AA21 = A[9];
        AA22 = A[10];
        AA23 = A[11];
        AA30 = A[12];
        AA31 = A[13];
        AA32 = A[14];
        AA33 = A[15];
#endif

        perm0 = 0;
        perm1 = 1;
        perm2 = 2;
        perm3 = 3;
        
        /****************** iteration 0 ***********/

#if USE_PIVOTING
        /* search pivot row */
        p = absOp (AA00);
        pvt = 0;
        t = absOp (AA10);
        if (t > p) { p = t;  pvt = 1; }
        t = absOp (AA20);
        if (t > p) { p = t;  pvt = 2; }
        t = absOp (AA30);
        if (t > p) { p = t;  pvt = 3; }
        
        /* swap pivot row with row 0 */
        if (pvt == 1) {
            tmp = AA00;  AA00 = AA10;  AA10 = tmp;
            tmp = AA01;  AA01 = AA11;  AA11 = tmp;
            tmp = AA02;  AA02 = AA12;  AA12 = tmp;
            tmp = AA03;  AA03 = AA13;  AA13 = tmp;
            /* update permutation vector based on row swap */
            i = perm0;  perm0 = perm1;  perm1 = i;
        }
        if (pvt == 2) {
            tmp = AA00;  AA00 = AA20;  AA20 = tmp;
            tmp = AA01;  AA01 = AA21;  AA21 = tmp;
            tmp = AA02;  AA02 = AA22;  AA22 = tmp;
            tmp = AA03;  AA03 = AA23;  AA23 = tmp;
            /* update permutation vector based on row swap */
            i = perm0;  perm0 = perm2;  perm2 = i;
        }
        if (pvt == 3) {
            tmp = AA00;  AA00 = AA30;  AA30 = tmp;
            tmp = AA01;  AA01 = AA31;  AA31 = tmp;            
            tmp = AA02;  AA02 = AA32;  AA32 = tmp;
            tmp = AA03;  AA03 = AA33;  AA33 = tmp;
            /* update permutation vector based on row swap */
            i = perm0;  perm0 = perm3;  perm3 = i;
        }
#endif // USE_PIVOTING

        /* scale current row */
        tmp = rcpOp (AA00);
        icol0 = perm0;
        AA00 = tmp;
        AA01 = mulOp (tmp, AA01);
        AA02 = mulOp (tmp, AA02);
        AA03 = mulOp (tmp, AA03);

        /* eliminate above and below current row */
        tmp = AA10;
        AA10 = mulOp (negOp(tmp), AA00);
        AA11 = fmnaOp (tmp, AA01, AA11);
        AA12 = fmnaOp (tmp, AA02, AA12);
        AA13 = fmnaOp (tmp, AA03, AA13);

        tmp = AA20;
        AA20 = mulOp (negOp(tmp), AA00);
        AA21 = fmnaOp (tmp, AA01, AA21);
        AA22 = fmnaOp (tmp, AA02, AA22);
        AA23 = fmnaOp (tmp, AA03, AA23);

        tmp = AA30;
        AA30 = mulOp (negOp(tmp), AA00);
        AA31 = fmnaOp (tmp, AA01, AA31);
        AA32 = fmnaOp (tmp, AA02, AA32);
        AA33 = fmnaOp (tmp, AA03, AA33);

/****************** iteration 1 ***********/

#if USE_PIVOTING
        /* search pivot row */
        p = absOp (AA11);
        pvt = 1;
        t = absOp (AA21);
        if (t > p) { p = t;  pvt = 2; }
        t = absOp (AA31);
        if (t > p) { p = t;  pvt = 3; }

        /* swap pivot row with row 1 */
        if (pvt == 2) {
            tmp = AA10;   AA10 = AA20;   AA20 = tmp;
            tmp = AA11;   AA11 = AA21;   AA21 = tmp;
            tmp = AA12;   AA12 = AA22;   AA22 = tmp;
            tmp = AA13;   AA13 = AA23;   AA23 = tmp;
            /* update permutation vector based on row swap */
            i = perm1;   perm1 = perm2;   perm2 = i;
        }
        if (pvt == 3) {
            tmp = AA10;   AA10 = AA30;   AA30 = tmp;
            tmp = AA11;   AA11 = AA31;   AA31 = tmp;
            tmp = AA12;   AA12 = AA32;   AA32 = tmp;
            tmp = AA13;   AA13 = AA33;   AA33 = tmp;
            /* update permutation vector based on row swap */
            i = perm1;   perm1 = perm3;   perm3 = i;
        }
#endif // USE_PIVOTING

        /* scale current row */
        tmp = rcpOp (AA11);
        icol1 = perm1;
        AA10 = mulOp (tmp, AA10);
        AA11 = tmp;
        AA12 = mulOp (tmp, AA12);
        AA13 = mulOp (tmp, AA13);

        /* eliminate above and below current row */
        tmp = AA01;
        AA00 = fmnaOp (tmp, AA10, AA00);
        AA01 = mulOp (negOp(tmp), AA11);
        AA02 = fmnaOp (tmp, AA12, AA02);
        AA03 = fmnaOp (tmp, AA13, AA03);
        
        tmp = AA21;
        AA20 = fmnaOp (tmp, AA10, AA20);
        AA21 = mulOp (negOp(tmp), AA11);
        AA22 = fmnaOp (tmp, AA12, AA22);
        AA23 = fmnaOp (tmp, AA13, AA23);
        
        tmp = AA31;
        AA30 = fmnaOp (tmp, AA10, AA30);
        AA31 = mulOp (negOp(tmp), AA11);
        AA32 = fmnaOp (tmp, AA12, AA32);
        AA33 = fmnaOp (tmp, AA13, AA33);
        
        /****************** iteration 2 ****************/

#if USE_PIVOTING
        /* search pivot row */
        p = absOp (AA22);
        pvt = 2;
        t = absOp (AA32);
        if (t > p) { p = t;  pvt = 3; }

        /* swap pivot row with row 2 */
        if (pvt == 3) {
            tmp = AA20;   AA20 = AA30;   AA30 = tmp;
            tmp = AA21;   AA21 = AA31;   AA31 = tmp;
            tmp = AA22;   AA22 = AA32;   AA32 = tmp;
            tmp = AA23;   AA23 = AA33;   AA33 = tmp;
            /* update permutation vector based on row swap */
            i = perm2;   perm2 = perm3;   perm3 = i;
        }
#endif // USE_PIVOTING

        /* scale current row */
        tmp = rcpOp (AA22);
        icol2 = perm2;
        AA20 = mulOp (tmp, AA20);
        AA21 = mulOp (tmp, AA21);
        AA22 = tmp;
        AA23 = mulOp (tmp, AA23);

        /* eliminate above and below current row */
        tmp = AA02;
        AA00 = fmnaOp (tmp, AA20, AA00);
        AA01 = fmnaOp (tmp, AA21, AA01); 
        AA02 = mulOp (negOp(tmp), AA22);
        AA03 = fmnaOp (tmp, AA23, AA03);

        tmp = AA12;
        AA10 = fmnaOp (tmp, AA20, AA10);
        AA11 = fmnaOp (tmp, AA21, AA11);
        AA12 = mulOp (negOp(tmp), AA22);
        AA13 = fmnaOp (tmp, AA23, AA13);

        tmp = AA32;
        AA30 = fmnaOp (tmp, AA20, AA30);
        AA31 = fmnaOp (tmp, AA21, AA31);
        AA32 = mulOp (negOp(tmp), AA22);
        AA33 = fmnaOp (tmp, AA23, AA33);

        /****************** iteration 3 ****************/

        /* scale current row */
        tmp = rcpOp (AA33);
        icol3 = perm3;
        AA30 = mulOp (tmp, AA30);
        AA31 = mulOp (tmp, AA31);
        AA32 = mulOp (tmp, AA32);
        AA33 = tmp;

        /* eliminate above and below current row */
        tmp = AA03;
        AA00 = fmnaOp (tmp, AA30, AA00);
        AA01 = fmnaOp (tmp, AA31, AA01);
        AA02 = fmnaOp (tmp, AA32, AA02);
        AA03 = mulOp (negOp(tmp), AA33);

        tmp = AA13;
        AA10 = fmnaOp (tmp, AA30, AA10);
        AA11 = fmnaOp (tmp, AA31, AA11);
        AA12 = fmnaOp (tmp, AA32, AA12);
        AA13 = mulOp (negOp(tmp), AA33);

        tmp = AA23;
        AA20 = fmnaOp (tmp, AA30, AA20);
        AA21 = fmnaOp (tmp, AA31, AA21);
        AA22 = fmnaOp (tmp, AA32, AA22);
        AA23 = mulOp (negOp(tmp), AA33);

        /* sort columns into the correct order */
        Ainv(0,icol0) = AA00;
        Ainv(1,icol0) = AA10;
        Ainv(2,icol0) = AA20;
        Ainv(3,icol0) = AA30;
        Ainv(0,icol1) = AA01;
        Ainv(1,icol1) = AA11;
        Ainv(2,icol1) = AA21;
        Ainv(3,icol1) = AA31;
        Ainv(0,icol2) = AA02;
        Ainv(1,icol2) = AA12;
        Ainv(2,icol2) = AA22;
        Ainv(3,icol2) = AA32;
        Ainv(0,icol3) = AA03;
        Ainv(1,icol3) = AA13;
        Ainv(2,icol3) = AA23;
        Ainv(3,icol3) = AA33;
    }
}

// Macro to catch CUDA errors in CUDA runtime calls
#define CUDA_SAFE_CALL(call)                                          \
do {                                                                  \
    cudaError_t err = call;                                           \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

// Macro to catch CUDA errors in kernel launches
#define CHECK_LAUNCH_ERROR()                                          \
do {                                                                  \
    /* Check synchronous errors, i.e. pre-launch */                   \
    cudaError_t err = cudaGetLastError();                             \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
    /* Check asynchronous errors, i.e. kernel failed (ULF) */         \
    err = cudaDeviceSynchronize();                                    \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString( err) );      \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

int main (void)
{
    double mat [4][4] = 
        {{2.120771107884677649e+09, 0.000000000000000000e+00, 0.000000000000000000e+00, 0.000000000000000000e+00},
         {0.000000000000000000e+00, 3.557266600921528288e+27, 3.557266600921528041e+07, 3.557266600921528320e+17},
         {0.000000000000000000e+00, 3.557266600921528041e+07, 3.557266600921528288e+27, 3.557266600921528041e+07},
         {0.000000000000000000e+00, 3.557266600921528320e+17, 3.557266600921528041e+07, 1.778633300460764144e+27}};
    double matinv [4][4] = {0};
    double *mat_d = 0, *matinv_d = 0;

    CUDA_SAFE_CALL (cudaMalloc ((void **)&mat_d, sizeof (mat)));
    CUDA_SAFE_CALL (cudaMalloc ((void **)&matinv_d, sizeof (matinv)));
    CUDA_SAFE_CALL (cudaMemcpy (mat_d, mat, sizeof (mat), cudaMemcpyHostToDevice));
    CUDA_SAFE_CALL (cudaMemset (matinv_d, 0xff, sizeof (matinv)));
    matinv_4x4_matrix_per_thread<double><<<1,1>>>(mat_d, matinv_d, 1);
    CHECK_LAUNCH_ERROR();
    CUDA_SAFE_CALL (cudaMemcpy (matinv, matinv_d, sizeof (matinv), cudaMemcpyDeviceToHost));
    
    for (int row = 0; row < 4; row++) {
        for (int col = 0; col < 4; col++) {
            printf ("% 23.16e  ", matinv[row][col]);
        }
        printf ("\n");
    }
    printf ("\n");
    
    CUDA_SAFE_CALL (cudaFree (mat_d));
    CUDA_SAFE_CALL (cudaFree (matinv_d));
    return EXIT_SUCCESS;
}

Should you notice that the code is limited by global memory bandwidth with the sub-optimal global memory access pattern as a contributing factor, try buffering in shared memory: All threads of a thread block cooperate to copy the relevant source data from global memory into shared memory as efficiently as possible, then each thread pulls its data from shared memory into registers to compute its assigned matrix inverse. Since writing data is typically not as critical, the inverses might be written out directly to global memory.

I am not quite sure how to interepret the information on the problem size. Original I assumed a 3D problem 128x256x512, but if the tiny matrices to be inverted number in fact between 128 and 512, this is not a problem well matched to the use of GPUs, which require parallelism on the order of several tens of thousands of active threads to be fully utilized. If this computation is a tiny portion of a GPU-based computation, this should be nothing to be worried about, but if it represents a hot spot because it is extremely often, that might be a problem.