Gamma function implementation with significantly improved accuracy

A recent publication compares the accuracy of standard math library functions across multiple toolchains, among them CUDA 11:

Vincenzo Innocente and Paul Zimmermann, “Accuracy of mathematical functions in single, double, extended double and quadruple precision”, HAL Preprint hal-03141101v2, January 2022

The data collected for CUDA seems to jibe nicely with the accuracy data published in the CUDA Programming Guide. One function that sticks out negatively is the single-precision gamma function, tgammaf(). I double checked the accuracy of that in CUDA 11 and found a maximum error of 10.2368 ulp in the positive half plane and maximum error of 11.4591 ulp in the negative half plane.

Even with a simple straightforward implementation the maximum error can be reduced significantly. To prove the point, I whipped up a new implementation my_tgammaf() with significantly improved error bounds. Maximum error drops to 2.5583 ulp in the positive half plane and 5.2922 ulp in the negative half plane. Substituting my_tgammaf() for CUDA 11’s built-in tgammaf() should generally be roughly performance neutral. In fact, on an sm_75 platform I see minimum, average, and maximum execution times all reduced by about 6%-7% compared when making the switch. But since the function has some argument-dependent branches, performance will tend to fluctuate a bit depending on argument distribution.

I tested exhaustively with-ftz={true|false} and -prec-div={true|false}.

[Code below updated 3/21/2022, 4/7/2022]

/*
  Copyright (c) 2022, Norbert Juffa
  All rights reserved.

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* Check whether argument is an even integer */
__device__ int is_even_int (float a)
{
    return (-2.0f * floorf (0.5f * a) + a) == 0.0f;
}

/* Compute exponential base e. Maximum ulp error = 0.86565 */
__device__ float my_expf (float a)
{
    float f, r, j, s, t;
    unsigned int ia;
    int i;

    // exp(a) = 2**i * exp(f); i = rintf (a / log(2))
    j = fmaf (1.442695f, a, 12582912.f) - 12582912.f; // 0x1.715476p0, 0x1.8p23
    f = fmaf (j, -6.93145752e-1f, a); // -0x1.62e400p-1  // log_2_hi 
    f = fmaf (j, -1.42860677e-6f, f); // -0x1.7f7d1cp-20 // log_2_lo 
    i = (int)j;
    // approximate r = exp(f) on interval [-log(2)/2, +log(2)/2]
    r =             1.37805939e-3f;  // 0x1.694000p-10
    r = fmaf (r, f, 8.37312452e-3f); // 0x1.125edcp-7
    r = fmaf (r, f, 4.16695364e-2f); // 0x1.555b5ap-5
    r = fmaf (r, f, 1.66664720e-1f); // 0x1.555450p-3
    r = fmaf (r, f, 4.99999851e-1f); // 0x1.fffff6p-2
    r = fmaf (r, f, 1.00000000e+0f); // 0x1.000000p+0
    r = fmaf (r, f, 1.00000000e+0f); // 0x1.000000p+0
    // exp(a) = 2**i * r
    ia = (i > 0) ? 0u : 0x83000000u;
    s = __int_as_float (0x7f000000u + ia);
    t = __int_as_float (((unsigned int)i << 23) - ia);
    r = r * s;
    r = r * t;
    // handle special cases: severe overflow / underflow
    if (fabsf (a) >= 104.0f) r = s * s;
    return r;
}

/* Compute sin(pi*a) accurately. Maximum ulp error = 0.96677 */
__device__ float my_sinpif (float a)
{
    float r, s, t, u;
    int i;

    /* Reduce argument to [-0.25, +0.25] */
    t = rintf (a + a);
    i = (int)t;
    t = fmaf (t, -0.5f, a);
    /* Apply minimax polynomial approximations to reduced argument */
    s = t * t;
    /* Approximate sin(pi*x) for x in [-0.25,0.25] */
    u =             -5.95458984e-1f;  // -0x1.30e000p-1
    u = fmaf (u, s,  2.55037689e+0f); //  0x1.4672c0p+1
    u = fmaf (u, s, -5.16772366e+0f); // -0x1.4abbfcp+2
    u = (s * t) * u;
    u = fmaf (t, 3.14159274e+0f, u);  //  0x1.921fb6p+1
    /* Approximate cos(pi*x) for x in [-0.25,0.25] */
    r =              2.30957031e-1f;  //  0x1.d90000p-3
    r = fmaf (r, s, -1.33492684e+0f); // -0x1.55bdc4p+0
    r = fmaf (r, s,  4.05869961e+0f); //  0x1.03c1bcp+2
    r = fmaf (r, s, -4.93480206e+0f); // -0x1.3bd3ccp+2
    r = fmaf (r, s,  1.00000000e+0f); //  0x1.000000p+0
    /* Select cosine or sine polynomial based on quadrant */
    r = (i & 1) ? r : u;
    /* Don't change "sign" of NaNs or create negative zeros */
    r = (i & 2) ? (0.0f - r) : r; 
    /* IEEE-754: sinPi(+n) is +0 and sinPi(-n) is -0 for positive integers n */
    r = (a == floorf (a)) ? (a * 0.0f) : r;
    return r;
}

/* Compute log(a) with extended precision, returned as a double-float value 
   loghi:loglo. Maximum relative error about 3.8324e-9
*/
__device__ void my_logf_ext_tgamma (float a, float *loghi, float *loglo)
{
    const float LOG2_HI =  6.93147182e-1f; //  0x1.62e430p-1
    const float LOG2_LO = -1.90465421e-9f; // -0x1.05c610p-29
    float m, r, i, s, t, p, qhi, qlo;
    int e;

    /* Reduce argument to m in [sqrt(0.5), sqrt(2.0)] */
    i = 0.0f;
    /* fix up denormal inputs */
    if (a < 1.175494351e-38f){ // 0x1.0p-126
        a = a * 8388608.0f; // 0x1.0p+23
        i = -23.0f;
    }
    e = (__float_as_int (a) - float_as_int (0.70710678f)) & 0xff800000;
    m = __int_as_float (__float_as_int (a) - e);
    i = fmaf ((float)e, 1.19209290e-7f, i); // 0x1.0p-23
    /* Compute q = (m-1)/(m+1) as a double-float qhi:qlo */
    p = m + 1.0f;
    m = m - 1.0f;
    asm ("rcp.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(p)); // 1.0f / p
    qhi = r * m;
    qlo = r * fmaf (qhi, -m, fmaf (qhi, -2.0f, m));
    /* Approximate atanh(q), q in [sqrt(0.5)-1, sqrt(2)-1] */ 
    s = qhi * qhi;
    r =             0.129211426f;  // 0x1.08a000p-3
    r = fmaf (r, s, 0.142010465f); // 0x1.22d662p-3
    r = fmaf (r, s, 0.200014427f); // 0x1.99a12ap-3
    r = fmaf (r, s, 0.333333254f); // 0x1.555550p-2
    r = r * s;
    s = fmaf (r, qhi, fmaf (r, qlo, qlo)); // (q**2*r)*(qhi:qlo) + qlo
    r = 2 * qhi;
    /* log(a) = 2 * atanh(q) + i * log(2) */
    t = fmaf ( LOG2_HI, i, r);
    p = fmaf (-LOG2_HI, i, t);
    s = fmaf ( LOG2_LO, i, fmaf (2.f, s, r - p));
    *loghi = p = t + s;    // normalize double-float result
    *loglo = (t - p) + s;
}

/* Compute the gamma function. Maximum error positive half-plane: 2.55831 ulp.
   Maximum error negative half-plane: 5.29219 ulp.

   x > 1: gamma(x) = sqrt(2*pi)*pow(x,x-0.5)*exp(-x)*(1+(1/x)*poly(1/x))
   x in [-0.0325, 1]: gamma(x) = 1.0 / (x + x**2 * P(x))
   x < -0.0325: gamma(x) = pi / (sin (pi*x) * |x| * gamma (|x|))
*/
__device__ float my_tgammaf (float a)
{
    float fa, b, e, f, r, s, t, loglo, loghi, thi, tlo;
    const float PI = 3.14159265f;
    const float SQRT_2PI_HI = 2.50662804e+0f; // 0x1.40d930p+1
    const float SQRT_2PI_LO = 2.38131975e-7f; // 0x1.ff6270p-23
    const float CANONICAL_NAN = __int_as_float (0x7fffffffu);
    const float CENTRAL_INTVL_UBOUND = 1.0f;
    const float CENTRAL_INTVL_LBOUND = -0.03125f;
    const float PREMATURE_UNDERFLOW_BOUND = -34.1875f;
    const float UNDERFLOW_BOUND = -41.125f;
    const float OVERFLOW_BOUND = 35.0400963f;

    fa = fabsf (a);
    if (fa >= CENTRAL_INTVL_UBOUND) { // Laplace formula
        asm ("rcp.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(fa)); // 1.0f / fa
        /* compute log(a) as a normalized double-float */
        my_logf_ext_tgamma (fa, &loghi, &loglo);
        /* compute (a - 0.5) * log(a) as an unnormalized double-float thi:tlo */
        b = fa - 0.5f;
        thi = loghi * b;
        tlo = fmaf (loghi, b, -thi);
        tlo = fmaf (loglo, b, tlo);
        /* compute ((a - 0.5) * log(a) - a) as unnormalized double-float t:s */
        t = thi - fa;
        e = t + fa;
        f = (e - t) - fa;
        e = thi - e;
        s = (e + f) + tlo;
        /* compute exp((a - 0.5) * log(a) - a) */
        if (a < PREMATURE_UNDERFLOW_BOUND) {
            /* compute sqrt (exp((a-0.5) * log(a) - a)) to prevent underflow */
            t = t * 0.5f;
            s = s * 0.5f;
        }
        e = my_expf (t);
        e = fmaf (s, e, e);
        /* gamma(a) = (1+(1/a)*poly(1/a))*sqrt(2*pi)*exp((a-0.5)*log(a)-a) */
        s =             -7.90506601e-6f;  // -0x1.094000p-17
        s = fmaf (s, r,  1.54409601e-4f); //  0x1.43d206p-13
        s = fmaf (s, r, -6.40788290e-4f); // -0x1.4ff526p-11
        s = fmaf (s, r,  1.10814150e-3f); //  0x1.227e1ep-10
        s = fmaf (s, r, -3.10655858e-4f); // -0x1.45bf0cp-12
        s = fmaf (s, r, -2.67055770e-3f); // -0x1.5e090cp-9
        s = fmaf (s, r,  3.47157661e-3f); //  0x1.c706c8p-9
        s = fmaf (s, r,  8.33333358e-2f); //  0x1.555556p-4
        t = fmaf (e, SQRT_2PI_HI, e * SQRT_2PI_LO);
        r = fmaf (r * s, t, t);
        if (a > OVERFLOW_BOUND) r = INFINITY;
    } else { // near 0, NaNs
        b = (a < CENTRAL_INTVL_LBOUND) ? fa : a;
        /* compute gamma(b) = 1 / (b * poly (b)) */
        r =             -1.08385086e-3f;  // -0x1.1c2000p-10
        r = fmaf (r, b,  6.55601546e-3f); //  0x1.ada7b0p-8
        r = fmaf (r, b, -8.75671301e-3f); // -0x1.1ef0a2p-7
        r = fmaf (r, b, -4.27411087e-2f); // -0x1.5e229ap-5
        r = fmaf (r, b,  1.66718066e-1f); //  0x1.557048p-3
        r = fmaf (r, b, -4.20317464e-2f); // -0x1.5852f6p-5
        r = fmaf (r, b, -6.55876338e-1f); // -0x1.4fcf06p-1
        r = fmaf (r, b,  5.77215672e-1f); //  0x1.2788d0p-1
        r = fmaf (fmaf (r, b, 0.0f), b, b); // handle -0 properly
        asm ("rcp.approx.f32 %0,%0;" : "+f"(r)); // 1.0f / r
    }
    if (a < CENTRAL_INTVL_LBOUND) {
        t = truncf (a);
        r = __fdiv_rn (PI, my_sinpif (a) * fa * r);
        if (a < PREMATURE_UNDERFLOW_BOUND) {
            r = __fdividef (r, e);
        }
        if (a < UNDERFLOW_BOUND) {
            r = __int_as_float (((unsigned int)is_even_int (t)) << 31);
        }
        if (t == a) {
            r = CANONICAL_NAN;
        }
    }
    return r;
}
2 Likes