I’m trying to modify the scqubits Python package to use JAX’s custom_vjp for differentiable programming. While doing so, I encountered the following error:
TypeError: ‘Add’ object is not iterable
Here’s a minimal example that reproduces the issue:
import sympy as sm
x, y, z = sm.symbols('x y z')
expr = x**2 + 2*y + z
try:
for term in expr:
print(term)
except TypeError as e:
print(f"Error: {e}")
This code outputs:
Error: ‘Add’ object is not iterable
In the scqubits codebase, specifically in scqubits/core/circuit_routines.py, there’s a function that iterates over a SymPy expression:
def _constants_in_subsys(self, H_sys: sm.Expr, constants_expr: sm.Expr) -> sm.Expr:
"""
Returns an expression of constants that belong to the subsystem with the
Hamiltonian H_sys
Parameters
----------
H_sys:
Subsystem Hamiltonian
Returns
-------
Expression of constants belonging to the subsystem
"""
constant_expr = 0
subsys_free_symbols = set(H_sys.free_symbols)
constant_terms = constants_expr.copy()
for term in constant_terms:
if set(term.free_symbols) & subsys_free_symbols == set(term.free_symbols):
constant_expr += term
return constant_expr
When I run this function with JAX’s custom_vjp, it raises the same TypeError.
My questions are:
Is iterating over a SymPy expression directly (as in for term in expr) incorrect?
Why does this raise a TypeError saying ‘Add’ object is not iterable?
How can I modify the code to correctly iterate over the terms of a SymPy expression, especially in the context of using JAX?
Additional Information:
SymPy Version: (1.9)
JAX Version: (0.2.25)
scqubits Version: (If applicable)
I’ve tried replacing the iteration with expr.args, like so:
for term in expr.args:
print(term)
This works without errors, but I’m unsure if this is the best or most reliable approach, especially when integrating with JAX.
Any insights or suggestions would be greatly appreciated!
Seems like you found a bug in scqubits
. I suggest opening an issue in that repository. Apparently, that function is untested.
In sympy, you can’t iterate over an expression. However, you can iterate over the arguments of an expression, just like you found above.
Note that any sympy expression exposes the args
attribute (like the addition in your example, or a multiplication, or power, or a function, or a derivative, etc.)
import sympy as sm
x, y, z = sm.symbols('x y z')
expr = x**2 + 2*y + z
try:
for term in expr.args:
print(term)
except TypeError as e:
print(f"Error: {e}")
# z
# x**2
# 2*y