Overlap between TensorCore GEMM operation and Softmax (exp) operation

Hello!

I am currently analyzing a kernel that performs an attention operation on an NVIDIA RTX 3080 and 4080. The workflow involves the following sequence:

  1. A Tensor Core GEMM operation.
  2. A softmax computation (exponential using hexp2 for FP16 values) using GEMM output.
  3. Another GEMM operation using the softmax output.

The GEMM operations utilize mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16. Based on profiling results from Nsight Systems and Nsight Compute, TensorCore utilization is constrained by the exp operations inside softmax.

I am exploring ways to mitigate this bottleneck and increase Tensor Core utilization. I’m wondering if it’s possible to run the GEMM and softmax operations in parallel. For instance, could I split the registers, with half performing the softmax for one half of the data, and the other half performing the GEMM for the other half of the data (such that there is no data dependency between them), and then swap tasks after?

I attempted this approach but didn’t observe any improvement in TensorCore utilization. Even though both the GEMM and the exp operations are synchronized, I think they belong to different execution pipelines, so they might be able to be parallelized? Is there any way to make them overlap, given that there is no data dependency between them?

I would greatly appreciate any insights or suggestions to help resolve this issue.

Thank you!

If Softmax uses the output of GEMM, there surely is a data dependency? Have you tried using a low-accuracy exp replacement? I am thinking of Schraudolph’s approach in particular, an FP32 incarnation of which looks like so:

// max rel err = 3.56027983e-2  RMS rel err = 1.82537677e-2
__device__ float schraudolph_expf (float x)
{
    const int FP32_MANT_BITS = 23;
    const int FP32_EXPO_BIAS = 127;
    const float L2E = 1.442695041f;
    const int CORR = -298685;
    float a = L2E * (1 << FP32_MANT_BITS);
    int b = FP32_EXPO_BIAS * (1 << FP32_MANT_BITS) + CORR;
    int i = (int)(a * x) + b;
    return __int_as_float (i);
}

An alternative that can take advantage of FMA would be Mineiro’s variant of the same algorithm:

// max rel err = 3.56061513e-2  RMS rel err = 1.82539111e-2
__device__ float mineiro_expf (float x)
{
    const int FP32_MANT_BITS = 23;
    const float FP32_EXPO_BIAS = 127;
    const float L2E = 1.442695041f;
    const float CORR = -298685;
    float c0 = L2E * (1 << FP32_MANT_BITS);
    float c1 = FP32_EXPO_BIAS * (1 << FP32_MANT_BITS) + CORR;
    int i = (int)(c0 * x + c1);
    return __int_as_float (i);
}

Either variant should map to three SASS instructions and could conceivably be faster than computing via MUFU instructions, especially on the latest architectures. Before embarking on an FP16 implementation, you might want to bolt the above into your existing code to see whether it can tolerate the reduced accuracy.

N. N. Schraudolph. “A fast, compact approximation of the exponential function.” Neural Computation, 11(4), May 1999, pp. 853-862.

If instead you want to try and get the maximum benefit from the MUFU (multi-function unit) you could experiment with the following:

/* extract least significant 16 bits of an unsigned int into an unsigned short*/
__forceinline__ __device__ unsigned short uint2loushort (unsigned int arg)
{
    unsigned short res;
    asm ("{\n\t"
         ".reg .b16 lo, hi;\n\t"
         "mov.b32 {lo, hi}, %1;\n\t"
         "mov.b16 %0, lo;\n\t"
         "}\n\t"
         : "=h"(res) : "r"(arg));
    return res;
}

/* extract most significant 16 bits of an unsigned int into an unsigned short */
__forceinline__ __device__ unsigned short uint2hiushort (unsigned int arg)
{
    unsigned short res;
    asm ("{\n\t"
         ".reg .b16 lo, hi;\n\t"
         "mov.b32 {lo, hi}, %1;\n\t"
         "mov.b16 %0, hi;\n\t"
         "}\n\t"
         : "=h"(res) : "r"(arg));
    return res;
}

__device__ half2 raw_ex2 (half2 arg)
{
    half2 res;
    half hi, lo;
    unsigned short ilo, ihi;
    unsigned int in, out;

    lo = __low2half (arg);
    hi = __high2half (arg);
    ilo = __half_as_ushort (lo);
    ihi = __half_as_ushort (hi);
    in = ((unsigned int)ihi << 16) | ((unsigned int)ilo);
    asm ("ex2.approx.f16x2 %0, %1;\n\t" : "=r"(out) : "r"(in));
    ilo = uint2loushort (out);
    ihi = uint2hiushort (out);
    lo = __ushort_as_half (ilo);
    hi = __ushort_as_half (ihi);
    res = __halves2half2 (lo, hi);
    return res;
}

Thank you for your response!

I’ve read the paper and successfully integrated the Schraudolph exponential function into my implementation.

My setup differs slightly because I’m using a complete FP16 pipeline, since MMA with FP16 accumulators delivers 2X throughput compared to FP32 accumulators. I’ve also implemented the softmax function using full FP16. In my application, the input values for softmax are small enough that overflow should not be a concern.

Therefore, I used an FP16 version of the Schraudolph exponential function. However, I haven’t observed much performance improvements actually… Nsys indicates that the MFU pipeline is throttled, while the FMA pipelines are pretty underutilized…

When using the standard FP16 exponential function (hexp2), it utilizes the MFU pipeline. After switching to the Schraudolph exponential, the conversion from float to integer: int i = (int)(a * x) + b; or in my case, the __half2short function — also uses the MFU pipeline.

(I am using Orin, which has relatively poor MFU throughput… Each thread processes the softmax for 128 individual FP16 values each time. This is part of a larger GEMM-softmax-GEMM process with outer loops, as tiling is employed for the GEMM operation).

I did find that, despite the MFU still being throttled, the latency of each exponential (and softmax) operation has indeed decreased. However, since the softmax operation is immediately followed by a GEMM operation in this workload, I guess the throughput is a more critical factor than the latency in determining overall performance of this workload.

Did you try “__half_as_ushort”?

Bummer. That is news to me that MUFU instructions and numeric conversion instructions go through the same pipeline. I guess it makes sense in terms of re-using hardware, as the MUFU pipeline uses fixed-point computation at its core (table lookup with quadratic interpolation), but the MUFU instructions have FP32 input and outputs so the hardware must include FP32 ↔ INT32 conversion.

There are alternative schemes for floating-point to integer conversion such as the “add-magic-number” approach, but that only works over a limited range. Here we need pretty much the full range of INT32, which would require use of double arithmetic, and is therefore even less suitable for the GPU platform.

A note on Schraudolph: As I recall, the correction constants in the original paper are not tuned to achieve smallest possible relative error, something I addressed in the code I showed above.

I don’t see how that is applicable? The Schraudolph method requires float-to-int conversion, not re-interpretation.

Here is an alternate implementation of a low-precision expf without use of MUFU instructions or conversions. This should map to eight FP32 / INT32 instructions on all recent GPU architectures (Volta and newer). Note the restricted input domain!

You should be able to turn this into an FP16 implementation in straightforward manner.

/*
  Copyright (c) 2024, Norbert Juffa

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* Compute exp(x) in [-87.33654, 88.72283]. Maximum relative error 1.726e-3. 
   exp(x) = 2**i * 2**f, where i = rint (log2(e) * x), and -0.5 <= f <= 0.5 
*/
__device__ float very_fast_expf (float x)
{
    const int FP32_MANT_BITS = 23;
    const float l2e = 1.442695041f;
    const float cvt = (1 << FP32_MANT_BITS) + (1 << (FP32_MANT_BITS - 1));
    const float c0 = 0.238428936f;
    const float c1 = 0.703448006f;
    const float c2 = 1.000443142f;
    float f, p, t;
    int i;

    t = fmaf (x, l2e, cvt);
    i = __float_as_int (t);
    t = t - cvt;             // t = rint (l2e * x)
    f = fmaf (x, l2e, -t);   // f = l2e * x - t
    p =             c0;
    p = fmaf (p, f, c1);
    p = fmaf (p, f, c2);
    i = i << FP32_MANT_BITS;
    return __int_as_float (__float_as_int (p) + i);
}

Thank you for your reply!
__half2short is more like value cast, while __half_as_short is more like reinterpret cast. It’s true that __half_as_short does not go through MFU, but they serve different purposes.

Thank you so much for your response! The fast exp implementations has been very helpful for my workload!

Another bottleneck in softmax is the reduction operations, where in each thread, I need to compute the sum and max of an array of data. In the current setup, each thread calculates the max and sum of 128 values, and this process is repeated across outer loops.

This operation essentially throttles the ALU. Currently, I perform this using a simple for loop that iterates through all 128 values. I’ve considered using a tree-based reduction with loop unrolling, but since the ALU is already throttled, adding more independent operations may not improve performance. Do you see any potential optimization opportunities in this context?

Thanks!

Instead of merely considering the tree-based reduction I would suggest actually trying it. In my experience the world of HPC software engineering is one of actual experiments much more than one of thought experiments.

You might also want to check whether the Thrust library is applicable to your work to avoid re-inventing the wheel.