Hi guys !
thanks for your help but i think i found a solution to my problem using template classes.
For everyone who is still interested i wrote a detailed minimal example.
pros :
[*] By the use of template classes [font=“Courier New”]Class[/font] one can derive specialized classes [font=“Courier New”]Class<1> Class<2> …[/font] which allow inheritance of attributes and methods from the main template [font=“Courier New”]Class[/font] and [font=“Courier New”][virtual][/font] behavior by spezialization/reimplementation of methods for the spezialized classes [font=“Courier New”]Class[/font].
[*] Inspired by “Loop Unrolling with template classes” i got it working to have a base class [font=“Courier New”]Class<0>[/font] in host code, which is able to return methods of spezialized classes depending on a private attribute [font=“Courier New”]unsigned int mType[/font]
[*] Now the kernels can be implemented as template functions to a fixed type [font=“Courier New”]T[/font].
[*] A user can derive new classes by simply increasing a define ([font=“Courier New”]NUM_TYPES[/font] in my case) and implementing the specialized methods.
contras :
[*] Complex and harder to read code
[*] No compiling of functions with extern linkage ( template specializations have to be included/inlined ).
[*] Every [font=“Courier New”]OtherClass[/font] that needs a pointer to [font=“Courier New”]Class[/font] needs to be a template class [font=“Courier New”]OtherClass[/font] itself (in host code one can always use a base class pointer [font=“Courier New”]Class<0> * ptr[/font], but for classes used inside a [font=“Courier New”]kernel[/font] this restriction applies).
This is a rather long “minimal” example so you have to Unroll it ;-) [spoiler]
Metric.h :
#define _h__ __host__
#define _hd_ __host__ __device__
// Template argument "MetricType"
typedef unsigned int MetricType;
#define Base 0
#define Type1 1
#define Type2 2
#define NUM_TYPES 3
// Use type "Base" by default
template<MetricType T = Base>
class Metric
{
public : //......... Constructor
_hd_ Metric();
_h__ Metric( MetricType Type );
public : //......... Public Methods
_h__ void setType( MetricType Type );
_hd_ MetricType getType();
_hd_ const char * getName();
private : //........ Specialized Methods
_hd_ const char * SpecializedName();
private : //........ Private Attributes
MetricType mType;
};
// we have to include all template definitions in the
// header because they are only compiled as needed !!
#include "Metric.inl"
The file Metric.inl is only a placeholder for including the various definitions
#include "Metric.h"
// Metric<Base> definitions
#include "MetricBase.inl"
// Add custom specializations here
#include "MetricType1.inl"
#include "MetricType2.inl"
The specialization/reimplementation for the MetricTypes [font=“Courier New”]Type1[/font] and [font=“Courier New”]Type2[/font] follows straight forward:
( here the [font=“Courier New”]inline[/font] qualifier is needed to avoid a “multiple definitions” linker error when dealing
with several objects files )
MetricType1.inl :
template<>
inline __host__ __device__ Metric<Type1>::Metric()
{
mType=Type1;
}
template<>
inline __host__ __device__ const char * Metric<Type2>::SpecializedName()
{
return "Metric1 [reimplemented]";
}
MetricType2.inl :
template<>
inline __host__ __device__ Metric<Type2>::Metric()
{
mType=Type2;
}
template<>
inline __host__ __device__ const char * Metric<Type2>::SpecializedName()
{
return "Metric2 [reimplemented]";
}
The basic magic happens in the definitions of the public methods
These are the same for every [font=“Courier New”]Metric[/font] instance. For [font=“Courier New”][virtual][/font] behavior Loop unrolling allows to access the (private) specialized methods depending on the [font=“Courier New”]MetricType T[/font] (see [font=“Courier New”]Metric::getName()[/font]) stored in [font=“Courier New”]mType[/font] by incrementing the template from [font=“Courier New”]Metric[/font] to [font=“Courier New”]Metric<T+1>[/font] (for template arguments T only compile time constants are allowed - NO variables !!)
MetricBase.inl :
// BASE: Standard constructor with no arguments
template<MetricType T>
__host__ __device__ Metric<T>::Metric()
{
mType = T;
}
// BASE: Typeset constructor
template<MetricType T>
__host__ Metric<T>::Metric( MetricType Type )
{
mType = Type;
}
// BASE: Set metric type
template<MetricType T>
__host__ void Metric<T>::setType( MetricType Type)
{
mType = Type;
}
// BASE: Get metric type
template<MetricType T>
__host__ __device__ MetricType Metric<T>::getType()
{
return mType;
}
/* ----------------------------------------------------- *
* BASE : Get metric names
* ----------------------------------------------------- */
template<MetricType T>
__host__ __device__ const char * Metric<T>::getName()
{
#if defined(__CUDACC__)
// DEVICE : We will only use specialized class instances
// Metric<MetricType T> [with T > 0] in device code !!
return SpecializedName();
#else // !defined(__CUDACC__)
// HOST : In host code we will always use a base instance
// Metric<Base> [T = 0] and use recursive loop unrolling over the
// template class argument T :
if( T < mType )
{
// cast "this"-pointer to a template of argument T+1
// and do recursive call of getName() ...
Metric<T+1> * NextMetric = (Metric<T+1> *)this;
return NextMetric->getName();
}
else
// ... until T is desired type (T == mType).
// now return specialization !
return SpecializedName();
}
#endif
}
// The compiler assumes mType to have an arbitrary value in the type
// range of MetricType (unsigned int). To prevent the compiler from
// infinite recursion ( Metric<T>::getName() has to be compiled
// seperately for every T ), we have to stop it at the end of valid
// metric types < NUM_TYPES >.
template<>
inline __host__ const char * Metric< NUM_TYPES >::getName()
{
// if this specialization is was reached something went wrong
return "error in unrolling metric types";
}
// BASE: Get metric type
template<MetricType T>
__host__ __device__ const char * Metric<T>::SpecializedName()
{
// this function must be reimplemented in specialized templates
return "Base [to be reimplemented]";
}
Assuming we implemented a Kernel
template<MetricType T> __global__ void Kernel(Metric<T> Metric)
we can call the right version by applying the same strategy used in [font=“Courier New”]Metric::getName()[/font].
device.cu
// forward declaration
template<MetricType T> KernelUnroller( Metric<T> * MetricT );
// this wrapper calls the kernel from host code
extern "C" void KernelWrapper( Metric<Base> * MetricBase )
{
KernelUnroller( MetricBase );
}
template<MetricType T>
void KernelUnroller( Metric<T> * MetricT )
{
if ( T < MetricT->getType() )
{
// Template has not the right type ...
// increment type and recurse !!
Metric<T+1> * nextMetric = (Metric<T+1> *) MetricT;
KernelUnroller( nextMetric );
}
else if( T == MetricT->getType() )
{
// Template has the desired type ... call template kernel
// ( TODO: MetricT has to be converted to a device pointer )
Kernel<<<GRID_SIZE,BLOCK_SIZE>>>( MetricT );
}
else
{
// For safety ... if type exceeds defined types, exit(fatal) !
exit( EXIT_FAILURE );
}
}
This means major modifications on my project, but i think i’ll reimplement everything using the shown strategy. This was proofed working in host code and tested wth a fixed type T on device. The time needed to compile the problematic kernel was reduced from 3 minutes for “if-branches” to 1:30 using the template unrolling method.
If you have suggestion for improvement or questions feel free to ask.
[/spoiler]
yours, spy !