I’m trying to implement Data Parallelism from scratch in PyTorch. To do this, I’ve implemented the following steps:
- Making model copies for each device
- Re-batching the data for each model copy
- Running forward passes
- Accumulating the gradients and averaging them
- Optimizer step on the averaged gradient <- you are here
I am trying to figure out how to combine the pytorch optimizer step and manual data parallelism. Currently, the only way I can do this is if I keep a copy of the optimizer around for each of the data parallel model replicas — here’s a simple reduction of that code
# Create optimizer such that its step would update the weights of all models
optimizers = [torch.optim.SGD(model.parameters(), lr=0.1) for model in models]
# Step the optimizers
for optimizer in optimizers:
optimizer.step()
I have a suspicion this is not how people implement the optimizer step, since it does not seem very memory efficient. How is this done in practice? Is there a way to attach one optimizer to multiple model copies? I’d like to understand what the proper way of doing this is!
arushi is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.