Kernel functions as parameters?

Is it possible to pass a kernel function as a parameter to a kernel call? I envisage something along the lines of the following simple example:

Wrapper Function:

void test( int *testArray,

		  unsigned int numPoints)

{

	// define grid and block size

    unsigned int numBlocks, numThreadsPerBlock;

	computeGridSize(numPoints, 256, numBlocks, numThreadsPerBlock);

// launch kernel

	test_k<<< numBlocks, numThreadsPerBlock >>>( testArray, increment_k, numPoints );

	checkCUDAError("Test Kernel");

}

That would then allow me to have a function that applies the appropriate extra kernel depending on what I have called e.g.:

__device__ int

increment_k( int t )

{

	return t + 1;

}

__global__ void

test_k( int *testArray, int (*f)(int), uint numPoints )

{

	unsigned int index = __umul24(blockIdx.x,blockDim.x) + threadIdx.x;

	if(index >= numPoints) return;

	int temp = (*f)(testArray[index]);

	testArray[index] = temp;

}

I am basically trying to apply the same kernel to different situations but to have a way of changing calculation of some of the variables according to my needs e.g. if the test_k kernel use area of a shape as part of its calculation but I wanted to call different functions for the area calculation to allow its application on different shapes. Thanks in advance.

General note: Invoking functions via function pointers in device code requires compute capability 2.x.

Only global functions and their addresses are visible inside host code. device functions, and thus their addresses, are not visible inside the host code. Therefore cannot pass function pointers to device functions to a kernel call (which is inside the host portion of the code).

However, one can pass the information needed to select the desired device function to the kernel. In the following example, a kernel finds either the minimum or the maximum element in an array of floats, with the argument findMin specifying which operation is desired. The kernel then selects a pointer to the appropriate selection function (either minimum or maximum) and passes that to a device function minmax() that does all the work.

#include <stdio.h>

#include <stdlib.h>

#define BLOCK_COUNT  240

#define THREAD_COUNT 128

#define CUDA_SAFE_CALL(call)                                          \

do {                                                                  \

    cudaError_t err = call;                                           \

    if (cudaSuccess != err) {                                         \

        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\

                 __FILE__, __LINE__, cudaGetErrorString(err) );       \

        exit(EXIT_FAILURE);                                           \

    }                                                                 \

} while (0)

// Macro to catch CUDA errors in kernel launches

#define CHECK_LAUNCH_ERROR()                                          \

do {                                                                  \

    /* Check synchronous errors, i.e. pre-launch */                   \

    cudaError_t err = cudaGetLastError();                             \

    if (cudaSuccess != err) {                                         \

        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\

                 __FILE__, __LINE__, cudaGetErrorString(err) );       \

        exit(EXIT_FAILURE);                                           \

    }                                                                 \

    /* Check asynchronous errors, i.e. kernel failed (ULF) */         \

    err = cudaThreadSynchronize();                                    \

    if (cudaSuccess != err) {                                         \

        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\

                 __FILE__, __LINE__, cudaGetErrorString( err) );      \

        exit(EXIT_FAILURE);                                           \

    }                                                                 \

} while (0)

typedef float (*pf)(float a, float b);

__device__ float minimum (float a, float b)

{

    return fminf(a,b);

}

__device__ float maximum (float a, float b)

{

    return fmaxf(a,b);

}

__device__ pf func_d[2] = { maximum, minimum };

__shared__ float partExtr[THREAD_COUNT];

__device__ void minmax(float *x, float *res, int n, pf func)

{

    int i;

    int tid = threadIdx.x;

    float extr = x[0];

    for (i = THREAD_COUNT*blockIdx.x+tid; i < n; i += gridDim.x*THREAD_COUNT) {

        extr = func (extr, x[i]);

    }

    partExtr[tid] = extr;

    for (i = THREAD_COUNT >> 1; i > 0; i >>= 1) {

        __syncthreads(); 

        if (tid < i) {

            partExtr[tid] = func (partExtr[tid], partExtr[tid+i]);

        }

    }

    if (tid == 0) {

        res[blockIdx.x] = partExtr[tid];

    }

}

__global__ void minmax_kernel(float *x, float *res, int n, int findmin)

{

    minmax (x, res, n, func_d[findmin]);

}

float findExtremum (float *x, int n, int findmin)

{

    pf func_h[2] = { fmaxf, fminf };

    float *res_d;

    float *res_h;

    float *x_d;

    float r;

if (n < 1) return sqrtf(-1.0f); // NaN

    CUDA_SAFE_CALL (cudaMalloc ((void**)&res_d, BLOCK_COUNT*sizeof(res_d[0])));

    CUDA_SAFE_CALL (cudaMalloc ((void**)&x_d, n * sizeof(x_d[0])));

    CUDA_SAFE_CALL (cudaMemcpy (x_d, x, n * sizeof(x_d[0]), 

                                cudaMemcpyHostToDevice));

    minmax_kernel<<<BLOCK_COUNT,THREAD_COUNT>>>(x_d, res_d, n, !!findmin);

    CHECK_LAUNCH_ERROR();

    res_h = (float *)malloc (BLOCK_COUNT * sizeof(res_h[0]));

    if (!res_h) {

        fprintf (stderr, "res_h allocation failed\n");

        exit (EXIT_FAILURE);

    }

    CUDA_SAFE_CALL (cudaMemcpy (res_h, res_d, BLOCK_COUNT * sizeof(res_d[0]), 

                                cudaMemcpyDeviceToHost));

    CUDA_SAFE_CALL (cudaFree (res_d));

    CUDA_SAFE_CALL (cudaFree (x_d));

    r = res_h[0];

    for (int i = 1; i < BLOCK_COUNT; i++) r = func_h[findmin](r, res_h[i]);

    free (res_h);

    return r;

}

#define ELEM_COUNT 8

int main (void)

{

    float x[ELEM_COUNT] = {-1.3f, 2.4f, 3.5f, -2.3f, 4.5f, 0.4f, -5.3f, -1.6f};

    float minimum = findExtremum (x, ELEM_COUNT, 1);

    float maximum = findExtremum (x, ELEM_COUNT, 0);

    printf ("min=% 13.6e  max=% 13.6e\n", minimum, maximum);

    return EXIT_SUCCESS;

}

Thank you for your reply. I have implemented the following to see if I can get your method to work:

typedef int (*pf)(int a);

__device__ int

increment_k( int t )

{

	return t + 1;

}

__device__ int

decrement_k( int t )

{

	return t - 1;

}

__device__ pf func_d[2] = { increment_k, decrement_k };

__global__ void

test_k( int *testArray, int funcId, uint numPoints )

{

	unsigned int index = __umul24(blockIdx.x,blockDim.x) + threadIdx.x;

	if(index >= numPoints) return;

	pf func = func_d[funcId];

	int temp = func(testArray[index]);

	testArray[index] = temp;

}

void test( int *testArray,

		  int funcId,

		  unsigned int numPoints)

{

	// define grid and block size

    unsigned int numBlocks, numThreadsPerBlock;

	computeGridSize(numPoints, 256, numBlocks, numThreadsPerBlock);

// launch kernel

	test_k<<< numBlocks, numThreadsPerBlock >>>( testArray, funcId, numPoints );

	checkCUDAError("Test Kernel");

}

int main (void)

{

	int numPoints = 10;

	int *testArray = malloc(numPoints*sizeof(int));

	int *d_test;

	allocateArray(d_test, numPoints*sizeof(int));//in device memory

	for(uint i = 0; i < numPoints; i++)

	{

		testArray[i] = i;

	}

	copyArrayToDevice(d_test, testArray, numPoints*sizeof(int));

	test(d_test, INCREMENT, numPoints);

	copyArrayFromDevice(testArray, d_test, numPoints*sizeof(int));

	for(uint i  = 0; i < numPoints; i++)

	{

		printf("test %i = %i\n", i, testArray[i]);

	}

    return 0;

}

I’m using the GeForce GTX 460 2GB card which has compute capability 2.1 but I keep getting the following error:

“Error: Function pointers and function template parameters are not supported in sm_1x.”

Does this mean it won’t work?

Quoting myself:

General note: Invoking functions via function pointers in device code requires compute capability 2.x.

Prior to compute capability 2.0 the hardware did not support function calls through pointers. If your GPU is in fact Fermi-based, add -arch=sm_20 or an equivalent -gencode flag to the nvcc commandline. This is necessary because the compiler defaults to building for sm_10.

I had the same problem as sunburntfish:

“Error: Function pointers and Function template parameters are not supported in sm_1x.”, although I am using a gtx 460 with sm 2.1.

I realized that I need to change some properties, so I went to the project properties, and under CUDA Runtime API, in GPU, I changed GPU Architecture (1) to sm_20 ( there’s another gtx 460 in the box for debugging, and GPU Architecture (2) was set to sm_20, though it still gave the error ). Now the compiler(s) is happy, but the pointer doesn’t seem to work anymore, the function seems to be empty.

I will see what’s wrong but if its because of the property change I did, please let me know.

Thank you, this resolved my problem and it works fine now.

Note to anyone else who might try this, make sure you have the build rules for 3.2 as 3.0 does not allow building for anything other than sm_1x.