Accuracy-optimized implementation of erff(), without performance impact

I set out to create a faithfully-rounded implementation of erff() without losing performance relative to CUDA 11’s built-in implementation. Alas, the hardware’s implementation of exp2f() (the MUFU.EX2 instruction) foiled my attempt: For 27 arguments |x| in [1.001, 1.083] an error in excess of 1 ulp occurs.

Still, maximum error drops from 1.14 ulps in CUDA 11 to 1.04 ulps for my_erff(), and the number of arguments across the entire input domain for which a correctly-rounded result is not returned drops from 364.7M in CUDA 11 to 30.3M with the code below, meaning average accuracy is improved noticeably. The performance is exactly the same as erff() in CUDA 11.

[Code below updated 6/29/2022, 12/29/2022, 1/3/2023]

/*
  Copyright (c) 2022-2023, Norbert Juffa
  All rights reserved.

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

__device__ float my_erff (float a)
{
    const float switchover = 1.0f;
    // approximation on [0,1]; maximum error 0.96922 ulp
    const float c0_lower =  8.05854797e-5f; //  0x1.520000p-14
    const float c1_lower = -8.07280652e-4f; // -0x1.a73f60p-11
    const float c2_lower =  5.19511057e-3f; //  0x1.54777ep-8
    const float c3_lower = -2.68568154e-2f; // -0x1.b805a6p-6
    const float c4_lower =  1.12836286e-1f; //  0x1.ce2d6cp-4
    const float c5_lower = -3.76126260e-1f; // -0x1.81273ep-2
    const float c6_lower =  1.28379166e-1f; //  0x1.06eba8p-3
    // approximation on [1,4]; maximum error 1.03872 ulp
    const float c0_upper =  2.58684158e-5f; //  0x1.b20000p-16
    const float c1_upper =  5.64930320e-4f; //  0x1.282faap-11
    const float c2_upper =  5.66849299e-3f; //  0x1.737d88p-8
    const float c3_upper =  3.51767987e-2f; //  0x1.202b18p-5
    const float c4_upper =  1.54329687e-1f; //  0x1.3c1134p-3
    const float c5_upper = -9.15674686e-1f; // -0x1.d4d350p-1
    const float c6_upper =  6.28459632e-1f; //  0x1.41c576p-1
    float fa, s, t, r, c0, c1, c2, c3, c4, c5, c6;

    fa = fabsf (a);
    // select approximation for upper or lower interval as appropriate
    s  = (fa >= switchover) ? (-fa) : (a * a);
    t  = (fa >= switchover) ? (-fa) : (a);
    c0 = (fa >= switchover) ? c0_upper : c0_lower;
    c1 = (fa >= switchover) ? c1_upper : c1_lower;
    c2 = (fa >= switchover) ? c2_upper : c2_lower;
    c3 = (fa >= switchover) ? c3_upper : c3_lower;
    c4 = (fa >= switchover) ? c4_upper : c4_lower;
    c5 = (fa >= switchover) ? c5_upper : c5_lower;
    c6 = (fa >= switchover) ? c6_upper : c6_lower;
    // evaluate polynomial approximation for either interval
    r =             c0;
    r = fmaf (r, s, c1);
    r = fmaf (r, s, c2);
    r = fmaf (r, s, c3);
    r = fmaf (r, s, c4);
    r = fmaf (r, s, c5);
    r = fmaf (r, s, c6);
    r = fmaf (r, t, t);
    // finish approximation for upper interval
    if (fa >= switchover) {
        asm ("ex2.approx.ftz.f32 %0,%0;" : "+f"(r));
        r = 1.0f - r;
        r = __int_as_float (__float_as_int(a) & 0x80000000 | __float_as_int(r));
    }
    return r;
}
2 Likes