Hello is there any example using jax with warp library ? As far as I understand one need to first parse jax into pytorch and than this to jax - as it should not really lead to actual memory allocations as far as I understand what problems one should be aware of - any problems with differentiability? stability? performance?
Hi @jakub.mitura14. Let me check with the warp team.
We recently added some direct interop utilities for Jax via dlpack. This allows sharing Jax and Warp arrays without copying the data, so it should be fast. Check out these tests for sample usage:
We don’t have more complete examples at this time. One thing to keep in mind is that Jax arrays are immutable, so if you run a Warp kernel that modifies a Jax array in place, then Jax might not be aware of this. In the tests, we got around it by running a bogus operation on the arrays (adding 0), but that’s a bit hacky. Jax interop is still work in progress.