A more accurate, performance-competitive implementation of lgammaf()

Vincenzo Innocente and Paul Zimmermann, “Accuracy of Mathematical Functions in Single, Double, Double Extended, and Quadruple Precision”, Preprint hal-03141101, February 2022, highlights the potentially very large error of the lgammaf() function in CUDA 11. While I cannot reproduce the worst case error stated in this report, 1.35e7 ulp, I find worst case error of 3.7366 ulps in the positive half-plane and 7.1824e6 ulps in the negative half-plane.

Accuracy in the positive half-plane is important, while I am not aware of a single use case involving lgammaf() in the negative half-plane. Traditionally, math libraries have computed lgammaf() in the negative half-plane by simple application of the reflection formula for the gamma function. This leads to catastrophic cancellation near the zeros of log (abs (gamma (x))), in particular in the interval [-11, -2], where all accuracy is lost.

However, I note that the math library for the Intel toolchain has for many years delivered accurate result for lgammaf() in the negative half-plane, and glibc started doing so in 2015. A reasonable hypothesis is that the extra effort is expended for the benefit of rare use cases. Fixing the accuracy in the negative half-plane does not have an impact on the performance of the function in the positive half-plane.

Below I am showing a CUDA implementation of lgammaf() that reduces maximum error to 3.0739 ulps in the positive half-plane and 6.6313e+6 ulps in the negative half-plane and has an option NEG_ACCURATE to drop the maximum error for the negative half-plane to 7.1639 ulps. Accurate computation is achieved by mapping inputs from the critical interval to the interval (0, 1), and using double precision to compute that gamma function, then use the logarithm to compute lgammaf(). I have tried to make bounding box in which this “slowpath” is activated as tight as possible, so that the double-precision computation is only invoked for about 10 million arguments in the negative half-plane.

Without the NEG_ACCURATE option, the code below is faster than the CUDA 11 built-in by a few percent across the full range of arguments on a Turing GPU. Code size is also reduced. Adding the NEG_ACCURATE option will of course expand the code size significantly and has some negative performance impact for arguments in the negative half-plane. On GPUs with slow double precision, arguments triggering the slowpath will see a drastic performance drop, so a variant using double-float computation is also provided which is activated by defining FAST_DP to zero.

In general, the trade-offs seem fine to me as-is, as this limited use of double-precision (or double-float) computation makes the frequent case fast, and the infrequent case correct.

[ Code below updated 4/24/2022, 5/19/2022, 5/29/2022, 6/19/2023 ]

/*
  Copyright (c) 2022-2023, Norbert Juffa
  All rights reserved.

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* compute natural logarithm. max ulp error = 0.85089 ulp1 = 72569983 */
__device__ float my_logf (float a)
{
    const float INF_F = __int_as_float (0x7f800000);
    const float LOG_TWO = 0.693147182f; // 0x1.62e430p-1
    const float TWO_TO_23 = 8388608.0f; // 0x1.0p+23
    const float TWO_TO_M23 = 1.19209290e-7f; // 0x1.0p-23
    const float TWO_TO_M126 = 1.175494351e-38f; // 0x1.0p-126
    float i, m, r, s, t;
    int e;

    i = 0.0f;
    if (a < TWO_TO_M126) {
        a = a * TWO_TO_23;
        i = -23.0f;
    }
    e = (__float_as_int (a) - __float_as_int (0.666666667f)) & 0xff800000;
    m = __int_as_float (__float_as_int (a) - e);
    i = fmaf ((float)e, TWO_TO_M23, i);
    /* m in [2/3, 4/3] */
    m = m - 1.0f;
    s = m * m;
    /* Compute log1p(m) for m in [-1/3, 1/3] */
    r =             -0.130310059f;  // -0x1.0ae000p-3
    t =              0.140869141f;  //  0x1.208000p-3
    r = fmaf (r, s, -0.121483095f); // -0x1.f19842p-4
    t = fmaf (t, s,  0.139814854f); //  0x1.1e5740p-3
    r = fmaf (r, s, -0.166846171f); // -0x1.55b372p-3
    t = fmaf (t, s,  0.200120345f); //  0x1.99d8b2p-3
    r = fmaf (r, s, -0.249996200f); // -0x1.fffe02p-3
    r = fmaf (t, m, r);
    r = fmaf (r, m,  0.333331972f); //  0x1.5554fap-2
    r = fmaf (r, m, -0.500000000f); // -0x1.000000p-1  
    r = fmaf (r, s, m);
    r = fmaf (i,  LOG_TWO, r);
    if (!((a > 0.0f) && (a < INF_F))) {
        asm ("lg2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a)); // handle NaN, INF
    }
    return r;
}

/* log1p(a) = log(a+1) = log(2**e * t) = log(2)*e + log(t). With t = m + 1,
   log1p(a) = log(2)*e + log1p(m). Choose e such that m is in [-0.25, 0.5], 
   with s = 2**(-e) we then have m = s*(a+1) - 1 = s*a + (s - 1). Instead 
   of using s directly, an intermediate scale factor s' = 4*s is utilized 
   to ensure this is representable as a normalized floating-point number.

   max ulp err = 0.87168 ulp1=43376783
*/
__device__ float my_log1pf (float a)
{
    const float LOG_TWO = 0.693147182f; // 0x1.62e430p-1
    const float TWO_TO_M23 = 1.19209290e-7f; // 0x1.0p-23
    const float INF_F = __int_as_float (0x7f800000);
    float m, r, s, t, u; 
    int e;

    u = a + 1.0f;
    e = (__float_as_int (u) - __float_as_int (0.75f)) & 0xff800000;
    m = __int_as_float (__float_as_int (a) - e);
    s = __int_as_float (__float_as_int (4.0f) - e); // s' in [2**-126,2**26]
    m = m + fmaf (0.25f, s, -1.0f);
    // approximate log(1+m) on [-0.25, 0.5]
    s = m * m;
    r =             -4.54559326e-2f;  // -0x1.746000p-5
    t =              1.05529785e-1f;  //  0x1.b04000p-4
    r = fmaf (r, s, -1.32279143e-1f); // -0x1.0ee85ep-3
    t = fmaf (t, s,  1.44911006e-1f); //  0x1.28c71ap-3
    r = fmaf (r, s, -1.66416913e-1f); // -0x1.54d264p-3
    t = fmaf (t, s,  1.99886635e-1f); //  0x1.995e2ap-3
    r = fmaf (r, s, -2.50001878e-1f); // -0x1.00007ep-2
    r = fmaf (t, m, r);
    r = fmaf (r, m,  3.33335280e-1f); //  0x1.5555d8p-2
    r = fmaf (r, m, -5.00000000e-1f); // -0x1.000000p-1
    r = fmaf (r, s, m);
    r = fmaf ((float)e, LOG_TWO * TWO_TO_M23, r);
    if (!((a != 0.0f) && (u > 0.0f) && (a < INF_F))) {
        asm ("lg2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(u));
        r = __fadd_rd (r, a); // handle negative zero
    }
    return r;
}

/* Compute sin(pi*a) accurately. Maximum ulp error = 0.96677 */
__device__ float my_sinpif (float a)
{
    float r, s, t, u;
    int i;

    /* Reduce argument to [-0.25, +0.25] */
    t = rintf (a + a);
    i = (int)t;
    t = fmaf (t, -0.5f, a);
    /* Apply minimax polynomial approximations to reduced argument */
    s = t * t;
    /* Approximate sin(pi*x) for x in [-0.25,0.25] */
    u =             -5.95458984e-1f;  // -0x1.30e000p-1
    u = fmaf (u, s,  2.55037689e+0f); //  0x1.4672c0p+1
    u = fmaf (u, s, -5.16772366e+0f); // -0x1.4abbfcp+2
    u = (s * t) * u;
    u = fmaf (t, 3.14159274e+0f, u);  //  0x1.921fb6p+1
    /* Approximate cos(pi*x) for x in [-0.25,0.25] */
    r =              2.30957031e-1f;  //  0x1.d90000p-3
    r = fmaf (r, s, -1.33492684e+0f); // -0x1.55bdc4p+0
    r = fmaf (r, s,  4.05869961e+0f); //  0x1.03c1bcp+2
    r = fmaf (r, s, -4.93480206e+0f); // -0x1.3bd3ccp+2
    r = fmaf (r, s,  1.00000000e+0f); //  0x1.000000p+0
    /* Select cosine or sine polynomial based on quadrant */
    r = (i & 1) ? r : u;
    /* Don't change "sign" of NaNs or create negative zeros */
    r = (i & 2) ? (0.0f - r) : r; 
    /* IEEE-754: sinPi(+n) is +0 and sinPi(-n) is -0 for positive integers n */
    r = (a == floorf (a)) ? (a * 0.0f) : r;
    return r;
}

#if !FAST_DP
typedef float2 dflt; // double-float, .x = tail, .y = head

/* Compute the product of two floats a, b and return result as double-float */
__device__ dflt mul_float_to_dflt (float a, float b)
{
    dflt z;
    z.y = a * b;
    z.x = fmaf (a, b, -z.y);
    return z;
}

/* Compute product of double-float a with float b */
__device__ dflt mul_dflt_by_float (dflt a, float b)
{
    dflt t, z;
    float e;
    t.y = a.y * b;
    t.x = fmaf (a.y, b, -t.y);
    t.x = fmaf (a.x, b,  t.x);
    z.y = e = t.y + t.x;
    z.x = (t.y - e) + t.x;
    return z;
}

/* Compute product of two double-floats a, b */
__device__ dflt mul_dflt (dflt a, dflt b)
{
    dflt t, z;
    float e;
    t.y = a.y * b.y;
    t.x = fmaf (a.y, b.y, -t.y);
    t.x = fmaf (a.y, b.x, t.x);
    t.x = fmaf (a.x, b.y, t.x);
    z.y = e = t.y + t.x;
    z.x = (t.y - e) + t.x;
    return z;
}

/* Compute sum of two double-floats a, b */
__device__ dflt add_dflt (dflt a, dflt b)
{
    dflt s, t, z;
    float e;
    s.y = (fabsf (a.y) > fabsf (b.y)) ? a.y : b.y;  // famax
    s.x = (fabsf (b.y) < fabsf (a.y)) ? b.y : a.y;  // famin
    t.y = a.y + b.y;
    t.x = t.y - s.y;
    t.x = s.x - t.x;
    t.x = t.x + a.x;
    t.x = t.x + b.x;
    z.y = e = t.y + t.x;
    z.x = (t.y - e) + t.x;
    return z;
}

/* Compute a*b+c, with float b. Assumes |c| > |a*b|. Result is unnormalized! */
__device__ dflt mad_dflt_flt (dflt a, float b, dflt c)
{
    dflt s, t;
    s.y = a.y * b;
    s.x = fmaf (a.y, b, -s.y);
    s.x = fmaf (a.x, b,  s.x);
    t.y = c.y + s.y;
    t.x = t.y - c.y;
    t.x = s.y - t.x;
    t.x = t.x + c.x;
    t.x = t.x + s.x;
    return t;
}

/* Compute the reciprocal of double-float b */
__device__ dflt rcp_dflt (dflt b)
{
    dflt t, z;
    float e, r;
    asm ("rcp.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(b.y)); // 1.0f / b.y
    t.y = r;
    e = fmaf (b.y, -t.y, 1.0f);
    t.y = fmaf (r, e, t.y);
    t.x = fmaf (b.y, -t.y, 1.0f);
    t.x = fmaf (b.x, -t.y, t.x);
    e = r * t.x;
    t.x = fmaf (b.y, -e, t.x);
    t.x = fmaf (r, t.x, e);
    z.y = e = t.y + t.x;
    z.x = (t.y - e) +  t.x;
    return z;
}

/* Compute absolute value of double-float a */
__device__ dflt abs_dflt (dflt a)
{
    dflt z;
    z.x = (a.y <= 0.0f) ? (0.0f - a.x) : a.x;
    z.y = (a.y <= 0.0f) ? (0.0f - a.y) : a.y;
    return z;
}

/* Construct a double-float from head, tail data (assumed to be normalized) */
__device__ dflt mk_dflt (float head, float tail)
{
    dflt z;
    z.x = tail;
    z.y = head;
    return z;
}

__device__ float head_dflt (dflt a)
{
    return a.y;
}

__device__ float tail_dflt (dflt a)
{
    return a.x;
}
#endif // FAST_DP

/* Compute logarithm of the absolute value of the gamma function.
   [-INF, -10.0001]:    maximum ulp error = 5.2886
   (-10.0001, -2.2699): maximum ulp error = 7.1639 (NEG_ACCURATE=1), 6.6313e+6
   [-2.2699, -0]:       maximum ulp error = 5.3067
   [0, 0.64453125]:     maximum ulp error = 2.8788
   [0.64453125, 1.5]:   maximum ulp error = 2.8771
   [1.5, 3.5859375]:    maximum ulp error = 2.8815
   [3.5859375, +INF]:   maximum ulp error = 3.0739
*/
__device__ float my_lgammaf (float a)
{
    const float LOG_PI = 1.14472988f; // 0x1.250d04p+0
    const float LOG_SQRT_2PI = 0.91893853f; // 0x1.d67f1cp-1
    const float TWO_TO_24 = 16777216.0f; // 0x1.0p24
    const float TWO_TO_M21 = 4.76837158e-7f; // 0x1.0p-21
    const float INF_F = __int_as_float (0x7f800000);
    float r, s, t, p, e, fa;

    fa = fabsf (a);
    t = fa - 1.0f;
    if (fa >= 3.5859375f) { // [3.5859375, INF] // Sterling formula
        /* log(Gamma(x)) ~= log(sqrt(2*pi))+(x-0.5)*log(x)-x+(1/x)*P(1/x**2) */
        asm ("rcp.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(fa)); // 1.0f / fa
        s = r * r; // 1/(fa*fa)
        p =             -4.13298607e-4f;  // -0x1.b16000p-12
        p = fmaf (p, s,  7.77681125e-4f); //  0x1.97ba98p-11
        p = fmaf (p, s, -2.77749589e-3f); // -0x1.6c0d4cp-9
        p = fmaf (p, s,  8.33333135e-2f); //  0x1.555550p-4
        e = my_log1pf (t); // log(fa)
        r = fmaf (p, r, fmaf (e, fa - 0.5f, -fa)) + LOG_SQRT_2PI;
        if (fa == INF_F) r = fa; // handle argument of infinity
    } else if (fa >= 1.5f) { // [1.5,3.5859375)
        t = fa - 2.0f;
        p =              2.25007534e-5f;  //  0x1.798000p-16
        p = fmaf (p, t, -1.53736837e-4f); // -0x1.4268d6p-13
        p = fmaf (p, t,  5.05698787e-4f); //  0x1.0921bep-11
        p = fmaf (p, t, -1.23138563e-3f); // -0x1.42cce4p-10
        p = fmaf (p, t,  2.90764729e-3f); //  0x1.7d1c74p-9
        p = fmaf (p, t, -7.37980939e-3f); // -0x1.e3a4a8p-8
        p = fmaf (p, t,  2.05773599e-2f); //  0x1.5123b4p-6
        p = fmaf (p, t, -6.73526227e-2f); // -0x1.13e058p-4
        p = fmaf (p, t,  3.22467178e-1f); //  0x1.4a34d6p-2
        p = fmaf (p, t,  4.22784328e-1f); //  0x1.b0ee60p-2
        r = p * t;
    } else if (fa >= 0.64453125f) { // [0.64453125, 1.5) 
        p =              6.74438477e-2f;  //  0x1.144000p-4
        p = fmaf (p, t, -1.31759942e-1f); // -0x1.0dd828p-3
        p = fmaf (p, t,  1.39101863e-1f); //  0x1.1ce170p-3
        p = fmaf (p, t, -1.43345997e-1f); // -0x1.259296p-3
        p = fmaf (p, t,  1.68158934e-1f); //  0x1.5863b6p-3
        p = fmaf (p, t, -2.07328245e-1f); // -0x1.a89bb6p-3
        p = fmaf (p, t,  2.70633340e-1f); //  0x1.1520e8p-2
        p = fmaf (p, t, -4.00688589e-1f); // -0x1.9a4e1cp-2
        p = fmaf (p, t,  8.22466493e-1f); //  0x1.a51a54p-1
        p = fmaf (p, t, -5.77215672e-1f); // -0x1.2788d0p-1
        r = fmaf (p, t, 0.0f); // don't generate spurious negative zero
    } else { // [-0x1.0p-21, 0.64453125)
        /* log(Gamma(x)) ~= log (1 / (x + x**2 * P(x)) */
        p =               2.75421143e-3f;  //  0x1.690000p-9
        p = fmaf (p, fa, -5.20337597e-2f); // -0x1.aa42b4p-5
        p = fmaf (p, fa,  1.70620918e-1f); //  0x1.5d6e80p-3
        p = fmaf (p, fa, -4.28910032e-2f); // -0x1.5f5cf4p-5
        p = fmaf (p, fa, -6.55785263e-1f); // -0x1.4fc316p-1
        p = fmaf (p, fa,  5.77212155e-1f); //  0x1.27885ap-1
        p = fmaf (p * fa, fa, fa);
        r = 0.0f - my_logf (p);
    }
    if (a < -TWO_TO_M21) {
        r = LOG_PI - my_logf (fabsf (fa * my_sinpif (fa))) - r;
        if (fa >= TWO_TO_24) r = INF_F;
#if NEG_ACCURATE
        if ((fa > 2.2699f) && (fa < 10.0001f) && (fa != floorf (fa)) && 
            (r > -3.8977f) && (r < 1.9937f)) { // vicinity of zero crossings
            /* scale argument into (0,1) */
#if FAST_DP  
            double da = (double)a;
            double pr = da;
            da = da + 1.0;
#pragma unroll 1
            for (int i = 0; i < (int)fa; i++) {
                pr = pr * da;
                da = da + 1.0;
            }
            /* compute abs(gamma(da)) for da in (0, 1) */
            double pp =        1.6784122146158307e-8;
            pp = fma (pp, da, -2.1600292130263686e-7);
            pp = fma (pp, da,  1.1295363349961504e-6);
            pp = fma (pp, da, -1.2258690654347639e-6);
            pp = fma (pp, da, -2.0171353502240220e-5);
            pp = fma (pp, da,  1.2808205617755296e-4);
            pp = fma (pp, da, -2.1526018690421332e-4);
            pp = fma (pp, da, -1.1651601138724383e-3);
            pp = fma (pp, da,  7.2189411499234959e-3);
            pp = fma (pp, da, -9.6219711282132837e-3);
            pp = fma (pp, da, -4.2197734605117540e-2);
            pp = fma (pp, da,  1.6653861138599796e-1);
            pp = fma (pp, da, -4.2002635034240961e-2);
            pp = fma (pp, da, -6.5587807152025157e-1);
            pp = fma (pp, da,  5.7721566490153287e-1);
            pp = fma (pp, da * da, da);
            double g = 1.0 / (pp * fabs (pr));
            /* compute log(abs(gamma(a))) */
            r = my_log1pf ((float)(g - 1.0));
#else // FAST_DP
            dflt pr = mk_dflt (a, 0.0f);
            a = a + 1.0f;
#pragma unroll 1
            do {
                pr = mul_dflt_by_float (pr, a);
                a = a + 1.0f;
            } while (a < 0.0f);
            dflt q =                mk_dflt( 1.67841225e-8f, -3.90092527e-16f);
            q = mad_dflt_flt (q, a, mk_dflt(-2.16002917e-7f, -4.41635482e-15f));
            q = mad_dflt_flt (q, a, mk_dflt( 1.12953637e-6f, -3.69948280e-14f));
            q = mad_dflt_flt (q, a, mk_dflt(-1.22586903e-6f, -3.78135362e-14f));
            q = mad_dflt_flt (q, a, mk_dflt(-2.01713538e-5f,  2.51299524e-13f));
            q = mad_dflt_flt (q, a, mk_dflt( 1.28082058e-4f, -1.74925092e-12f));
            q = mad_dflt_flt (q, a, mk_dflt(-2.15260181e-4f, -5.91027617e-12f));
            q = mad_dflt_flt (q, a, mk_dflt(-1.16516009e-3f, -1.94846084e-11f));
            q = mad_dflt_flt (q, a, mk_dflt( 7.21894111e-3f,  3.50806051e-11f));
            q = mad_dflt_flt (q, a, mk_dflt(-9.62197129e-3f,  1.58620006e-10f));
            q = mad_dflt_flt (q, a, mk_dflt(-4.21977341e-2f, -4.87609619e-10f));
            q = mad_dflt_flt (q, a, mk_dflt( 1.66538611e-1f,  3.31577488e-10f));
            q = mad_dflt_flt (q, a, mk_dflt(-4.20026332e-2f, -1.82024407e-09f));
            q = mad_dflt_flt (q, a, mk_dflt(-6.55878067e-1f, -4.50364990e-09f));
            q = mad_dflt_flt (q, a, mk_dflt( 5.77215672e-1f, -6.63777389e-09f));
            q = mul_dflt (q, mul_float_to_dflt (a, a));
            q = add_dflt (q, mk_dflt (a, 0.0f));
            dflt g = rcp_dflt (mul_dflt (q, abs_dflt (pr)));
            /* compute log(abs(gamma(a))) */
            r = my_log1pf (head_dflt (add_dflt (g, mk_dflt (-1.0f, 0.0f))));
#endif // FAST_DP
        }
#endif // NEG_ACCURATE
    }
    return r;
}

1 Like