I have this python code (from QuantPy) which generates stock paths under the Heston model using numpy. I am trying to convert it to using Jax. For some reason, the numpy version runs in about 2 seconds whereas the Jax version takes about 45 seconds. I would be very grateful if someone could point out why this is and suggest any improvement to make the Jax run faster.
# Parameters
S0 = 100.0 # initial asset price
K = 64 # strike price
T = 1.0 # time in years
r = 0.05 # risk-free rate
N = 252 # number of time steps in simulation
M = 100000 # number of simulations
# Heston dependent parameters
kappa = 2 # rate of mean reversion of variance under risk-neutral dynamics
theta = 0.05 # long-term mean of variance under risk-neutral dynamics
v0 = 0.05 # initial variance under risk-neutral dynamics
rho = -0.5 # correlation between returns and variances under risk-neutral dynamics
sigma = 0.3 # volatility of volatility
def heston_model_sim(S0, v0, rho, kappa, theta, sigma, r, T, N, M):
# initialise other parameters
dt = T/N
mu = np.array([0,0])
cov = np.array([[1,rho],
[rho,1]])
# arrays for storing prices and variances
S = np.full(shape=(N+1,M), fill_value=S0)
v = np.full(shape=(N+1,M), fill_value=v0)
# sampling correlated brownian motions under risk-neutral measure
Z = np.random.multivariate_normal(mu, cov, (N,M))
for i in range(1,N+1):
S[i] = S[i-1] * np.exp( (r - 0.5*v[i-1])*dt + np.sqrt(v[i-1] * dt) * Z[i-1,:,0] )
v[i] = np.maximum(v[i-1] + kappa*(theta-v[i-1])*dt + sigma*np.sqrt(v[i-1]*dt)*Z[i-1,:,1],0)
return S, v
def heston_model_sim_jax(S0, v0, rho, kappa, theta, sigma, r, T, N, M):
# Initialize other parameters
dt = T / N
mu = jnp.array([0, 0])
cov = jnp.array([[1, rho], [rho, 1]])
# Arrays for storing prices and variances
S = jnp.full((N+1, M), S0)
v = jnp.full((N+1, M), v0)
# Sampling correlated Brownian motions under risk-neutral measure
key = random.PRNGKey(0)
Z = random.multivariate_normal(key, mean=mu, cov=cov, shape=(N, M))
for i in range(1, N + 1):
S = S.at[i].set(S[i-1] * jnp.exp((r - 0.5 * v[i-1]) * dt + jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 0]))
v = v.at[i].set(jnp.maximum(v[i-1] + kappa * (theta - v[i-1]) * dt + sigma * jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 1], 0))
return S, v
I have read the Jax docs and other people’s questions online about Jax vs numpy for their specific functions, but I couldn’t find anything that helped me understand my function.
I am wondering if it is to do with the assignment: perhaps S[i] = … is quicker than S = S.at[i].set(…).
1
For comparing performance of JAX and NumPy, you should keep in mind the general discussion at FAQ: is JAX faster than NumPy?. In particular:
In summary: if you’re doing microbenchmarks of individual array operations on CPU, you can generally expect NumPy to outperform JAX due to its lower per-operation dispatch overhead. If you’re running your code on GPU or TPU, or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy.
Your code appears to be a non-jit-compiled sequence of array operations on CPU, which is exactly the regime where we’d expect NumPy to be faster than JAX.
You may be able to improve the JAX runtime by wrapping your function in jax.jit
(with appropriate arguments marked static) but you’ll likely find that the compilation time is very slow because of the use of for
loops in your code. You could address this by switching to an XLA-friendly iteration such as scan
or fori_loop
(See JAX sharp bits: control flow for some discussion); it would look something like this:
@partial(jax.jit, static_argnames=['N', 'M'])
def heston_model_sim_jax_2(S0, v0, rho, kappa, theta, sigma, r, T, N, M):
# Initialize other parameters
dt = T / N
mu = jnp.array([0, 0])
cov = jnp.array([[1, rho], [rho, 1]])
# Arrays for storing prices and variances
S = jnp.full((N+1, M), S0)
v = jnp.full((N+1, M), v0)
# Sampling correlated Brownian motions under risk-neutral measure
key = random.PRNGKey(0)
Z = random.multivariate_normal(key, mean=mu, cov=cov, shape=(N, M))
def body(i, carry):
S, v = carry
S = S.at[i].set(S[i-1] * jnp.exp((r - 0.5 * v[i-1]) * dt + jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 0]))
v = v.at[i].set(jnp.maximum(v[i-1] + kappa * (theta - v[i-1]) * dt + sigma * jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 1], 0))
return (S, v)
S, v = jax.lax.fori_loop(1, N + 1, body, (S, v))
return S, v
But even with this I wouldn’t expect JAX on CPU to be significantly faster than NumPy for this operation. The overarching issue is that your function is dispatching a small number of array operations per iteration with strict sequential dependence, so there’s no way for the XLA compiler to do the kind of hardware-tuned parallelization/vectorization that makes other JAX programs fast.