What I am doing:
I am currently trying to setup a training routine which will require me to use DDP. After doing some research, I decided to use PyTorch Lightning for this because of their ease of use, and since I was already quite familiar with basic PyTorch.
I am currently trying to use the LightningCLI to start my training. Using --trainer.strategy=ddp
it is very easy to switch over to DDP training.
After implementing my LightningDataModule I am now trying to implement the needed LightningModule.
What the problem is:
To my problem: I can’t find anywhere in the documentation where the torch.nn.Module
needs to be initialized to properly set this up for DDP training.
There are two possibilities, that I can think of, both of them do not seem right.
class TestModel(torch.nn.Module):
... # This is some Transformer Model
class TestModule(lightning.LightningModule):
def __init__(self, torch_model_params):
super(lightning.LightningModule).__init__()
self.torch_model_params = torch_model_params
self.model = TestModel(**self.torch_model_params) # OPTION 1
def setup(self, stage):
self.model = TestModel(**self.torch_model_params) # OPTION 2
...
if __name__ == '__main__':
LightningCLI(
model_class=TestModule,
datamodule_class=DataModule, # Implemented somewhere else
)
I expect only a single model to be trained across multiple SLURM nodes and multiple GPUs per node. How do I do it with this setup? Is it even possible or do I need to change my approach?
Lennart Eing is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.