Jax (and PyTorch) are already "integrated" with Mathematica using Python ExternalFunction interface, so the question would be -- which parts of Jax/PyTorch would you want to integrate more closely?
If the goal is to replace MxNet parts impementing NN interface, then it would be the modeling code.
For PyTorch, that would mean integrating nn.Module abstraction which gives is a building block that you wire up together with other nn.Module blocks to build your neural net. This abstraction has been around for 8+ years, so it probably won't go away soon. For instance, models from PyTorch model hub and most repos under Papers with Code use nn.Module interface. So one interesting use-case would be to make give a way to load model hub models into a Mathematica symbol which could be called as a function.
Jax has started as "autodiff for numpy", so neural network abstractions didn't come until later. There's Flax created by Google and Haiku that was created by DeepMind. Integrating Jax modeling would mean deciding whether you want to integrate Flax Dense or Haiku Linear layers.