Originally published at: https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/
Update May 21, 2018: CUTLASS 1.0 is now available as Open Source software at the CUTLASS repository. CUTLASS 1.0 has changed substantially from our preview release described in the blog post below. We have decomposed the structure of the GEMM computation into deeper, structured primitives for loading data, computing predicate masks, streaming data at each level…
Thanks for the great tutorial!
I am trying to understand better what "fusing element-wise operation" means. I implement lot's of custom LSTM with pytorch (and fusing is a big problem if I understand stuff correctly). I don't write CUDA codes, so the explanation in the tutorial about the gemm::epilogue_op are hard to follow for me. I am just looking for a theoretical understanding.
Please let me know if the following is correct:
Suppose I want to compute ReLU(A*B). Without fusing the pointwise operation, this means that I launch a GEMM kernel to compute A*B. Once the kernel is finished it will send the product C=A*B back to global memory. Then I will launch a kernel to compute ReLU(C). To do this I will need to go to fetch the matrix C in global memory, send it to shared memory, and then threshold all the entries of C. Obviously in this last step, all the time is spent in fetching the matrix C from global memory. The goal of fusing is to eliminate this unnecessary fetching time.
In the "fusing scenario" we launch a single GEMM kernel, with simply an extra line at the end of the kernel code to threshold each entry of C once they become available available.
Did I understand correctly what "fusing element-wise operation" means?
Thanks
Hi Alain,
That's correct. Although, ReLU(C) wouldn't need to stage through shared memory since each element is only accessed once. But it saves a load and store of C.
Niall
Thanks for the write up!
But I don't quite get the essence of the thread tile. In figure 5, it seems that one thread is responsible for calculating the outer product for 4 locations in the warp accumulator, I don't understand where the 8x8 matrix (on the right of fig 5) comes from?
Also, to my understanding, threads cannot share their register space, how is the thread tile achieving O(mn) computations with only O(m+n) loads (as illustrated in the slides)? Are you using the __shfl_sync function to build another caching layer?
Thanks!
Hello Andrew. This is Isaac from GTC who had a fortune of talking to you about CUTLASS. Your explanation of CUTLASS was extremely helpful. Thanks so much.
In the "Complete GEMM" block code, this line:
accumulator[thread_x][thread_y] += frag_a[y]*frag_b[x];
seems to contain a typo. Should y
be replaced with thread_y
and x
with thread_x
?
Hello Andrew, I'm somewhat confused as to how you're getting simultaneous global load and computation in the same CTA (Software Pipelining) when those sections are separated by a syncthreads in your GEMM pseudo-code. My understanding was that in the following setup, all threads in a CTA must either be in section A or section B.
for(...){
//Section A
shared[i] = global[i];
syncthreads();
//Section B
result = compute(shared[i]);
syncthreads();
}
Am I missing something fundamental about what syncthreads actually does? My apologies if this is the wrong place to ask this. However this is the only place I've seen anyone suggest that loading from global into shared can be pipelined with computation using only a couple syncthreads.
Thanks
Replying for Andrew:
Two buffers in shared memory are allocated. One is actively being written by values fetched from global memory loads (the threadblock tile), while the other SMEM buffer is being loaded from into registers (the warp-scoped tile). At the appropriate point in the mainloop body, all data has been written to one SMEM buffer and all data has been loaded from the other and issued to multiply-add instructions, so the threadblock issues a barrier and then exchanges pointers.
Because one buffer is only being written to by the threads of a threadblock, and one thread is only being loaded from, there is no hazard. This permits a single barrier and latency tolerance of global memory.
Here’s pseudocode mirroring your example:
__shared__ float shared[2][N]; // two buffers
int write_buffer = 0;
int read_buffer = 1;
for (...) {
tmp_registers = global[i]; // global load
result = compute(shared[read_buffer][j]); // math instructions
shared[write_buffer][i] = tmp_registers; // shared memory store
syncthreads();
swap(read_buffer, write_buffer); // exchange pointers
}
While this doesn't directly answer my question, I like the solution you have presented here. It addresses the main issue we have with using shared memory. All of the simple examples that I've seen that use shared memory seem to throw away the nice latency hiding feature of the GPU with naive usage of syncthreads. Double buffering the shared memory is such a simple and elegant way regaining performance that I'm surprised it isn't presented as a standard way of using shared memory.
Thank you very much Jen.
Thanks for the feedback! I'll pass this along to Andrew.
Sorry for digging this up, but I am really confused by Figure 5.
In particular, I don’t understand how we got the “8-by-8 overall thread tile”. There are a total of 4*8=32 threads in the wrap, each computing a 2-by-2 block (there are 4 green cells on the left). How do we get 8-by-8 from that?
Additionally, does thread tile mean “a part of the wrap tile as seen from a single thread’s point of view”?
Hi! I am deeply impressed by the “different policy for different type of SGEMM”. But I can not find to decide which size is which type, such as “tall”, “large”, and so on. Is it possible to provide me a link for the code to decide the SGEMM type? Thank you!!!
I have the exact question for that. How did we get 8x8 from that?
Confused too,any idea about that?
The state-of-the-art architecture at the time of that post was Volta, which had Tensor Cores, each capable of doing 64 fused multiply-add (FMA) operations per clock. That’s why the thread tile was organized in a 8 x 8 grid. Ampere did 256 FMA operations per clock. Here’s a good post explaining more.