Accuracy-optimized implementation of expm1f() without performance penalty

The function expm1(x) :=exp(x)-1, often used in conjunction with log1p(), was introduced some forty years ago to avoid numerical issues with subtractive cancellation when using plain exp().

A typical example can be seen in this recent question on Math Stackexchange, which arose in the context of neural networks.

For a standard math library it is desirable for the maximum error of simple math functions to be less than 1 ulp, making such functions faithfully rounded. An independent report shows the maximum error in CUDA’s current implementation of the single-precision variant expm1f() as < 1.45 ulp:

Brian Gladman, Vincenzo Innocente, John Mather, Paul Zimmermann, Accuracy of Mathematical Functions in Single, Double, Double Extended, and Quadruple Precision. HAL preprint hal-03141101, Feb. 2024 (online).

I reproduced this particular finding of the report, in that the maximum error of expm1f(x) in CUDA 12.3 is found as 1.44463616 ulp @ x = 17.0407963f. The following code shows a faithfully-rounded implementation of expm1f() that does not incur a performance penalty compared to the built-in function from CUDA 12.3; on some GPU architectures it is likely even be a bit faster.

/*
  Copyright (c) 2015-2024, Norbert Juffa
  All rights reserved.

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* Compute exponential base e minus 1. Maximum ulp error = 0.997458

   i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
   Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
   With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
   when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.

   NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
*/
__device__ float my_expm1f_scaled_unchecked (float a, float b)
{
    float f, j, r, s, t, u, v, x, y;
    int i;

    // exp(a) = 2**i * exp(f); i = rintf (a / log(2))
    j = fmaf (1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
    i = __float_as_int (j); // trailing bits contain integer
    j = j - 12582912.f; // 0x1.8p23
    f = fmaf (j, -6.93145752e-1f, a); // -0x1.62e400p-1  // log_2_hi 
    f = fmaf (j, -1.42860677e-6f, f); // -0x1.7f7d1cp-20 // log_2_lo

    // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
    s = f * f;
    if (a == 0.0f) s = a; // ensure -0 is passed through
    r = fmaf (1.98662281e-4f, f, 1.39354519e-3f);//0x1.a0a000p-13,0x1.6d4f3cp-10
    t = fmaf (8.33332818e-3f, f, 4.16667648e-2f);//0x1.111106p-7, 0x1.55558ap-5
    r = fmaf (r, s, t);
    r = fmaf (r, f, 1.66666716e-1f); // 0x1.55555cp-3
    r = fmaf (r, f, 4.99999970e-1f); // 0x1.fffffep-2

    // if i == 0, expm1(a) = z
    // if i == 1, expm1(a) = 2*(r*(f*f)+f+0.5)
    // if (i < 0) || (i > 1) expm1(a) = 2*((r*(f*f)+f)*t-0.5+t)
    u = (j == 1) ? (f + 0.5f) : f;
    v = fmaf (r, s, u);
    s = 0.5f * b;
    t = __int_as_float ((i << 23) + __float_as_int (s));
    y = t - s;
    x = t - y;
    x = x - s; // double-float canonicalization of difference
    r = fmaf (v, t, x);
    r = r + y;
    r = r + r;
    if (j == 0) r = v;
    if (j == 1) r = v + v;
    return r;
}

/* Compute exponential base e minus 1. max ulp err = 0.99746 */
__device__ float my_expm1f (float a)
{
    float r;

    r = my_expm1f_scaled_unchecked (a, 1.0f);
    /* handle severe overflow and underflow */
    if (fabsf (a - 1.0f) > 88.0f) {
        asm ("ex2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
        r = fmaf (r, r, -1.0f);
    }
    return r;
}
3 Likes