Btw, regarding Jax, I think of Mathematica implemented tensor differentiation natively, this would obviate the need to rely on external autodiff.
The issue is that neural network training involves differentiation of expression like norm(W1 W2 W3) w.r.t to matrix W2. Mathematica differentiation is at its core a scalar differentiation, it can't treat W2 as an atom. Autodiff frameworks like Jax can find these derivatives efficiently by treating matmuls as atomic operations.
I have thought a lot about the proper design of such a system and have a couple of prototypes, getting it into Mathematica is on my to-do list https://community.wolfram.com/groups/-/m/t/2437093