Creating function call in device code and linking to ptx

I have a use case where I want to call a device function defined by ptx. I don’t know exactly how to make the function call from existing device code to match the signature in my ptx.

Here’s the tool flow I’m attempting. If you see an alternative tool flow I’m happy to try it.

The approach that I’m pursuing at the moment is

  1. lower specialized PyTorch model to StableHLO dialect of MLIR
  2. generate PTX using iree
  3. Write a c/cuda file with a function call that matches the signature of the ptx and generate ptx for that c file
  4. Link the ptx files for execution