Fairly accurate and robust implementation of the logarithm of the beta function

The beta function and its logarithm are useful in various statistical computations. For a, b > 0, the beta function is defined as B(a,b) = Γ(a) * Γ(b) / Γ(a+b). When computing with IEEE-754 floating-point arithmetic, this function overflows easily for arguments of modest magnitude, which is why much practical computation is performed in the logarithmic space, i.e. involves log (B(a,b)).

One might think that all that is needed is to compose the logarithm for the beta function from the logarithm of the gamma function, for which C++ and CUDA offer the standard math function lgamma(). However, even this is prone to overflow in intermediate computation and can give rise to accuracy issues. In the 1990s Nico Temme showed how the scaled gamma function Γ* (see section 5.11.3 in NIST’s Digital Library of Mathematical Functions) can be used advantageously to largely avoid these problems, and this is the approach I adopted for the single-precisionlbetaf() implementation below.

A few caveats remain. Like all other implementations of the logarithm of the beta function that I have come across, lbetaf() has poor relative accuracy for arguments in the immediate neighborhood around unity, say 1± (1/16) as function results approach zero. Good absolute accuracy however is maintained. In terms of robustness note that all accuracy is lost when min (a, b) / (a+b) underflows to zero in finite-precision floating-point arithmetic, e.g. when computing lbetaf (1e-30f, 1e+30f), even though the mathematical result of log (B(a,b)) is well within the representable range of the IEEE-754 binary32 (single precision) format. Again, this defect is shared with other widely used implementations of the logarithm of the beta function, and I am not aware of any use case where this causes issues.

The code below was tested with the default compiler settings, that is -ftz=false -prec-div=true -prec-sqrt=true. Adjustments might be needed in order to maintain accuracy and robustness with other settings. Outside the accuracy issues noted above, I see maximum error around 8 ulp, with average error more around 3 ulps.

NOTE: Computing the beta function B(a,b) as expf(lbetaf(a,b)) is not advisable in the general case due to the well-known error magnification properties of exponentiation.

/*
  Copyright (c) 2023, Norbert Juffa

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* Compute scaled gamma function, Γ*(a) = sqrt(a/2*pi)*exp(a)*pow(a,-a)*Γ(a) */
__device__ float gammastarf (float x)
{
    const float MY_NAN_F = __int_as_float (0x7fffffff);
    float r, s;

    if (x <= 0.0f) {
        r = MY_NAN_F;
    } else if (x < 1.0f) {
        /* (0, 1): maximum error 4.17 ulp */
        if (x < 3.5e-9f) {
            const float oosqrt2pi = 0.39894228040143267794f; // 1/sqrt(2pi)
            r = oosqrt2pi / sqrtf (x);
        } else {
            r = 1.0f / x;
            r = gammastarf (x + 1.0f) * expf (fmaf (x, log1pf (r), -1.0f)) * 
                sqrtf (1.0f + r);
        }
    } else {
        /* [1, INF]: maximum error 0.56 ulp */
        r = 1.0f / x;
        s =              1.24335289e-4f;  //  0x1.04c000p-13
        s = fmaf (s, r, -5.94899990e-4f); // -0x1.37e620p-11
        s = fmaf (s, r,  1.07218279e-3f); //  0x1.1910f8p-10
        s = fmaf (s, r, -2.95283855e-4f); // -0x1.35a0a8p-12
        s = fmaf (s, r, -2.67404946e-3f); // -0x1.5e7e36p-9
        s = fmaf (s, r,  3.47193284e-3f); //  0x1.c712bcp-9
        s = fmaf (s, r,  8.33333358e-2f); //  0x1.555556p-4
        r = fmaf (s, r,  1.00000000e+0f); //  0x1.000000p+0
    }
    return r;
}

/* Compute the logarithm of the beta function, ln(B(x,y)) */
__device__ float lbetaf (float x, float y)
{
    const float MY_NAN_F = __int_as_float (0x7fffffff);
    const float sqrt_2pi = 2.506628274631000502416f;
    float mn, mx, sum, r;

    mx = fmaxf (x, y);
    mn = fminf (x, y);
    sum = x + y;
    if (isnan (x) || isnan (y)) {
        r = sum;
    } else if ((mn <= 0) || (isinf (x)) && (isinf (y))) {
        r = MY_NAN_F;
    } else if (isinf (mx)) {
        r = mx;
    } else if (mx >= 1.7265625f) {
        float half_mn, mn_over_sum;
        // compute mn/sum avoiding premature overflow in sum
        half_mn = 0.5f * mn;
        mn_over_sum = half_mn / fmaf (0.5f, mx, half_mn);
        r = fmaf (mn - 0.5f, logf (mn_over_sum), 
                  fmaf (mx, log1pf (-(mn_over_sum)),
                        fmaf (-0.5f, logf (mx),
                              logf (sqrt_2pi * gammastarf (mn) *
                                    gammastarf (mx) / gammastarf (sum)))));
    } else {
        r = (lgammaf (mn) - lgammaf (sum)) + lgammaf (mx);
    }
    return r;
}
1 Like