An accurate single-precision implementation of the Lambert W function

The Lambert W function is the solution w of w*exp(w)=z. It is named after the Swiss mathematician Johann Heinrich Lambert (1728-1777), although (in keeping with Stigler’s law of eponymy) it was originally examined by his more famous compatriot Leonard Euler, who, however, credited Lambert with some of the underlying ideas in the relevant publication:

Leonard Euler, “De serie Lambertina plurimisque eius insignibus proprietatibus,” (On Lambert’s series and its many distinctive properties) Acta Academiae Scientiarum Imperialis Petropolitanae pro Anno MDCCLXXIX, Tomus III, Pars II, (Proceedings of the Imperial Academy of Sciences of St. Petersburg for the Year 1779, volume 3, part 2, Jul. - Dec.), St. Petersburg: Academy of Sciences 1783, pp. 29-51 (scan by Bavarian State Library, Munich)

After a slumber of some 200 years, the function was rediscovered (and named after Lambert) in the 1990s for use in computer algebra packages, but there have since been various applications in the physical sciences, a useful overview of which can be found in Wikipedia. My own interest arose from using it as a building block for the computation of the inverse of the principal branch of the gamma function.

The Lambert W function is multi-branched, but when restricted to the real numbers, solving y*exp(y)=x, there are just two branches, commonly referred to as W0 (the principal branch) for x in [-exp(-1), ∞] and W-1 for x in [-exp(-1), 0). As in cases of other multi-branched functions such as the square root, the principal branch is usually the one of interest, and is the one computed by the single-precision implementation lambert_w0f below.

The numerical computation of the Lambert W function is an area of active research, with most of the focus either on piecewise rational approximations or on starting approximations of various kinds coupled with functional iteration schemes.

I created lambert_w0f as a starting point for further investigations. In its present form it provides results with good accuracy, with maximum error of 2.56303 ulps across the entire input domain. Performance is reasonably good, with the cost of logarithmic and exponential functions mitigated as best as possible, for example by using device intrinsics. I am using simple Newton iterations, as the iterations of higher order, like Halley’s (third order) and Schröder’s method (fifth order), often advocated in the literature, seem not to perform as well as one might expect when executed in finite-precision floating-point arithmetic. In particular three Newton iterations provide more accurate results than two Halley iterations.

[ Code below updated 7/21/2022, 7/25/2022, 7/26/2022, 7/27/2022, 7/29/2022, 8/3/2022, 9/11/2022, 9/18/2022, 1/17/2023 ]

/*
  Copyright (c) 2022-2023, Norbert Juffa
  All rights reserved.

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

__forceinline__ __device__ float raw_rcp (float a)
{
    float r;
    asm ("rcp.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

__forceinline__ __device__ float raw_sqrt (float a)
{
    float r;
    asm ("sqrt.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

__forceinline__ __device__ float raw_ex2 (float a)
{
    float r;
    asm ("ex2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

__forceinline__ __device__ float raw_lg2 (float a)
{
    float r;
    asm ("lg2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

/* Compute exp(a) * 2**scale. Max ulp err = 0.86565 */
__forceinline__ __device__ float expf_scale (float a, int scale)
{
    const float flt_int_cvt = 12582912.0f; // 0x1.8p23
    float f, r, j, t;
    int i;

    // exp(a) = 2**i * exp(f); i = rintf (a / log(2))
    j = fmaf (1.442695f, a, flt_int_cvt); // // 0x1.715476p0 // log2(e)
    t = j - flt_int_cvt; 
    f = fmaf (t, -6.93145752e-1f, a); // -0x1.62e400p-1  // log_2_hi 
    f = fmaf (t, -1.42860677e-6f, f); // -0x1.7f7d1cp-20 // log_2_lo 
    i = __float_as_int (j);
    // approximate r = exp(f) on interval [-log(2)/2, +log(2)/2]
    r =             1.37805939e-3f;  // 0x1.694000p-10
    r = fmaf (r, f, 8.37312452e-3f); // 0x1.125edcp-7
    r = fmaf (r, f, 4.16695364e-2f); // 0x1.555b5ap-5
    r = fmaf (r, f, 1.66664720e-1f); // 0x1.555450p-3
    r = fmaf (r, f, 4.99999851e-1f); // 0x1.fffff6p-2
    r = fmaf (r, f, 1.00000000e+0f); // 0x1.000000p+0
    r = fmaf (r, f, 1.00000000e+0f); // 0x1.000000p+0
    // exp(a) = 2**(i+scale) * r;
    r = __int_as_float (((i + scale) << 23) + __float_as_int (r));
    return r;
}

/*
  Compute the principal branch of the Lambert W function, W_0. The maximum
  error in the positive half-plane is 1.49923 ulps and the maximum error in
  the negative half-plane is 2.56303 ulps; a total of 76403840 results are 
  not correctly rounded.
*/
__device__ float lambert_w0f (float z) 
{
    const float MY_INF = __int_as_float (0x7f800000);
    const float em1_fact_0 = 0.625529587f; // exp(-1)_factor_0
    const float em1_fact_1 = 0.588108778f; // exp(-1)_factor_1
    const float qe1 = 2.71828183f / 4.0f; // exp(1)/4
    float e, w, num, den, rden, redz, y, r;

    if (isnan (z) || (z == MY_INF) || (z == 0.0f)) return z + z;
    if (fabsf (z) < 1.220703125e-4f) return fmaf (-z, z, z); // 0x1.0p-13
    redz = fmaf (em1_fact_0, em1_fact_1, z); // z + exp(-1)
    if (redz < 0.0625f) { // expansion at -(exp(-1))
        r = raw_sqrt (redz);
        w =             -1.23046875f;  // -0x1.3b0000p+0
        w = fmaf (w, r,  2.17185670f); //  0x1.15ff66p+1
        w = fmaf (w, r, -2.19554094f); // -0x1.19077cp+1 
        w = fmaf (w, r,  1.92107077f); //  0x1.ebcb4cp+0
        w = fmaf (w, r, -1.81141856f); // -0x1.cfb920p+0
        w = fmaf (w, r,  2.33162979f); //  0x1.2a72d8p+1
        w = fmaf (w, r, -1.00000000f); // -0x1.000000p+0
    } else {
        /* Compute initial approximation. Based on: Roberto Iacono and John 
           Philip Boyd, "New approximations to the principal real-valued branch
           of the Lambert W function", Advances in Computational Mathematics, 
           Vol. 43, No. 6, December 2017, pp. 1403-1436
        */
        y = fmaf (2.0f, raw_sqrt (fmaf (qe1, z, 0.25f)), 1.0f);
        y = raw_lg2 (raw_rcp (fmaf (0.31819974f, raw_lg2 (y), 1.0f)) * 
                     fmaf (1.15262585f, y, -0.15262585f));
        w = fmaf (1.41337042f, y, -1.0f);

        /* perform Newton iterations to refine approximation to full accuracy */
        const float l2e = 1.44269504f; // log2(exp(1))

        e = raw_ex2 (fmaf (l2e, w, -3.0f));
        num = fmaf (w, e, -0.125f * z);
        den = fmaf (w, e, e);
        rden = raw_rcp (den);
        w = fmaf (-num, rden, w);

        e = raw_ex2 (fmaf (l2e, w, -3.0f));
        num = fmaf (w, e, -0.125f * z);
        den = fmaf (w, e, e);
        rden = raw_rcp (den);
        w = fmaf (-num, rden, w);

        e = expf_scale (w, -3);
        num = fmaf (w, e, -0.125f * z);
        den = fmaf (w, e, e);
        rden = raw_rcp (den);
        w = fmaf (-num, rden, w);
    }
    return w;
}

2 Likes

This is an interesting hobby of yours.

Ikr :-)

One of the things that brought me to computers decades ago was me wondering what happens under the hood after pressing the ex button on my calculator. And in the days before the internet, that was not trivial to figure out. As I recall it took me several years and access to a university library to discover all the details.

Now that I am retired and have plenty of time on my hands I am at times wondering much the same about slightly more complicated functions. Since the Lambert W function was named in the 1990s, after I was out of school, I did not know of its existence until the early 2000s. Recently, I discovered another useful function I had never heard of before: Owen’s T function. It came up in the context of some question on one of the Stackexchange sites.

1 Like

When you find the right fit for an intellect like that, it is like striking gold. That’s probably true for most people and their work.

We were lucky to have njuffa during the time spent working at NVIDIA.

Here is a double-precision implementation of the principal branch of the Lambert W function W0. It is fully functional and accurate (maximum error found so far is 2.68410 ulp) and is fully optimized within the limitations of the algorithm I chose to use.

[ Code below updated 8/16/2022, 10/7/2022, 1/22/2023, 2/14/2023 ]

/*
  Copyright 2022-2023, Norbert Juffa

  Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* exp(a) * 2**scale; pos. normal results only! Max. err. found: 0.89028 ulp */
__device__ double exp_scale_pos_normal (double a, int scale)
{
    const double ln2_hi = 6.9314718055829871e-01; // 0x1.62e42fefa00000p-01
    const double ln2_lo = 1.6465949582897082e-12; // 0x1.cf79abc9e3b3a0p-40
    const double l2e = 1.4426950408889634; // 0x1.71547652b82fe0p+00 // log2(e)
    const double cvt = 6755399441055744.0; // 0x1.80000000000000p+52 // 3*2**51
    double f, r;
    int i;

    /* exp(a) = exp(i + f); i = rint (a / log(2)) */
    r = fma (l2e, a, cvt);
    i = __double2loint (r);
    r = r - cvt;
    f = fma (r, -ln2_hi, a);
    f = fma (r, -ln2_lo, f);

    /* approximate r = exp(f) on interval [-log(2)/2,+log(2)/2] */
    r =            2.5022018235176802e-8;  // 0x1.ade0000000000p-26
    r = fma (r, f, 2.7630903497145818e-7); // 0x1.28af3fcbbf09bp-22
    r = fma (r, f, 2.7557514543490574e-6); // 0x1.71dee623774fap-19
    r = fma (r, f, 2.4801491039409158e-5); // 0x1.a01997c8b50d7p-16
    r = fma (r, f, 1.9841269589068419e-4); // 0x1.a01a01475db8cp-13
    r = fma (r, f, 1.3888888945916566e-3); // 0x1.6c16c1852b805p-10
    r = fma (r, f, 8.3333333334557735e-3); // 0x1.11111111224c7p-7
    r = fma (r, f, 4.1666666666519782e-2); // 0x1.55555555502a5p-5
    r = fma (r, f, 1.6666666666666477e-1); // 0x1.5555555555511p-3
    r = fma (r, f, 5.0000000000000122e-1); // 0x1.000000000000bp-1
    r = fma (r, f, 1.0000000000000000e+0); // 0x1.0000000000000p+0
    r = fma (r, f, 1.0000000000000000e+0); // 0x1.0000000000000p+0

    /* exp(a) = 2**(i+scale) * r */
    r = __hiloint2double (__double2hiint (r) + ((i + scale) << 20), 
                          __double2loint (r));
    return r;
}

/* compute approximate reciprocal */
__forceinline__ __device__ double rcp64 (double a)
{
    double t, r;

    asm ("rcp.approx.ftz.f64 %0,%1;" : "=d"(r) : "d"(a));
    t = fma (-a, r, 1.0); 
    t = fma (t, t, t); 
    r = fma (t, r, r); 
    return r;
}

/* compute approximate reciprocal square root */
__forceinline__ __device__ double rsq64 (double a)
{
    const double MY_NAN = __hiloint2double (0xfff80000, 0x00000000);
    double t, r;

    asm ("rsqrt.approx.ftz.f64 %0,%1;" : "=d"(r) : "d"(a));
    t = fma (a * r, -r, 1.0);
    r = fma (fma (0.375, t, 0.5), t * r, r);
    if (a < 0.0) r = MY_NAN;
    return r;
}

/* compute natural logarithm of positive normals; max. err. found: 0.87008 ulp*/
__device__ double log_pos_normal (double a)
{
    const double MY_INF = __hiloint2double (0x7ff00000, 0x00000000);
    const double MY_NAN = __hiloint2double (0xfff80000, 0x00000000);
    const double ln2_hi = 6.9314718055994529e-01; // 0x1.62e42fefa39efp-01
    const double ln2_lo = 2.3190468138462996e-17; // 0x1.abc9e3b39803fp-56
    double m, r, i, s, t, p, q;
    int e;

    /* log(a) = log(m * 2**i) = i * log(2) + log(m) */
    e = (__double2hiint (a) - __double2hiint (0.70703125)) & 0xfff00000;
    m = __hiloint2double (__double2hiint (a) - e, __double2loint (a));
    t = __hiloint2double (0x41f00000, 0x80000000 ^ e);
    i = t - (__hiloint2double (0x41f00000, 0x80000000));

    /* m now in [181/256, 362/256]. Compute q = (m-1) / (m+1) */
    p = m + 1.0;
    r = rcp64 (p);
    q = fma (m, r, -r);
    m = m - 1.0;

    /* Compute (2*atanh(q)/q-2*q) as p(q**2), q in [-75/437, 53/309] */
    s = q * q;
    r =            1.4794533702205467e-1;  // 0x1.2efdf700d7e8p-3
    r = fma (r, s, 1.5314187748152153e-1); // 0x1.39a272db730bp-3
    r = fma (r, s, 1.8183559141305675e-1); // 0x1.746637f2f174p-3
    r = fma (r, s, 2.2222198669309617e-1); // 0x1.c71c522a6458p-3
    r = fma (r, s, 2.8571428741489424e-1); // 0x1.24924941c9a4p-2
    r = fma (r, s, 3.9999999999418556e-1); // 0x1.999999998007p-2
    r = fma (r, s, 6.6666666666667340e-1); // 0x1.555555555559p-1
    r = r * s;

    /* log(a) = 2*atanh(q) + i*log(2) = ln2_lo*i + p(q**2)*q + 2q + ln2_hi * i.
       Use K.C. Ng's trick to improve the accuracy of the computation, like so:
       p(q**2)*q + 2q = p(q**2)*q + q*t - t + m, where t = m**2/2.
    */
    t = m * m * 0.5;
    r = fma (q, t, fma (q, r, ln2_lo * i)) - t + m;
    r = fma (ln2_hi, i, r);

    return r;
}

/* Compute the principal branch of the Lambert W function, W_0. Maximum error:
   positive half-plane: 1.49283 ulp @  3.8007153063827201e-6
   negative half-plane: 2.68410 ulp @ -3.4564955331259656e-1
*/
__device__ double lambert_w0 (double z) 
{
    const double MY_INF = __hiloint2double (0x7ff00000, 0x00000000);
    const double em1_fact_0 = 0.57086272525975246; // 0x1.24481e7efdfcep-1 // exp(-1)_factor_0
    const double em1_fact_1 = 0.64442715366299452; // 0x1.49f25b1b461b7p-1 // exp(-1)_factor_1
    const double qe1 = 2.7182818284590452 * 0.25;  // 0x1.5bf0a8b145769p-1 // exp(1)/4
    double e, r, t, w, y, num, den, rden, redz;
    int i;
    
    if (isnan (z) || (z == MY_INF) || (z == 0.0)) return z + z;
    if (fabs (z) < 1.9073486328125e-6) return fma (fma (1.5, z, -1.) * z, z, z);
    redz = fma (em1_fact_0, em1_fact_1, z); // z + exp(-1)
    if (redz < 0.01025390625) { // expansion at -(exp(-1))
        r = rsq64 (redz) * redz; // sqrt (redz)
        w =            -7.8466654751155138;
        w = fma (w, r, 10.0241581340373877);
        w = fma (w, r, -8.1029379749359691);
        w = fma (w, r,  5.8322883145113726);
        w = fma (w, r, -4.1738796362609882);
        w = fma (w, r,  3.0668053943936471);
        w = fma (w, r, -2.3535499689514934);
        w = fma (w, r,  1.9366310979331112);
        w = fma (w, r, -1.8121878855270763);
        w = fma (w, r,  2.3316439815968506);
        w = fma (w, r, -1.0000000000000000);
        return w;
    }
    /* Starting approximation from: Roberto Iacono and John Philip Boyd, 
       "New approximations to the principal real-valued branch of the 
       Lambert W function", Advances in Computational Mathematics, Vol. 
       43, No. 6, December 2017, pp. 1403-1436
     */
    y = fma (2.0, sqrt (fma (qe1, z, 0.25)), 1.0);
    y = log_pos_normal (fma (1.14956131, y, -0.14956131) / 
                        fma (0.4549574, log_pos_normal (y), 1.0));
    w = fma (2.036, y, -1.0);

    /* Use iteration scheme w = (w / (1 + w)) * (1 + log (z / w) from
       Roberto Iacono and John Philip Boyd, "New approximations to the 
       principal real-valued branch of the Lambert W function", Advances
       in Computational Mathematics, Vol. 43, No. 6, December 2017, pp. 
       1403-1436
    */
    for (i = 0; i < 3; i++) {
        t = w * rcp64 (1.0 + w);
        r = rcp64 (w);
        w = fma (log_pos_normal (z * r), t, t);
    }

    /* Fine tune approximation with a single Newton iteration */
    e = exp_scale_pos_normal (w, -3);
    num = fma (w, e, -0.125 *z);
    den = fma (w, e, e);
    rden = rcp64 (den);
    w = fma (-num, rden, w);

    return w;
}
3 Likes

Since I came across a half dozen recent papers that make use of the negative branch of the real-valued Lambert W function, W-1, it looks like this function branch is of more utility than I previously thought. Given that, I thought I would try my hand at an initial single-precision implementation over the weekend to see what is involved. It turns it is not overly complicated once one has a good overview of available literature. lambert_wm1f() below achieves reasonable accuracy at a maximum error of 2.65 ulps and relies quite heavily on the GPU’s multi-function unit for good performance.

[code below updated 7/28/2023]

/*
  Copyright (c) 2023, Norbert Juffa

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

__forceinline__ __device__ float raw_rcp (float a)
{
    float r;
    asm ("rcp.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

__forceinline__ __device__ float raw_sqrt (float a)
{
    float r;
    asm ("sqrt.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

__forceinline__ __device__ float raw_ex2 (float a)
{
    float r;
    asm ("ex2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

__forceinline__ __device__ float raw_lg2 (float a)
{
    float r;
    asm ("lg2.approx.ftz.f32 %0,%1;" : "=f"(r) : "f"(a));
    return r;
}

/* Compute natural logarithm, for positive normal arguments only!
   Max ulp error = 0.85095 
*/
__device__ float logf_pos_normal (float a)
{
    float f, i, m, r, s, t;
    int e;

    e = (__float_as_int (a) - __float_as_int (0.666666667f)) & 0xff800000;
    m = __int_as_float (__float_as_int (a) - e);
    i = (float)e * 1.19209290e-7f; // 0x1.0p-23
    /* m in [2/3, 4/3] */
    f = m - 1.0f;
    s = f * f;
    /* Compute log1p(f) for f in [-1/3, 1/3] */
    r =             -0.130310059f;  // -0x1.0ae000p-3
    t =              0.140869141f;  //  0x1.208000p-3
    r = fmaf (r, s, -0.121483512f); // -0x1.f198b2p-4
    t = fmaf (t, s,  0.139814854f); //  0x1.1e5740p-3
    r = fmaf (r, s, -0.166846126f); // -0x1.55b36cp-3
    t = fmaf (t, s,  0.200120345f); //  0x1.99d8b2p-3
    r = fmaf (r, s, -0.249996200f); // -0x1.fffe02p-3
    r = fmaf (t, f, r);
    r = fmaf (r, f,  0.333331972f); //  0x1.5554fap-2
    r = fmaf (r, f, -0.500000000f); // -0x1.000000p-1  
    r = fmaf (r, s, f);
    /* log (a) = log (m) + i * log (2) = log1p (f) + i * log (2) */
    r = fmaf (i,  0.693147182f, r); //  0x1.62e430p-1 // log(2)
    return r;
}

/*
  Compute the negative branch of the real-valued Lambert W function, W_{-1}.
  Maximum error 2.65003 ulp.
*/
__device__ float lambert_wm1f (float x) 
{
    const float MY_INF = __int_as_float (0x7f800000);
    const float MY_NAN = __int_as_float (0x7fffffff);
    const float em1_fact_0 = 0.625529587f; // exp(-1)_factor_0
    const float em1_fact_1 = 0.588108778f; // exp(-1)_factor_1
    const float exp2_32 = 4294967296.0f; // 0x1.0p32
    const float ln2 = 0.69314718f; // log(2)
    const float l2e = 1.44269504f; // log2(exp(1))
    const float c0 =  1.68090820e-1f; //  0x1.584000p-3
    const float c1 = -2.96497345e-3f; // -0x1.84a000p-9
    const float c2 = -2.87322998e-2f; // -0x1.d6c000p-6
    const float c3 =  7.07275391e-1f; //  0x1.6a2000p-1
    float redx, r, s, t, w, e, num, den, rden;
    
    if (isnan (x)) return x + x;
    if (x == 0.0f) return -MY_INF;
    if (x > 0.0f) return MY_NAN;
    redx = fmaf (em1_fact_0, em1_fact_1, x); // x + exp(-1)
    if (redx <= 0.09765625f) { // expansion at -(exp(-1))
        r = raw_sqrt (redx);
        w =             -3.30250000e+2f;  // -0x1.4a4000p+8
        w = fmaf (w, r,  3.53563141e+2f); //  0x1.61902ap+8
        w = fmaf (w, r, -1.91617889e+2f); // -0x1.7f3c5cp+7
        w = fmaf (w, r,  4.94172478e+1f); //  0x1.8b5686p+5
        w = fmaf (w, r, -1.23464909e+1f); // -0x1.8b1674p+3
        w = fmaf (w, r, -1.38704872e+0f); // -0x1.6315a0p+0
        w = fmaf (w, r, -1.99431837e+0f); // -0x1.fe8ba6p+0
        w = fmaf (w, r, -1.81044364e+0f); // -0x1.cf793cp+0
        w = fmaf (w, r, -2.33166337e+0f); // -0x1.2a73f2p+1
        w = fmaf (w, r, -1.00000000e+0f); // -0x1.000000p+0
    } else {
        /* Initial approximation based on: D. A. Barry, L. Li, and D.-S. Jeng, 
           "Comments on 'Numerical Evaluation of the Lambert Function and 
           Application to Generation of Generalized Gaussian Noise with 
           Exponent 1/2'", IEEE Transactions on Signal Processing, Vol. 52, 
           No. 5, May 2004, pp. 1456-1457
        */
        s = fmaf (raw_lg2 (-x * exp2_32) - 32.0f, -ln2, -1.0f);
        t = raw_sqrt (s);
        w = -1.0f - s - raw_rcp (fmaf (raw_ex2 (c2 * t), c1 * t, 
                                       fmaf (raw_rcp (t), c3, c0)));

        if (x > -7.703719778e-34f) { // -0x1.0p-110
            /* Newton iteration */
            e = raw_ex2 (fmaf (w, l2e, 32.0f));
            num = fmaf (w, e, -exp2_32 * x);
            den = fmaf (w, e, e);
            rden = raw_rcp (den);
            w = fmaf (-num, rden, w);
        } else {
            /* Roberto Iacono and John Philip Boyd, "New approximations to the 
               principal real-valued branch of the Lambert W function", 
               Advances in Computational Mathematics, Vol. 43, No. 6, 
               December 2017, pp. 1403-1436
            */
            t = 1.0f + w;
            r = raw_rcp (t);
            e = fmaf (-t, r, 1.0f);
            t = fmaf (w, r, e * r * w); // t = w / (1.0f + w);
            w = fmaf (logf_pos_normal (x * raw_rcp (w)), t, t);
        } 
    }
    return w;
}
2 Likes

I finally got around to designing a double-precision implementation of W-1. Since the starting approximation I took from the literature provides a relative error of 2.5e-4, two iterations with quadratic convergence are not quite sufficient to reach full double-precision accuracy. However three simple Newton iteration provide excellent accuracy. As with my other implementations of the Lambert W function, a polynomial approximation is used near -exp(-1). Double-precision implementations cannot be tested exhaustively; with one billion random test vectors, the maximum error found was 2.58435 ulps.

/*
  Copyright 2023, Norbert Juffa

  Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/* exp(a) * 2**scale; pos. normal results only! Max. err. found: 0.89028 ulp */
__device__ double exp_scale_pos_normal (double a, int scale)
{
    const double ln2_hi = 6.9314718055829871e-01; // 0x1.62e42fefa00000p-01
    const double ln2_lo = 1.6465949582897082e-12; // 0x1.cf79abc9e3b3a0p-40
    const double l2e = 1.4426950408889634; // 0x1.71547652b82fe0p+00 // log2(e)
    const double cvt = 6755399441055744.0; // 0x1.80000000000000p+52 // 3*2**51
    double f, r;
    int i;

    /* exp(a) = exp(i + f); i = rint (a / log(2)) */
    r = fma (l2e, a, cvt);
    i = __double2loint (r);
    r = r - cvt;
    f = fma (r, -ln2_hi, a);
    f = fma (r, -ln2_lo, f);

    /* approximate r = exp(f) on interval [-log(2)/2,+log(2)/2] */
    r =            2.5022018235176802e-8;  // 0x1.ade0000000000p-26
    r = fma (r, f, 2.7630903497145818e-7); // 0x1.28af3fcbbf09bp-22
    r = fma (r, f, 2.7557514543490574e-6); // 0x1.71dee623774fap-19
    r = fma (r, f, 2.4801491039409158e-5); // 0x1.a01997c8b50d7p-16
    r = fma (r, f, 1.9841269589068419e-4); // 0x1.a01a01475db8cp-13
    r = fma (r, f, 1.3888888945916566e-3); // 0x1.6c16c1852b805p-10
    r = fma (r, f, 8.3333333334557735e-3); // 0x1.11111111224c7p-7
    r = fma (r, f, 4.1666666666519782e-2); // 0x1.55555555502a5p-5
    r = fma (r, f, 1.6666666666666477e-1); // 0x1.5555555555511p-3
    r = fma (r, f, 5.0000000000000122e-1); // 0x1.000000000000bp-1
    r = fma (r, f, 1.0000000000000000e+0); // 0x1.0000000000000p+0
    r = fma (r, f, 1.0000000000000000e+0); // 0x1.0000000000000p+0

    /* exp(a) = 2**(i+scale) * r */
    r = __hiloint2double (__double2hiint (r) + ((i + scale) << 20), 
                          __double2loint (r));
    return r;
}

/* compute approximate reciprocal */
__forceinline__ __device__ double rcp64 (double a)
{
    double t, r;

    asm ("rcp.approx.ftz.f64 %0,%1;" : "=d"(r) : "d"(a));
    t = fma (-a, r, 1.0); 
    t = fma (t, t, t); 
    r = fma (t, r, r); 
    return r;
}

/* compute approximate reciprocal square root */
__forceinline__ __device__ double rsq64 (double a)
{
    const double MY_NAN = __hiloint2double (0xfff80000, 0x00000000);
    double t, r;

    asm ("rsqrt.approx.ftz.f64 %0,%1;" : "=d"(r) : "d"(a));
    t = fma (a * r, -r, 1.0);
    r = fma (fma (0.375, t, 0.5), t * r, r);
    if (a < 0.0) r = MY_NAN;
    return r;
}

/* compute natural logarithm of positive normals; max. err. found: 0.87008 ulp*/
__device__ double log_pos_normal (double a)
{
    const double MY_INF = __hiloint2double (0x7ff00000, 0x00000000);
    const double MY_NAN = __hiloint2double (0xfff80000, 0x00000000);
    const double ln2_hi = 6.9314718055994529e-01; // 0x1.62e42fefa39efp-01
    const double ln2_lo = 2.3190468138462996e-17; // 0x1.abc9e3b39803fp-56
    double m, r, i, s, t, p, q;
    int e;

    /* log(a) = log(m * 2**i) = i * log(2) + log(m) */
    e = (__double2hiint (a) - __double2hiint (0.70703125)) & 0xfff00000;
    m = __hiloint2double (__double2hiint (a) - e, __double2loint (a));
    t = __hiloint2double (0x41f00000, 0x80000000 ^ e);
    i = t - (__hiloint2double (0x41f00000, 0x80000000));

    /* m now in [181/256, 362/256]. Compute q = (m-1) / (m+1) */
    p = m + 1.0;
    r = rcp64 (p);
    q = fma (m, r, -r);
    m = m - 1.0;

    /* Compute (2*atanh(q)/q-2*q) as p(q**2), q in [-75/437, 53/309] */
    s = q * q;
    r =            1.4794533702205467e-1;  // 0x1.2efdf700d7e8p-3
    r = fma (r, s, 1.5314187748152153e-1); // 0x1.39a272db730bp-3
    r = fma (r, s, 1.8183559141305675e-1); // 0x1.746637f2f174p-3
    r = fma (r, s, 2.2222198669309617e-1); // 0x1.c71c522a6458p-3
    r = fma (r, s, 2.8571428741489424e-1); // 0x1.24924941c9a4p-2
    r = fma (r, s, 3.9999999999418556e-1); // 0x1.999999998007p-2
    r = fma (r, s, 6.6666666666667340e-1); // 0x1.555555555559p-1
    r = r * s;

    /* log(a) = 2*atanh(q) + i*log(2) = ln2_lo*i + p(q**2)*q + 2q + ln2_hi * i.
       Use K.C. Ng's trick to improve the accuracy of the computation, like so:
       p(q**2)*q + 2q = p(q**2)*q + q*t - t + m, where t = m**2/2.
    */
    t = m * m * 0.5;
    r = fma (q, t, fma (q, r, ln2_lo * i)) - t + m;
    r = fma (ln2_hi, i, r);

    return r;
}

/*
  Compute the negative branch of the real-valued Lambert W function, W_{-1}.
  Maximum error found using 1B test cases: 2.58435 ulps.
*/
__device__ double lambert_wm1 (double z) 
{
    const double MY_INF = __hiloint2double (0x7ff00000, 0x00000000);
    const double MY_NAN = __hiloint2double (0xfff80000, 0x00000000);
    const double c0 =  1.6729676723480225e-1; //  0x1.569fbp-3
    const double c1 = -2.7966443449258804e-3; // -0x1.6e8fdp-9
    const double c2 = -2.1342277526855469e-2; // -0x1.5dac0p-6
    const double c3 =  7.0781660079956055e-1; //  0x1.6a66fp-1
    const double ln2 = 0.6931471805599453094172;
    const double exp2_64 = 1.8446744073709552e+19; // 0x1.0p64
    const double em1_fact_0 = 0.57086272525975246; // 0x1.24481e7efdfcep-1 // exp(-1)_factor_0
    const double em1_fact_1 = 0.64442715366299452; // 0x1.49f25b1b461b7p-1 // exp(-1)_factor_1
    double redz, r, s, t, w, e, num, den, rden;
    int i;
    
    if (isnan (z)) return z + z;
    if (z == 0.0) return -MY_INF;
    if (z > 0.0) return MY_NAN;
    redz = fma (em1_fact_0, em1_fact_1, z); // z + exp(-1)
    if (redz < 0.04296875) { // expansion at -(exp(-1))
        r = rsq64 (redz) * redz; // sqrt (z)
        w =            -3.1102051749530146e+3;
        w = fma (w, r,  3.3514583413659661e+3);
        w = fma (w, r, -2.0376505203571792e+3);
        w = fma (w, r,  6.6470674321336662e+2);
        w = fma (w, r, -1.9891092047488328e+2);
        w = fma (w, r,  1.1792563777850908e+1);
        w = fma (w, r, -1.6044530625408662e+1);
        w = fma (w, r, -8.0468700837359766e+0);
        w = fma (w, r, -5.8822749101364442e+0);
        w = fma (w, r, -4.1741334842726321e+0);
        w = fma (w, r, -3.0669009628894104e+0);
        w = fma (w, r, -2.3535502058947735e+0);
        w = fma (w, r, -1.9366311293653595e+0);
        w = fma (w, r, -1.8121878855163436e+0);
        w = fma (w, r, -2.3316439815975385e+0);
        w = fma (w, r, -1.0000000000000000e+0);
        return w;
    } else {
        /* Initial approximation based on: D. A. Barry, L. Li, D.-S. Jeng, 
           "Comments on 'Numerical Evaluation of the Lambert Function and 
           Application to Generation of Generalized Gaussian Noise with 
           Exponent 1/2'", IEEE Transactions on Signal Processing, Vol. 52, 
           No. 5, May 2004, pp. 1456-1457
        */
        s = - fma (ln2, -64.0, log_pos_normal (-exp2_64 * z)) - 1.0;
        t = rsq64(s) * s; // sqrt (s);
        w = -1.0 - s - rcp64 (fma (exp_scale_pos_normal (c2 * t, 0), 
                                   c1 * t, fma (rcp64 (t), c3, c0)));
        /* Newton iterations */
        for (i = 0; i < 3; i++) {
            e = exp_scale_pos_normal (w, 64);
            num = fma (w, e, -exp2_64 * z);
            den = fma (w, e, e);
            rden = rcp64 (den);
            w = fma (-num, rden, w);
        }
    }
    return w;
}
1 Like