Lots of small matrices

Hello everyone,

So, from what I’ve read, CUDA+GPU can achieve impressive speedups in linear algebra for large matrices. I have a different problem that I would like to solve, concerning which I am hoping for some advice.

The matrices in my application are fairly small, only around 80x80 to 100x100… but I have 40,000 of them. I must explicitly invert all of them on each iteration of the calculation (I actually need the inverse). On my current computer, I clocked inverting just one matrix at approximately 50ms, at which rate it will take over half an hour just to complete the matrix inversions, which is more than a little depressing.

Would CUDA+GPU be of much help on this problem? All the stuff I’ve read concerning matrix operations suggests that there’s little, if any, benefit to doing GPU matrix math on small matrices, but parallelization seems it could be very useful for what I’m trying to do… Say instead of aiming all the GPU resources at inverting one matrix really, really quickly, could they be instead partitioned into doing many inversions in parallel?

Thank you in advance!

A GPU does SAME calculations on many different data sets simultaneously as opposed to Multi-core CPU which can do different type of calculations. As in your case, matrices are appropriate for this type of calculations, but it depends a lot on the algorithm being used as it should be able to make the best use of GPU. Well, I dont know much about the algorithms for inversion of a matrix, so I am unable to comment more here.
As for the speed up, if programmed to make the most of GPU, a GPU like NVidia GTX 470 can bring down your over 30 mins of computation to 30 seconds, because it has 448 stream cores each running at 1.2 GHz.
I want recommend u read the CUDA programming guide to get an idea of how it works
http://developer.download.nvidia.com/compute/cuda/3_2_prod/toolkit/docs/CUDA_C_Programming_Guide.pdf
So it depends on the algorithm, and u can share details about your algorithm, but I am sure u can get speed up.

Thank you for your reply. So I found an older thread, on essentially the same question:

So if I may abbreviate my original question:

  1. Can I run 40,000 small matrix inversions in parallel on a GPU, or any respectable fraction of that?

  2. Would there be any point in doing so? (ie. speed gain?)

Correct me if I’m wrong, but the thread cited above suggests that the performance gains are uncertain at best. (The reasons aren’t entirely clear to me.)

As far as I know, matrix inversion is normally accomplished through LU matrix decomposition, which is a series of row-wise arithmetic operations on a matrix. The unfortunate thing is that LU normally requires an if-statement to swap rows (“pivot”) in the matrix to avoid a division by zero, and GPU’s really don’t like conditional statements, so I understand.

I’m aware that people use the parallel capabilities of GPUs to speed up LU factorizations for large, single matrices. That is a much different problem; I’m not an expert in numerical linear algebra, so I’m uncertain how they get around the pivoting.

So, is there still any hope for my problem?

Yes, just launch many inverting threads, as much as you have memory.

There is good hope, but you may have to write your own as there still doesn’t seem to be support for operations on many matrices VS single large matrices.

I believe for example culatools have this in their pipeline…

You definately can run many inversions in paralellel, but as said before you will need to write your own code for that. My guess is that it is best to start a threadblock for each inversion. That way you can use the shared memory of a multiproc for the inversion.

I dont see any reason why the potential speedup should be smaller than for large matrices as long as the total number of operations is similar - if at all I would guess that it could be faster [though until you get yourself at a code that polished as the ones available for large matrices you probably have to spend a lot of time].

Cheers

Ceearem

Googling “small matrices” yields this forum thread fairly high up in the results list, so I’ll hazard a guess that there’s a substantial number of people interested in similar problems. I recently ran some tests on LU factorization with partial pivoting, so I thought I’d share my results.

1. 1-thread per Matrix approach

First and foremost, having threads each invert one matrix is a terrible, terrible idea. The CPU will dance circles around the GPU. It would appear that such an approached is doomed by the numerous reads and writes to global memory.

2. Multiple threads per Matrix approach

I took the advice of some members of this forum, including Ceearem above, to have one block of threads collaboratively invert one matrix. The matrix for one block is kept entirely in shared memory (this is essential). For an NxN matrix, I chose N threads per block, the threads sweeping through the matrix row by row. This works waaay better than the 1-thread per matrix approach.

Code for LU factorization of a real matrix is posted below, and extension to complex matrices is straight-forward. Unfortunately it’s highly divergent, and every time it advances one row there’s one more thread with nothing to do. I would appreciate it if anyone has suggestions on how to improve on this.

3. GPU vs CPU results

I used a Quadro 600 GPU, and one core of a Intel Xeon 5680 (3.3GHz). Maybe I’m being unfair in comparing an mid-class GPU to a very high-end CPU, but that’s what I could get my hands on. The test is LU decomposition of 50,000 randomly generated 32x32 matrices. The best execution time out of three runs is reported below.

Quadro 600 (CUDA kernel running in PyCUDA)

single precision LU decomposition, real matrix: 1.06s

single precision LU decomposition, complex matrix: 2.27

Xeon 5680, one core (MATLAB 2009a)

single precision LU decomposition, real matrix: 0.86s

single precision LU decomposition, complex matrix: 1.37s

Some caveats: the Quadro 600 was probably doing some stuff keeping the desktop graphics running, so competition may have been stacked against it. For the CPU, I also tried LU decomposition in Fortran, but it was actually slower than MATLAB; the reason for this is probably that MATLAB is using ATLAS and I think my LAPACK is using the netlib reference BLAS.

4. Conclusions

Granted that my code is probably non-optimal, but there doesn’t seem there’s any advantage to using a GPU for this sort of problem. I think I am limited primarily by the number of multiprocessors right now, so I’m sure much better performance could be attained on a more powerful GPU; Quadro 600 has 2 cores, whereas a deluxe GPGPU like the Tesla C2070 has 14 cores, so I’d expect somewhere in the vicinity of 6-7 fold improvement at best. But if I harnessed all 6 cores of the Xeon via, say, OpenMP, the performance would be similar. In any event, there’s not going to be the orders-of-magnitude speedups seen with large matrix computations.

Comments and second opinions would be welcome.

5. Code

CUDA Kernel:

// ran with 32 thread blocks, blockDim.x = 1000, blockDim.y = 50

// in PyCUDA. sz=32 is the matrix dimension.

__global__ void LUfactor(float *AG, int *PG) {

    int bx = blockIdx.x + blockIdx.y*blockDim.x;

    int tx = threadIdx.x;

__shared__ float A[sz][sz];

    __shared__ int piv[sz];

int m,n,tmp;

    __shared__ float v;

//Collaboratively load matrix into shared memory

    for(n=0; n<sz; n++){

        A[n][tx] = AG[bx*sz*sz + n*sz + tx];

    }

    piv[tx] = tx;

    __syncthreads();

//LU decomposition

    for(n=0; n<sz; n++){

//pivoting

        if (tx==0){

            for(m=n+1; m<sz; m++) {

                if( fabsf(A[piv[m]][n]) > fabsf(A[piv[n]][n]) ){

                    tmp = piv[n];

                    piv[n] = piv[m];

                    piv[m] = tmp;

                }

            }

            v = 1/A[piv[n]][n];

        }

        __syncthreads();

//L block normalization

        if (tx>n) {

            A[piv[tx]][n] *= v;

        }

        __syncthreads();

//reduction

        for (m=n+1; m<sz; m++) {

            if (tx>n) {

                A[piv[m]][tx] -= A[piv[m]][n]*A[piv[n]][tx];

            }

        }

    }

//copy back to global memory

    for(n=0; n<sz; n++){

        AG[bx*sz*sz + n*sz + tx] = A[n][tx];

    }

    PG[bx*sz + tx] = piv[tx];

}

MATLAB code

A = rand(32,32,50000,'single');                                               

A = rand(32,32,50000,'single') + single(i)*rand(32,32,50000,'single');        

tic;                                                                          

for n = 1:50000                                                               

    A(:,:,n) = lu(A(:,:,n));                                                  

end          

toc

You state that for an NxN matrix your code uses N threads. So if N=32 as in your example, there is only one warp running per thread block. Since at most 8 thread blocks can run on the same multiprocessor simultaneously, this will result in low occupany, preventing the code from extracting the full performance of the hardware. You could assign multiple threads to each row, or in the extreme try assigning each matrix element to a thread. Since this last variant limits the code to one thread block of 1024 threads, it is probably also not optimal. You may need to experiment to find the optimal tradeoff between work per thread and threadblocks per multiprocessor.

For matrices of this size it is probably better to perform explicit row exchanges instead of using a pivot vector, because the latter approach will likely use more memory accesses and instructions than the former. [Added later:] When LU decompositon is used to solve systems of linear equations, applying such explicit row exchanges to the matrix also requires that equivalent exchanges are applied to the RHSs. So getting rid of the pivot vector makes sense only if the number of RHS is small. How many RHSs are there going to be for each matrix being factored?

A choke point of the algorithm is the pivot search which uses only a single thread. Since this is basically a reduction, it would be beneficial to introduce some degree of parallelism (whether a full tree-like scheme makes sense at this size is unclear, it would be best to experiment).

I realize that my occupancy sucks, but I see no way to further subdivide work so that more threads can participate. The problem with LU factorization, at least according to standard methods, is that each row update requires results from the previous row; this is what motivated my choice of N threads for an NxN matrix. Each thread updates one matrix element in the current row. If, for example, I had 2N threads, one set of N could work on the current row, but what could the other set of N possibly do? Likewise, if each matrix element had it’s own thread, most of the threads would just hang idle for the vast majority of the calculation, waiting on results from a previous row. If there’s an LU algorithm out there that employs greater parallelism, I’d love to know about it, but my own searches haven’t proven fruitful. LU factorization as far as I can tell is necessarily sequential, with parallelism only possibly in the updates between each pivot.

I think I’d better keep the pivot vector. For my application I need to explicitly calculate matrix inverses, so there’s always N x RHSs.

I see your point that the pivot search is a bottleneck (I did a linear search with one thread mostly out of simplicity). If I turn off pivoting, the execution time drops to 0.66s, a 0.40s improvement. Assuming an ideal binary comparison search, the execution time ratio between a tree search and the linear search should be approximately sum(log(1:32)/log(2))/sum(1:32) = 0.2228. So I estimate a lower bound on the execution time to be 0.2228*0.40 + 0.66 = 0.75s — about 30% improvement, and clocking better than the CPU. Respectable perhaps, but it’s a pretty marginal benefit over the CPU for a much greater coding effort.

blockDim.x = 1000, blockDim.y = 50

I suppose this should this be gridDim.x = 1000 and gridDim.y = 50 ?

So it seems to me you are comparing suboptimal GPU code against an optimized CPU routine. Adding to this that you are comparing a high-end CPU against a low end GPU your results aren’t very surprising.

Funnily enough, I have done this exact exercise using the single thread per matrix approach. I have a smallish template library which implements selected bit of BLAS and LAPACK as device functions for use in kernels. I wrote it with the intention of working on smaller matrices that your case - 32x32 is at the very upper end of what I would usually work with. And my conclusion is the exact opposite of yours. I just built a little example to test my code for your 32x32 case, I can factorize 50,000 random, double precision matrices in 0.26 seconds on my GTX 470. For smaller matrices, the performance is considerably better - the 16x16 case takes 0.036s, then 8x8 case takes 0.005s.

The kernel I used for this is very, very prosaic - everything in local memory, with coalesced loads and stores from local memory to global using a single thread per matrix:

template<typename T, int N>

void __global__ lakernel(T * InOut, const int LDA, const int M)

{

    const int NN = N*N;

    T A_[NN];

    int p_[N];

int tid = threadIdx.x + blockIdx.x * blockDim.x;

if (tid < M) {

        // Load matries to memory

        {

            T * pos = InOut + tid;

            for(int i=0; i<NN; i++, pos+=LDA) {

                A_[i] = *pos;

            }

        }

// Do the decomposition

        fortranArray<T> A(&A_[0],N);

        fortranVector<int> p(&p_[0]);

        getf2<T,N>(A,p);

// Save factorization

        {

            T * pos = InOut + tid;

            for(int i=0; i<NN; i++, pos+=LDA) {

                *pos = A_[i];

            }

        }

    }

}

I don’t pretend that any of this code is even remotely optimized, and internally everything works in column major order (hence the wrapper classes), but I think it shows that there is some scope for the single thread per matrix approach to be useful.

1 Like

Binary search shouldn’t be that difficult to code. Throw in full loop unrolling and some micro-optimizations (exchange if and for in the reduction, compute fabs only once), and use a decent GPU, and you might end up with code running 20x faster on the GPU than on one core of the CPU. I think that is about what you can realistically expect from GPU computing.

Obviously parallelism is limited during the pivot and scaling phases. But additional threads can be used for work during the reduction phase, as there is independent work across “m”.

If the ultimate goal is matrix inversion, it might be worthwhile looking into Gauss-Jordan elimination as an alternative to LU decomposition.

I don’t know how fine-grained the scoreboarding is for shared memory, but doesn’t the GPU already extract the instruction level parallelism from the reduction phase?

In the code as written O(N) threads are working on the reduction. One could have O(N^2) threads working on it when running with an N x N thread block. That might not be optimal either as it limits us to a single thread block per multi-process for 32 x 32 matrices. But some intermediate scheme (for example, a 2 x N thread block where during reduction, odd and even rows are handled by different threads) might work well. The goal is to increase occupancy to where we can be sure that latencies are covered.

I was thinking along the lines of Vasily Volkov’s presentation Better Performance at Lower Occupancy. But I’m not sure the scoreboarding is fine-grained enough so the latencies of the reduction operations can be hidden with just a single warp per block. Of course this uncertainty can easily be overcome by turning the instruction-level parallelism into thread-level parallelism to make it explicit.

I will readily agree that occupancy is not always strongly correlated with performance. However an occupancy below 0.333 on Fermi is unlikely to result in the best performance, and in this code the occupancy is only 0.166 at N=32. What exacerbates the problem here is that fewer and fewer threads actively participate as we loop over “n”.

Which approach works best will depend on the specific properties of the problem to be solved. The following questions will likely inform the choice of algorithm and the mapping of work on the GPU: What is the size of each matrix? How many matrices are in each batch? Are we solving general systems of linear equations, or are we inverting matrices? If the former, how many RHSs are there typically for a given matrix? Are they all applied at the same time?

I dug out a quick and rough implementation of LU decomposition.

This was able to do 50,000 random 32 by 32 matrices (real only) in 0.129 s on my laptop: GTX460m (mid-end mobile GPU). With good scaling you could expect 3x that performance on a high end desktop GPU such as the GTX 580. Meaning potentially 129/3 ~ 0.043 s.

That would be 0.86 / 0.043 ~ 20x speedup over your CPU implementation.

So I didn’t spend much time on this and I’m in no way assuming that it’s fully optimized.

OK, since this is somewhat similar to something I might want to do in my own code, I’ve tried the different optimizations myself.

With the original code, 50,000 random real 32x32 matrices take 1.41s on the GT 9600m in my laptop. Simply moving the [font=“Courier New”]if[/font] out of the reduction loop reduces the time to 1.06s. Parallelizing the pivot search further reduces the time to 0.89s. Using two warps during the reduction phase further brings this down to 0.641s, or 0.638s with four warps. Loop unrolling had no effect at all.

So Norbert was right, the instruction level parallelism isn’t exploited automatically in this case.

Of course the exercise was pretty moot regarding the total possible speedup, as Jimmy had already confirmed my 20x educated guess using his own code. However I was also interested in seeing what each optimization step would gain us.

EDIT: Oh, and eliminating the bank conflicts takes us to 0.574s.

1 Like

Okay, this is way more feedback than I expected… Thank you everyone for your input, it’s been very helpful.

So kinda slightly off topic, one of the factors in leading me to conclude that GPU vs CPU wasn’t advantageous was that a Tesla C2050/C2050 costs more than a hex-core CPU (and I believe fair comparison requires comparison against all six cores, not just one). But after reading this thread, I looked at the commodity gaming cards, and to my surprise their performance metrics aren’t all that different. Yet a Tesla C2050 cost 5 times the price of a GTX ($2400 vs $580). If I’m not concerned about fast double precision calculations and ECC, why would I ever buy a high-end Tesla or Quadro? (Double precision I’m reasonably certain I could live without; is ECC really that important?)

Tera, could you be persuaded to post your modifications to my original code? I would like to run it on my own system to see if there’s a difference (and I’m sure studying it would be edifying for me personally, as a beginner to CUDA). Since my earlier post, I tried swapping the if and for (1.06s down to 0.91s), parallelizing the pivot search (0.91s down to 0.86s), but using two warps didn’t anything for me (maybe I’m doing it wrong).