I would like to be able to solve an ODE with Forward Euler and save the results every nth iteration.
Here is an example that stores all the data:
import jax.numpy as jnp
from jax import lax
def f(y0, t, alpha, beta, gamma, delta):
x,y = y0
return jnp.array([alpha*x - beta*x*y, delta*x*y - gamma*y])
def euler_integration(df_dt, y0, t, *args):
def make_step(state, t):
y_prev, t_prev = state
h = t - t_prev
y = y_prev + h*df_dt(y_prev, t_prev, *args)
return (y, t), y
_, ys = lax.scan(make_step, (y0, t[0]), t)
return ys
t_grid = jnp.linspace(0, 15., 10000)
ys = euler_integration(f, jnp.array([1., 5.]), t_grid,
1.1, 2.2, 3.3, 4.4)
How can I store the results at very 100 steps?
New contributor
aqw is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.