Performance tweak for single-precision square root

Since the built-in reciprocal square root instruction of the GPU, MUFU.RSQ is very accurate, computing a correctly rounded single-precision square root is very simple for non-exceptional cases: Multiply the reciprocal square root approximation with the function argument to get a square root approximation, compute the error of that approximate square root with one FMA (fused multiply-add), finally use the computed error to correct the square root approximation (one more FMA) to yield the correctly rounded square root. This should require four FPU instructions in total in addition to the MUFU.RSQ.

Inexplicably, in CUDA 11.1 (with an sm_75 target, but it seems to be true for all architecture from sm35 up), five FPU instruction are used. There is a negation that should be folded into an FMA but isn’t:

// slowpath or fastpath ?
/*00b0*/                   IADD3 R4, R3, -0xd000000, RZ ;                     /* 0xf300000003047810 */
/*00c0*/                   ISETP.GT.U32.AND P0, PT, R4, 0x727fffff, PT ;      /* 0x727fffff0400780c */
/*00d0*/              @!P0 BRA 0x140 ;                                        /* 0x0000006000008947 */

// slowpath
/*00e0*/                   BMOV.32.CLEAR RZ, B1 ;                             /* 0x0000000001ff7355 */
/*00f0*/                   BSSY B1, 0x130 ;                                   /* 0x0000003000017945 */
/*0100*/                   MOV R8, 0x120 ;                                    /* 0x0000012000087802 */
/*0110*/                   CALL.REL.NOINC 0x230 ;                             /* 0x0000011000007944 */
/*0120*/                   BSYNC B1 ;                                         /* 0x0000000000017941 */
/*0130*/                   BRA 0x1a0 ;                                        /* 0x0000006000007947 */

// fastpath
/*0140*/                   MUFU.RSQ R2, R3 ;                                  /* 0x0000000300027308 */
/*0150*/                   FMUL.FTZ R5, R3, R2 ;                              /* 0x0000000203057220 */
/*0160*/                   FMUL.FTZ R2, R2, 0.5 ;                             /* 0x3f00000002027820 */
/*0170*/                   FADD.FTZ R4, -R5, -RZ ;                            /* 0x800000ff05047221 */   <<<<<<<<<<<<<
/*0180*/                   FFMA R4, R5, R4, R3 ;                              /* 0x0000000405047223 */
/*0190*/                   FFMA R5, R4, R2, R5 ;                              /* 0x0000000204057223 */
/*01a0*/                   BSYNC B0 ;                                         /* 0x0000000000007941 */

I built my own implementation of sqrtf() to confirm that only four instructions are needed in addition to MUFU.RSQ and my code passes an exhaustive test with both ftz=true and ftz=false. The resulting speedup at app level is modest, only a few percent, but sometimes every percent counts.

/*
  Copyright (c) 2021, Norbert Juffa
  All rights reserved.

  Redistribution and use in source and binary forms, with or without 
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright 
     notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

__host__ __device__ uint32_t float_as_uint32 (float);
__host__ __device__ float uint32_as_float (uint32_t);
__host__ __device__ __noinline__ float sqrtf_slowpath (float);

__host__ __device__ __forceinline__ float my_sqrtf (float arg)
{
    const uint32_t upper = float_as_uint32(3.402823470e+38f); // 0x1.fffffep+127
    const uint32_t lower = float_as_uint32(1.972152262e-31f); // 0x1.000000p-102
    float rsq, sqt, err;

    /* use fastpath computation if argument in [0x1.0p-102, 0x1.0p+128) */
    if ((uint32_t(float_as_uint32 (arg)) - lower) <= (upper - lower)) {

#if (__CUDA_ARCH__ >= 300)
        asm ("rsqrt.approx.ftz.f32 %0,%1; \n\t" : "=f"(rsq) : "f"(arg));
#else // __CUDA_ARCH__
        /* generate low-accuracy approximation to rsqrt(arg) */
        _mm_store_ss (&rsq, _mm_rsqrt_ss (_mm_set_ss (arg)));
        /* apply Newton-Raphson iteration with quadratic convergence */
        rsq = fmaf (fmaf (-0.5f * arg * rsq, rsq, 0.5f), rsq, rsq);
#endif // __CUDA_ARCH__
        
        /* compute sqrt from rsqrt, round result to nearest or even */
        sqt = rsq * arg;
        err = fmaf (sqt, -sqt, arg);
        sqt = fmaf (0.5f * rsq, err, sqt);
    } else {
        sqt = sqrtf_slowpath (arg);
    }
    return sqt;
}

__host__ __device__ uint32_t float_as_uint32 (float a)
{
#if (__CUDA_ARCH__ >= 300)
    return (uint32_t)(__float_as_int(a));
#else // __CUDA_ARCH__
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
#endif // __CUDA_ARCH__
}

__host__ __device__ float uint32_as_float (uint32_t a)
{
#if (__CUDA_ARCH__ >= 300)
    return __int_as_float((int)a);
#else // __CUDA_ARCH__
    float r;
    memcpy (&r, &a, sizeof r);
    return r;
#endif // __CUDA_ARCH__
}

__host__ __device__ __noinline__ float sqrtf_slowpath (float arg)
{
    const float FP32_INFINITY = uint32_as_float (0x7f800000);
#if (__CUDA_ARCH__ >= 300)
    const float FP32_QNAN = uint32_as_float (0x7fffffff); /* canonical NaN */
#else // __CUDA_ARCH__
    const float FP32_QNAN = uint32_as_float (0xffc00000); /* QNaN INDEFINITE */
#endif // __CUDA_ARCH__
    const float scale_in  = 67108864.0f;     // 0x1.0p+26
    const float scale_out = 1.220703125e-4f; // 0x1.0p-13
    float rsq, err, sqt;

    if (arg < 0.0f) {
        return FP32_QNAN;
    } else if ((arg == 0.0f) || !(fabsf (arg) < FP32_INFINITY)) { /* Inf, NaN */
        return arg + arg;
    } else {
        /* scale subnormal arguments towards unity */
        arg = arg * scale_in;

#if (__CUDA_ARCH__ >= 300)
        asm ("rsqrt.approx.ftz.f32 %0,%1; \n\t" : "=f"(rsq) : "f"(arg));
#else // __CUDA_ARCH__
        /* generate low-accuracy approximation to rsqrt(arg) */
        _mm_store_ss (&rsq, _mm_rsqrt_ss (_mm_set_ss (arg)));
        /* apply Newton-Raphson iteration with quadratic convergence */
        rsq = fmaf (fmaf (-0.5f * arg * rsq, rsq, 0.5f), rsq, rsq);
#endif  // __CUDA_ARCH__
        
        /* compute sqrt from rsqrt, round to nearest or even */
        sqt = rsq * arg;
        err = fmaf (sqt, -sqt, arg);
        sqt = fmaf (0.5f * rsq, err, sqt);

        /* compensate scaling of argument by counter-scaling the result */
        sqt = sqt * scale_out;
        
        return sqt;
    }
}

The disassembly from cuobjdump --dump-sass confirms the tighter code generated:

// fastpath or slowpath?
/*00b0*/                   IADD3 R4, R7, -0xc800000, RZ ;                 /* 0xf380000007047810 */
/*00c0*/                   ISETP.GT.U32.AND P0, PT, R4, 0x72ffffff, PT ;  /* 0x72ffffff0400780c */

 // fastpath
/*00d0*/              @!P0 MUFU.RSQ R4, R7 ;                              /* 0x0000000700048308 */
/*00e0*/              @!P0 FMUL.FTZ R5, R7, R4 ;                          /* 0x0000000407058220 */
/*00f0*/              @!P0 FMUL.FTZ R6, R4, 0.5 ;                         /* 0x3f00000004068820 */
/*0100*/              @!P0 FFMA.FTZ R4, R5, -R5, R7 ;                     /* 0x8000000505048223 */
/*0110*/              @!P0 FFMA.FTZ R5, R4, R6, R5 ;                      /* 0x0000000604058223 */
/*0120*/              @!P0 BRA 0x180 ;                                    /* 0x0000005000008947 */

// slowpath
/*0130*/                   BMOV.32.CLEAR RZ, B1 ;                         /* 0x0000000001ff7355 */
/*0140*/                   BSSY B1, 0x180 ;                               /* 0x0000003000017945 */
/*0150*/                   MOV R6, 0x170 ;                                /* 0x0000017000067802 */
/*0160*/                   CALL.REL.NOINC 0x210 ;                         /* 0x000000a000007944 */
/*0170*/                   BSYNC B1 ;                                     /* 0x0000000000017941 */