Jax (and PyTorch) are already "integrated" with Mathematica using Python ExternalFunction
interface, so the question would be -- which parts of Jax/PyTorch would you want to integrate more closely?
If the goal is to replace MxNet parts impementing NN interface, then it would be the modeling code.
For PyTorch, that would mean integrating nn.Module
abstraction which gives is a building block that you wire up together with other nn.Module
blocks to build your neural net. This abstraction has been around for 8+ years, so it probably won't go away soon. For instance, models from PyTorch model hub and most repos under Papers with Code use nn.Module
interface. So one interesting use-case would be to make give a way to load model hub models into a Mathematica symbol which could be called as a function.
Jax has started as "autodiff for numpy", so neural network abstractions didn't come until later. There's Flax
created by Google and Haiku
that was created by DeepMind. Integrating Jax modeling would mean deciding whether you want to integrate Flax Dense or Haiku Linear layers.