I’m trying to use JAX’s custom_vjp to define custom gradient computations for a function that takes a SymPy expression as an input. However, I’m encountering errors because JAX doesn’t support non-JAX types as inputs for functions that are transformed (e.g., with grad, jit, or custom_vjp). I am recently modifying code in ScQubits to add a new jax backend to improve the efficiency, and then I meet some problem with jax and sympy.
Here’s a minimal example of what I’m trying to do:
import jax
import sympy as sm
# Define symbols and expression
x, y, z = sm.symbols('x y z')
expr = x**2 + 2*y + z
# Attempt to iterate over expr (this will cause an error)
try:
for term in expr:
print(term)
except TypeError as e:
print(f"Error: {e}")
# Define a function that takes a SymPy expression and a value
def sympy_function(expr, x_value):
x = sm.Symbol('x')
result = expr.subs(x, x_value)
return float(result)
# Attempt to apply custom_vjp
sympy_function = jax.custom_vjp(sympy_function)
def sympy_function_fwd(expr, x_value):
y = sympy_function(expr, x_value)
return y, (expr, x_value)
def sympy_function_bwd(residual, grad_y):
expr, x_value = residual
x = sm.Symbol('x')
derivative_expr = sm.diff(expr, x)
grad_x_value = float(derivative_expr.subs(x, x_value))
grad_expr = None
return grad_expr, grad_y * grad_x_value
sympy_function.defvjp(sympy_function_fwd, sympy_function_bwd)
# Test the function
x = sm.Symbol('x')
expr = x**2 + 3*x + 2
x_value = 1.0
# This will raise an error
y = sympy_function(expr, x_value)
When I run this code, I get an error like:
TypeError: Value x**2 + 3*x + 2 with type <class 'sympy.core.add.Add'> is not a valid JAX type
How can I use jax.custom_vjp with functions that take non-JAX types like SymPy expressions as inputs? Is there a way to work around this limitation, or to make JAX accept such functions?
There is no way to do what you’re trying to do: JAX transformations like custom_jvp
can only differentiate with respect to JAX-compatible values like Arrays and pytrees. Non-JAX values can only be used as nondiff_argnums
within the custom_vjp
; in this case, the only arguments to your function are non-differentiable arguments. As a result the custom_vjp
is unnecessary, because your function is not differentiable within JAX’s autodiff framework.