I attempted to fuse two GEMM operations. Although examples exist in CUTLASS, its support for fusing two GEMMs is too restrictive, making real-world implementation rather demanding. Out of personal interest, I want to implement a fused three-matrix multiplication, but I keep running into a problem for which I have yet to find a good solution:
Except for the Stream-K approach, every GEMM kernel’s CUDA launch configuration (e.g., the number of CTAs) is tied to the matrix dimensions. cuBLAS uses a heuristic to pick an optimal configuration for each individual GEMM, whereas in my case I am forced to use the same configuration for both GEMMs, inevitably incurring a performance penalty—especially when the two GEMMs have different sizes.
In one of my implementations I followed the mainstream fastest strategy and finished the first GEMM first. At that point every CTA holds its tile of intermediate results in registers. For the second GEMM, however, each CTA needs to access a row-wise band of the intermediate data, which implies inter-CTA communication. The remedies are either atomics or a global-memory write of the entire row band before the second computation, neither of which I consider a satisfactory solution.
Could you give example dimensions of the matrices to see where the bottleneck would be? E.g. if multiplying two matrices, the resulting matrix could be much smaller, equal or much larger than the 2 input matrices. Same for two multiplications.
And size and memory operations matter. In some cases it is more efficient to compute the same calculations more than once than storing and loading.
Thanks for your reply! Let me explain my implementation in more detail.
I’m still following the mainstream GEMM design, where each CTA is mapped to a block of the intermediate result matrix. This setup gives the first GEMM pretty good performance. After the first GEMM finishes, the intermediate data are stored in the registers of each CTA — meaning that one CTA can’t access a full row of the intermediate result because it’s invisible to others. To fix this, I temporarily write the intermediate result out to global memory.
Then, when I move on to the second GEMM, I run into a CTA count mismatch problem.
For example, suppose the first GEMM has dimensions M×N×K = 256×256×256. The intermediate matrix is 256×256, and if each CTA handles a 128×128 tile, that means 4 CTAs are launched.
Now, in the second GEMM, assume the third matrix has dimensions 256×COL.
-
If COL = 128, only 2 CTAs are needed.
-
If COL = 512, 8 CTAs are needed.
But we only have 4 available from the first GEMM setup, so there’s a mismatch.
I also tried another approach — mapping CTAs directly to the final output matrix instead of the intermediate one. But I found that each CTA then needs a whole row of the intermediate matrix, which depends on one row from the first matrix and the entire second matrix. This means every CTA would have to repeatedly read the entire second matrix, which is clearly not acceptable.
Yes, I understand. The mismatch of CTA numbers comes from the GEMM implementation and the needed reuse from the math of matrix multiplication
So one should tackle the reuse first in combination with the Nvidia GPU memory system and its caching levels.
Matrices can scatter and gather/reduce data.
Are square matrices used as in the example or very rectangular matrices with their longer side inside or outside in relation to the multiplication expression?
Depending on your concrete matrices, one would choose a suitable approach.
In my case, most of the matrices are irregular rather than square, and the second GEMM is usually larger than the first one. So my main challenge is figuring out how to handle the second GEMM when there aren’t enough CTAs.
My idea is that the fused version should, in theory, outperform cuBLAS. Some blogs have already shown cases where two-matrix multiplication can outperform cuBLAS.
However, when extending this to three matrices, the mismatch issue makes it hard to optimize the second GEMM as efficiently as the first one. This ends up slowing down the overall fused kernel, and honestly, there doesn’t seem to be an easy fix for it.
I don’t see why, at least with a custom kernel, the CTA mismatch would still be an issue. Let’s focus on the math instead.
If you have matrices with sizes KxL, LxM, MxN → KxN
then naively computing each output element is the addition of L*M terms, each of those terms a multiplication of 3 scalars.
So the rough relative sizing between K*N and L*M would be interesting. One or the other could be 10000 times larger, totally changing the problem statement from an implementation on the GPU view between the two extreme variants.
Also especially for the inner dimensions, L could be much larger than M or the other way around. In this case the much larger side is best computed first in a 2-matrix multiplication. The added possibly uncached memory accesses for not fusing would be negligible in that case.