Computing the log of a matmul from the log of two matrices

Hello everyone,

I am looking for a way to perform the elementwise log of a matrix multiplication, from the elementwise log of both matrices, so I want $\log(AB)$ from $\log(A)$ and $\log(B)$.

My goal initially is to implement this in Triton, but I’m open to any solution in CUDA or even pseudocode. Do you have any suggestions how I could modify the code in the Triton tutorial to avoid losing too much efficiency?

https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py

As far as I know, this is not generally possible for arbitrary matrices A, B. Only when AB == BA is log (AB) = log(A) + log(B). Are the matrices always commutative in your case?

Thank you for your answer. Sorry, I forgot to mention I’m talking about the elementwise log of the matrices. Basically, the matrices A and B contain very small probabilities (sometimes exp(-6000)) so that I can’t store them without underflowing in standard float32. Now I would need to form the elementwise log of the matrix product, but I’m not sure how to do that efficiently.

My current solution consists in computing the rowwise max of log(A) and columnwise max of log(B), then taking the exponentials of log(A)-rowwisemax and log(B)-columnwisemax, doing matmul, and then taking the log of the matmul + rowwisemax@columnwisemax basically.

I wondered if there were more efficient ways of doing that. It seems right now the exponentiation is the bottleneck by far: I must be saturating the special function unit but I don’t know how to do it any better.

I’m working with an RTX 4090.

One element of a matrix multiplication is a sum of products. Now you want to compute the log of this sum. Let’s assume those products=summands are all positive. Then the resulting log is larger than log((max in A row) * (max in B column)) and smaller than log((number of sum terms)*maxA*maxB. That new forum version is really awful for formatting.

Perhaps a better approach is to roughly normalize the A and B matrices multiplicatively (leads to a constant offset of the result after the log).

And then make a polynomial approximation of the final log. You would do the actual exact matrix multiplication and then create different element-wise powers, which you combine.

It is probably more exact in the end than using log A and log B. But would not need the special function unit.

Not sure whether the normalization helps enough with your small non-representable values.

Or you do the log A and log B representation, but instead of matrix multiplications, where each resulting element is a scalar product of two vectors (one column and one row), you want just a sum of the two vectors and from it the largest component.