Jax vs numpy for generating Heston paths
I have this python code (from QuantPy) which generates stock paths under the Heston model using numpy. I am trying to convert it to using Jax. For some reason, the numpy version runs in about 2 seconds whereas the Jax version takes about 45 seconds. I would be very grateful if someone could point out why this is and suggest any improvement to make the Jax run faster.