I’m training a model using jax and optax on four GPUs and need to save and restore the optimizer state, but I’m running into a problem loading it. The optimizer state is initialized —
optimizer = optax.adamw(0.0001)
class Runner():
self.opt_state = optimizer.init(self.params) # intialize the optimizer
self.opt_state = jax.device_put_replicated(self.opt_state, devices) # replicate the opt_state to all devices
then saved only on the first device —
def save(self, save_dir):
params_path = save_dir + '/{}_params.pickle'.format(self.train_step)
opt_state_path = save_dir + '/{}_opt_state.pickle'.format(self.train_step)
with open(params_path, 'wb') as file:
pickle.dump(jax.device_get(jax.tree_map(lambda x: x[0], self.params)), file)
with open(opt_state_path, 'wb') as file:
pickle.dump(jax.device_get(jax.tree_map(lambda x: x[0], self.opt_state)), file)
logging.info('saved to {}, step {}'.format(save_dir, self.train_step))
Later on, the parameters and optimizer state are loaded
def restore(self, save_dir, step, restore_opt_state = True):
opt_state_path = save_dir + '/{}_opt_state.pickle'.format(step)
if restore_opt_state: # True
with open(opt_state_path, 'rb') as file:
opt_state = pickle.load(file)
replicate_opt_state = jax.device_put_replicated(opt_state, self.devices) # replicate the opt_state to all devices
self.opt_state = update_dict(self.opt_state, replicate_opt_state)
logging.info('restored opt state from {}, step {}'.format(save_dir, step))
where update_dict()
is defined:
def update_dict(params1, params2):
assert type(params1) == type(params2)
if type(params1) == dict:
params1.update(params2)
elif type(params1) == FrozenDict:
p = dict(params1['params'])
p.update(params2['params'])
params1 = {'params': FrozenDict(p)}
else:
raise ValueError('params1 type {} not implemented'.format(type(params1)))
return params1
When I run the restore
method on the correct save_dir
and step
, though, update_dict()
interprets it as a tuple and cannot restore the optimizer state (ValueError: params1 type <class 'tuple'> not implemented
). I’m not sure why this is, though, since save()
explicitly saves only the optimizer state on the first device. Additionally, this same setup seems to work find when loading the parameters from a single device and replicating them across multiple devices. Would very much appreciate some guidance here!
I was expecting similar performance to params, which gets saved and loaded as a dictionary. I’ve tried explicitly adding another jax.tree_map(lambda x: x[0] opt_state)
to the restore_opt_state
context manager, which didn’t work. It also looks like opt_state
is being saved as a tuple, not as a dictionary, which is what params
is being saved as. Would it be enough to pass one of the elements of the tuple through update_dict()
?
Jamie Mahowald is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.