I found a way to do this with Curiously Recurring Template Pattern. It’s a static inheritance that avoid use of virtual function.
Here is a basic sample!
#include <vector>
#include <memory>
#include <typeinfo>
#include <iostream>
#include <variant>
#include <utility>
template <typename Derived>
class Base{
public:
double implementation()
{
return static_cast<Derived*>(this)->implementation();
}
__device__ void test()
{
return static_cast<Derived*>(this)->test();
}
};
class DerivedA : public Base<DerivedA>{
public:
double implementation(){ return 2.0;}
__device__ void test(){printf("i am A from GPU\n");}
};
class DerivedB : public Base<DerivedB>{
public:
double implementation(){ return 1.0;}
__device__ void test(){printf("i am B from GPU\n");}
};
using Child = std::variant<DerivedA, DerivedB>;
template <typename T>
__global__ void kernel(T a)
{
a.test();
}
int main() {
auto obj1 = new DerivedA;
auto obj2 = new DerivedB;
cudaDeviceSynchronize();
std::vector<Child> forces;
forces.push_back(*obj1);
forces.push_back(*obj2);
for(auto f:forces)
{
int grid = 1;
int threads = 1;
std::visit( [&grid,&threads](auto&& e)
{
kernel<<<grid,threads>>>(e);
cudaDeviceSynchronize();
}, f);
}
return 0;
}
compile with ncvv file.cu -std=c++17
Since the question of class heritage in Cuda comes up so often, I hope this example will help!