Implement faster cuda intrinsics for specific power functions

Hello, I want to use a fast power function to calculate x^p. Cuda has provided an intrinsic called __powf(x, p) however it is still slower than I expect. I know the reason why power function is so slow is that there are many corner cases. However, for my case, x is a float between [0, 1] and p is a fixed large positive float number, for example 15 or 25. Can I write a custom intrinsic that calculates x^25 (for example)? Since I don’t know how __powf is implemented (I know it uses __log2f and __exp2f), it is hard for me to modify the implementation to fit my need. Can anyone help me? Thank you so much!

Note that I don’t care about high accuracy. For my case, it is enough that the maximum absolute error just be less than 1e-4.

The execution time of the C++ standard math functions pow() and powf() is about twice that of the sum of the execution time of the standard math functions exp()/expf(), and log()\logf()combined. This rule of thumb applies to all platforms I am familiar with including CUDA.

Why twice the time? The additional time is need to (1) compute the logarithm to extended precision to guarantee an accurate pow()/powf() result (2) deal with the many special cases prescribed by the language standard.

The device-function intrinsic __powf() is not encumbered by these requirements, that is, it neither computes the logarithm to extended precision nor does it make any effort to get special cases correct. If you compile with -ftz=true (or -use_fast_math which includes -ftz=true), the implementation comprises three inlined instructions (here is __powf() disassembled from sm_70 code):

        /*0020*/                   MUFU.LG2 R0, c[0x0][0x160] ;
        /*0040*/                   FMUL.FTZ R0, R0, c[0x0][0x164] ;
        /*0050*/                   MUFU.EX2 R0, R0 ;

I don’t see how one can get any faster code for a power function that takes two float arguments. I find it hard to believe that this would be a bottleneck in anything but a trivial test app. If you have profiler output to the contrary please share it.

Your question suggests that your use case may be using pow with integer exponents. A call pow (double, int) or pow (float, int) will not go through the standard math functions, but uses a square-multiply algorithm that scans the exponent one bit at a time. This may result in code that is optimal for very small exponents, and it may be slower than invoking __powf (float, float).

Thank you for your reply. Yes, I agree that the log-exp paradigm can not be faster than simply calling __powf(), but I wonder if there are other approaches. Since for my case p is fixed in advance, the function is thus unary. I would like to learn how other unary function intrinsics (for example, __expf and __logf) are implemented, and derive a similar approximation to power function. I has tried Taylor expansion, but to reach the accuracy the order needed is very large, which leads to a slow calculation.

I really need an extremely fast calculation of power function. In my case, the power operation takes at least 90% of time. I use power function to calculate p-norm of millions of high dimensional vectors.

In my case, the power operation takes at least 90% of time

You mean __powf(), in code compiled with -ftz=true, takes 90% of the time? I find that hard to believe. There have to be at least a few other instructions, maybe even a load and a store somewhere.

It is very hard to beat the three instruction sequence I showed.

If the exponent is compile-time constant, the compiler’s constant propagation optimization may help. Take a look at the generated machine code with cuobjdump --dump-sass to check whether that is happening. You could implement specialized versions of pow by template metaprogramming that expand into the minimum number of multiplications. That is only advantageous for small integer exponents obviously.

I would like to learn how other unary function intrinsics (for example, __expf and __logf) are implemented

Compile a minimal kernel using the function under test in a release build, then use cudobdump --dump-sass to examine the generated machine code, and all will be revealed. __expf() and __logf in particular map to the MUFU.EX2 and MUFU.LG2 machine instructions with additional instructions for base conversion and (optionally) denormal support.

I replace the p-norm function (which calculates (\sum_i (x_i)^p)^(1/p)) to the summation function (which calculates \sum_i x_i), and the latter only takes less than 20% time. I use __powf() but I does not add other configurations (-ftz or -use_fast_math). Will -ftz have effect on __powf() intrinsic?

Also, do you mean MUFU.LG2 and MUFU.EX2 are basic instructions? I thought MUFU.LG2 and MUFU.EX2 comprise of a series of additions and multiplications.

Yes. -ftz=true will eliminate the portion of the code that is needed to handle denormal operands. I goofed in my earlier post: since mathematically, pow (x,y) = exp2 (y * log2 (x)), constant propagation cannot help when the exponent is compile-time constant.

Yes. The hardware performs quadratic interpolation in internal ROM-based tables for these. MUFU stands for MUlti Function Unit. See:

S. Oberman and M. Siu, “A High-Performance Area-Efficient Multifunction Interpolator.” In 17th IEEE Symposium on Computer Arithmetic, June 2005, pp. 272-279

If all your exponents are small non-negative integers known at compile time, you could use something like the following which relies on constant propagation and dead-code elimination:

    /* compute a**b, where b in [0,25] */
    __device__ float powf_int (float a, int b)
    {
        float r, s, t, u, v;
        if (b == 0) {
            r = 1;
        } else if (b == 1) {
            r = a;
        } else if (b == 2) {
            r = a * a;
        } else if (b == 3) {
            r = a * a * a;
        } else if (b == 4) {
            s = a * a;
            r = s * s;
        } else if (b == 5) {
            s = a * a;
            r = a * s * s;
        } else if (b == 6) {
            s = a * a;
            r = s * s * s;
        } else if (b == 7) {
            s = a * a;
            t = s * s;
            r = a * s * t;
        } else if (b == 8) {
            s = a * a;
            t = s * s;
            r = t * t;
        } else if (b == 9) {
            t = a * a * a;
            r = t * t * t;
        } else if (b == 10) {
            s = a * a;
            t = s * s;
            r = s * t * t;
        } else if (b == 11) {
            s = a * a;
            t = s * s;
            r = a * s * t * t;
        } else if (b == 12) {
            s = a * a;
            t = s * s;
            r = t * t * t;
        } else if (b == 13) {
            s = a * a;
            t = s * s;
            r = a * t * t * t;
        } else if (b == 14) {
            s = a * a;
            t = s * s;
            r = s * t * t * t;
        } else if (b == 15) {
            s = a * a;
            t = a * s * s;
            r = t * t * t;
        } else if (b == 16) {
            s = a * a;
            t = s * s;
            u = t * t;
            r = u * u;
        } else if (b == 17) {
            s = a * a;
            t = s * s;
            u = t * t;
            r = a * u * u;
        } else if (b == 18) {
            t = a * a * a;
            u = t * t * t;
            r = u * u;
        } else if (b == 19) {
            t = a * a * a;
            u = t * t * t;
            r = a * u * u;
        } else if (b == 20) {
            s = a * a;
            t = s * s;
            u = t * t;
            r = t * u * u;
        } else if (b == 21) {
            s = a * a;
            t = s * s;
            u = t * t;
            r = a * t * u * u;
        } else if (b == 22) {
            s = a * a;
            t = s * s;
            u = t * t;
            r = s * t * u * u;
        } else if (b == 23) {
            s = a * a;
            t = a * s;
            u = s * t;
            v = u * u;
            r = t * v * v;
        } else if (b == 24) {
            s = a * a;
            t = s * s;
            u = t * t;
            r = u * u * u;
        } else if (b == 25) {
            s = a * a;
            t = a * s * s;
            u = t * t;
            r = t * u * u;
       }
        return r;
    }