I’ve built a dynamics model in Dymos, using Jax to calculate the partial derivatives using auto-differentiation. The code looks something like the following:
import openmdao.api as om
import dymos as dm
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
# Define dynamics
class Dynamics(om.ExplicitComponent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._compute_primal_vec = jax.vmap(self._compute_primal)
self._compute_partials_vec = jax.jit(jax.vmap(jax.jacfwd(self._compute_primal, argnums=np.arange(5))))
def initialize(self):
self.options.declare('num_nodes', types=int)
def setup(self):
nn = self.options['num_nodes']
# States
self.add_input('theta', shape=(nn,), desc='orientation, anticlockwise from positive x', units='rad')
self.add_input('omega', shape=(nn,), desc='angular velocity, positive anticlockwise', units='rad/s')
# Controls
self.add_input('rho', shape=(nn,), desc='rotational control input')
# Parameters
self.add_input('k', shape=(nn,), desc='rotational resistance coefficient')
self.add_input('I', shape=(nn,), desc='rotational moment of inertia', units='kg*m**2')
# Outputs
self.add_output('omega_dot', val=np.zeros(nn), desc='rate of change of angular velocity', units='rad/s**2')
# Partials declared analytically
arange = np.arange(nn)
self.declare_partials(of='*', wrt='*', method='exact', rows=arange, cols=arange)
# Dynamics go here
@partial(jax.jit, static_argnums=(0,))
def _compute_primal(self, theta, omega, rho, k, I):
# For some reason, need to assign these otherwise they have the value 0/0J when checking the partials.
I = 1
k = 1
# Calculate moments
tau = rho - k*omega**2 # Rotational torque
# Calculate state rates of change (dynamics)
# Rotational
omega_dot = tau / I
return omega_dot
def compute(self, inputs, outputs):
omega_dot = self._compute_primal_vec(*inputs.values())
if np.isnan(np.sum(omega_dot)):
raise Exception("NaN values found in rates")
outputs['omega_dot'] = omega_dot
def compute_partials(self, inputs, partials):
output_names = ['omega_dot']
input_names = ['theta', 'omega', 'rho', 'k', 'I']
computed_partials = self._compute_partials_vec(*inputs.values())
# Cycle through computed partials
for out_ind, output_name in enumerate(output_names):
for in_ind, input_name in enumerate(input_names):
partials[output_name, input_name] = computed_partials[out_ind][in_ind]
This code may look a bit strange because it’s been cut down from a larger dynamics model, but it effectively models the rotation of an object with moment of inertia I given a torque input tau, which is a function of a control input rho. Theta is the orientation, and omega the angular velocity.
What might be slightly strange about the way I’ve built this too is the use of Jax to try to calculate the partial derivatives automatically. I’ve seen this done in the OpenMDAO and Dymos examples, e.g. using wrap_ode (OpenMDAO example, Dymos wrap_ode example), but my approach is slightly different, so I don’t know if it plays any part in the strange behaviour I’m experiencing.
The problem is built and set up as follows (initial state, control and parameter values aren’t set as I’m only interested in checking the partial derivatives):
num_segments = 10
order = 3
# Build problem
prob = om.Problem()
traj = dm.Trajectory()
prob.model.add_subsystem('traj', traj)
phase = dm.Phase(ode_class=Dynamics, transcription=dm.GaussLobatto(num_segments=num_segments, order=order),
ode_init_kwargs={})
traj.add_phase('phase0', phase)
# Add states, controls, parameters and objective
# States
phase.add_state('theta', rate_source='omega', targets=['theta'], units='rad')
phase.add_state('omega', rate_source='omega_dot', targets=['omega'], units='rad/s')
# Controls
phase.add_control('rho', continuity=True, rate_continuity=True, targets=['rho'])
# Parameters
phase.add_parameter('k', targets=['k'])
phase.add_parameter('I', units='kg*m**2', targets=['I'])
# Configure and set up
prob.setup(force_alloc_complex=True)
When I run check_partials using prob.check_partials(method='cs', compact_print=True)
, I get some significant, and systematic-looking, errors:
Stranger still, I don’t get an error for omega_dot w.r.t theta, which has an absolute error of 9.5238e-01 and a relative error of 1.0000e+00 in the full model.
What could be causing this? Is it an issue with the model itself, or perhaps the way I’m using Jax? There are a few peculiarities:
- The fact that the errors are for different partials in the original model and this minimum working example.
- In _compute_primal(…), the values of k and I don’t ‘carry through’ into the function – they end up being 0 or 0J when check_partials is run (which is why I have to set them explicitly in _compute_primal, to avoid division by zero errors) – why would this be?
As an extra aside, I’m not sure what the difference is between the rhs_disc and rhs_col components of the phase?
Many thanks, I really appreciate any assistance.
Here are more details for the non-zero/non-nan partials (these were screenshotted on a different run, and it seems to have changed in-between):