I’m looking to pass a jax calculated jacobian to my objective function (for an lmfit minimizer) that incorporates a BSpline object which is not supported in jax.
I’ve defined a custom jvp for a finite difference approximation relating to this part of the code
@convert.defjvp
def convert_jvp(primals, tangents):
alpha_values, knots = primals
alpha_tangent = tangents[0]
epsilon = 1e-4
primal_output = convert(alpha_values, knots)
derivatives = {key: jnp.zeros_like(primal_output[key]) for key in primal_output}
for i in range(alpha_values.shape[0]):
perturbed_alpha = alpha_values.at[i].add(epsilon * alpha_tangent[i])
perturbed_output = convert(perturbed_alpha, knots) # Function evaluation at perturbed alpha #outputs a dictionary with 4 keys and 4 arrays
for key in primal_output:
derivative = (perturbed_output[key] - primal_output[key]) / epsilon
derivatives[key] += derivative * alpha_tangent[i]
return primal_output, derivatives
this gets called during a minimization routine from the lmfit library where i pass the jacobian:
def get_alpha(params):
return jnp.array([params[f'alpha{i}'].value for i in range(len(alpha_initial))])
def objective(alpha_values, knots, call_prices,strikes,forward):
poly_data = convert(alpha_values, knots)
model_prices = jnp.array([poly_data['a'][i] for i in range(len(call_prices))]) #placeholder for actual pricer
errors = (jnp.array(call_prices) - model_prices)
return errors
jacobian_func = jax.jit(jax.jacfwd(objective, argnums=0))
def jacobian(params, knots,call_prices,strikes,forward):
alpha_values = get_alpha(params)
return jnp.array(jacobian_func(alpha_values, knots, call_prices,strikes,forward))
# Wrapper for lmfit objective
def wrapped_objective(params, knots,call_prices,strikes,forward):
alpha_values = get_alpha(params)
return objective(alpha_values, knots, call_prices,strikes,forward)
result = minimize(wrapped_objective, params, args=(knots, call_prices,strikes,forward),Dfun=jacobian, method='leastsq')
however the jvp function only seems to be getting called once. furthermore, the resulting fit is far worse than a fit without the jacobian as input (uncommenting the second result for the first one yields a better fit without the jacobian as input)
i cannot tell whether this is due to an incorrect structuring of the jvp function, or whether this is due to how lmfit understands the jacobian’s input. any help would be greatly appreciated.
full code (reproducable):
from lmfit import Parameters,minimize
from jax import custom_jvp
import jax.numpy as jnp
from scipy.interpolate import PPoly,BSpline
import jax
def calibrate_jax(bspl_initial, call_prices,strikes,forward):
alpha_initial = bspl_initial.c
knots = bspl_initial.t
params = Parameters()
params.add('alpha0', value=alpha_initial[0], vary=True)
for i in range(1, len(alpha_initial)):
params.add(f'delta{i}', value=max(alpha_initial[i] - alpha_initial[i-1], 1e-4), min=0, vary=True)
params.add(f'alpha{i}', expr=f'alpha{i-1} + delta{i}')
@custom_jvp
def convert(alpha_values, knots):
len_ = len(knots) - 1
result_shape_dtypes = [
jax.ShapeDtypeStruct(shape=(len_,), dtype=jnp.float32),
jax.ShapeDtypeStruct(shape=(len_,), dtype=jnp.float32),
jax.ShapeDtypeStruct(shape=(len_,), dtype=jnp.float32),
jax.ShapeDtypeStruct(shape=(len(knots),), dtype=jnp.float32)
]
# Use pure_callback to interface with non-JAX compatible code
a, b, c, x = jax.pure_callback(get_bspl, result_shape_dtypes, alpha_values, knots, 2)
return {'a': a, 'b': b, 'c': c, 'x': x}
def get_bspl(alpha, knots, degree=2):
alpha = jnp.asarray(alpha)
knots = jnp.asarray(knots)
bspl = BSpline(knots, alpha, degree)
poly = PPoly.from_spline(bspl, extrapolate=False)
return poly.c[2],poly.c[1],poly.c[0],poly.x
@convert.defjvp
def convert_jvp(primals, tangents):
alpha_values, knots = primals
alpha_tangent = tangents[0]
epsilon = 1e-4
primal_output = convert(alpha_values, knots)
derivatives = {key: jnp.zeros_like(primal_output[key]) for key in primal_output}
for i in range(alpha_values.shape[0]):
perturbed_alpha = alpha_values.at[i].add(epsilon * alpha_tangent[i])
perturbed_output = convert(perturbed_alpha, knots) #outputs a dictionary with 4 keys and 4 arrays
for key in primal_output:
derivative = (perturbed_output[key] - primal_output[key]) / epsilon
derivatives[key] += derivative * alpha_tangent[i]
return primal_output, derivatives
def get_alpha(params):
return jnp.array([params[f'alpha{i}'].value for i in range(len(alpha_initial))])
def objective(alpha_values, knots, call_prices,strikes,forward):
poly_data = convert(alpha_values, knots)
model_prices = jnp.array([poly_data['a'][i] for i in range(len(call_prices))]) #placeholder for actual pricer
errors = (jnp.array(call_prices) - model_prices)
return errors
jacobian_func = jax.jit(jax.jacfwd(objective, argnums=0))
def jacobian(params, knots,call_prices,strikes,forward):
alpha_values = get_alpha(params)
return jnp.array(jacobian_func(alpha_values, knots, call_prices,strikes,forward))
# Wrapper for lmfit objective
def wrapped_objective(params, knots,call_prices,strikes,forward):
alpha_values = get_alpha(params)
return objective(alpha_values, knots, call_prices,strikes,forward)
result = minimize(wrapped_objective, params, args=(knots, call_prices,strikes,forward),Dfun=jacobian, method='leastsq')
#result = minimize(wrapped_objective, params, args=(knots, call_prices,strikes,forward), method='leastsq') #uncomment for desired output
final_alpha_values = [result.params[f'alpha{i}'].value for i in range(len(alpha_initial))]
final_bspl = BSpline(t=knots, c=final_alpha_values, k=2)
final_poly = PPoly.from_spline(final_bspl)
return final_poly, result
prices = [337.94, 333.22, 310.49, 306.57, 288.3, 266.1, 249.4, 244.47, 232.02, 223.57, 215.47, 203.49, 199.4]
strikes = [20, 25, 50, 55, 75, 100, 120, 125, 140, 150, 160, 175, 180]
t = [-4.63, -4.62, -4.61, -1.09, -0.360, 0.05, 0.51]
c = [2.9, 4.09, 5.42, 5.78, 6.05, 6.37, 6.55]
forward = 356.73063159822254
bspl = BSpline(t,c,2)
spl,res = calibrate_jax(bspl,prices,strikes,forward)
Daily is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
2
Something looks strange in your finite difference code: you’re perturbing by alpha * epsilon
but then dividing by epsilon
. I suspect this is producing incorrect gradients. You’d probably want this instead:
perturbed_alpha = alpha_values.at[i].add(epsilon)
Stepping back though: using finite differences within custom_jvp
is very strange: the whole point of autodiff is to avoid the need to compute gradients via finite differences, which are relatively slow and inaccurate even when correctly implemented.
Unless you can figure out how to compute a closed-form version of the bspline gradient, I suspect your best option will be to remove all the custom_jvp
code and not use Dfun
at all when computing your result:
result = minimize(wrapped_objective, params, args=(knots, call_prices,strikes,forward), method='leastsq')
The minimize
call here already uses an accurate, well-tuned finite difference algorithm to estimate the jacobian: there’s no need to try to recreate that yourself.