I’m training a DL model using Lightning, this is the general outcome of the module corresponding to the model:
class CustomModel(pl.LightningModule):
def __init__(...):
super().__init__()
# Create the required model
self.model = # model initialization
# Metrics
self.train_prec = torchmetrics.classification.Precision(task="binary")
self.train_recall = torchmetrics.classification.Recall(task="binary")
self.train_f1 = torchmetrics.classification.F1Score(task="binary")
self.val_prec = torchmetrics.classification.Precision(task="binary")
self.val_recall = torchmetrics.classification.Recall(task="binary")
self.val_f1 = torchmetrics.classification.F1Score(task="binary")
def forward(self, batch_video):
# implementation
def training_step(self, batch, batch_idx):
mode = "train"
loss = self._mode_step(batch, mode)
return loss
def validation_step(self, batch, batch_idx):
mode = "val"
loss = self._mode_step(batch, mode)
return loss
def test_step(self, batch, batch_idx):
mode = "test"
loss = self._mode_step(batch, mode)
return loss
def predict_step(self, batch, batch_idx, *args, **kwargs):
logit = self(batch["video"])
y_hat = torch.sigmoid(logit)
return y_hat, batch["label"]
def on_train_epoch_start(self) -> None:
"""Remove the data before a new epoch starts."""
self._reset_pred_label()
def _mode_step(self, batch, mode):
"""Generic step for train/val/test"""
# compute inference and loss
prob, label, loss = self._compute_pred_loss(batch)
# update metrics after every batch
self._update_metrics(mode, prob, label)
return loss
def _compute_pred_loss(self, batch):
# implementation
def _update_metrics(self, mode, prob, label):
"""Update all metrics per batch."""
if mode == "train":
self.train_prec.update(prob, label)
self.train_recall.update(prob, label)
self.train_f1.update(prob, label)
elif mode == "val":
self.val_prec.update(prob, label)
self.val_recall.update(prob, label)
self.val_f1.update(prob, label)
def _log_metrics(self, mode):
"""Compute and log all metrics."""
if mode == "train":
precision = self.train_prec.compute()
recall = self.train_recall.compute()
f1 = self.train_f1.compute()
elif mode == "val":
precision = self.val_prec.compute()
recall = self.val_recall.compute()
f1 = self.val_f1.compute()
# Log metrics at the end of the epoch
self.log(f"{mode}/precision", precision, on_epoch=True, on_step=False)
self.log(f"{mode}/recall", recall, on_epoch=True, on_step=False)
self.log(f"{mode}/f1", f1, on_epoch=True, on_step=False)
def _reset_metrics(self, mode):
if mode == "train":
self.train_prec.reset()
self.train_recall.reset()
self.train_f1.reset()
elif mode == "val":
self.val_prec.reset()
self.val_recall.reset()
self.val_f1.reset()
def on_train_epoch_end(self):
"""Compute and log validation metrics at epoch end."""
self._log_metrics("train")
# Reset metrics for the next epoch
self._reset_metrics("train")
def on_validation_epoch_end(self):
"""Compute and log validation metrics at epoch end."""
self._log_metrics("val")
# Reset metrics for the next epoch
self._reset_metrics("val")
def on_test_epoch_end(self):
"""Compute and log validation metrics at epoch end."""
self._log_metrics("test")
# Reset metrics for the next epoch
self._reset_metrics("test")
I defined a MLFlow logger and passed it to the trainer:
mlflow_logger = MLFlowLogger(experiment_name="some_name", run_id=mlflow_run_id, tracking_uri="some_uri)
trainer = pytorch_lightning.Trainer(logger=mlflow_logger, profiler=train_profiler, callbacks=callbacks, **trainer_kwargs)
Everything works fine, the metrics are uploaded correctly to MLFLow. The problem is that, they’re not logged as function of the epoch number, even if the logging is performed at the end of each epoch, and I specified:
on_epoch=True, on_step=False
any suggestion?