I need to use pytorch_forecasting’s TemporalFusionTransformer model: https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.html#pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer on large-volume retail data.
I’m training my model on databricks on a machine with 128GB memory and 32 cores.
My data is saved on blob storage in parquet format in groups of 30 data ids (named chunk below).
I’m training my model chunk by chunk and I see my memory usage increase as I go along, until it kills itself with a typical error: The Python process exited with exit code 137 (SIGKILL: Killed). This may have been caused by an OOM error. Check your command’s memory usage.
I don’t see where my memory problem lies on this model?
from datetime import date
from typing import Dict, Any, List
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor
)
import torch
from lightning.pytorch.trainer import Trainer
import pandas as pd
import gc
import sys
import tracemalloc
from src.spark_session import get_spark_session
from src.modelisation.utils import (
inverse_transform_scalers,
)
from conf.config import (
target_column,
columns_ids,
MODELISATION_TFT,
CATEGORICAL_ENCODERS,
TFT_MODEL,
num_workers_to_use,
)
def train_model_by_chunks(params: Dict[str, Any], list_chunks_path: List[str]):
tft = None
print("Starting iterate over chunks")
# Load encoders
categorical_label_encoders = torch.load("/dbfs" + CATEGORICAL_ENCODERS.as_posix())
nchunks_alread_trained = 0
# Create the trainer
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
lr_logger = LearningRateMonitor()
model_checkpoint = ModelCheckpoint(
dirpath=("/dbfs" + TFT_MODEL.as_posix()),
filename="best-checkpoint-tft",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min",
enable_version_counter=False,
)
trainer = Trainer(
max_epochs=params["max_epochs"],
accelerator="cpu",
enable_model_summary=True,
gradient_clip_val=0.1,
callbacks=[model_checkpoint],
default_root_dir=("/dbfs" + MODELISATION_TFT.as_posix()),
)
# Démarrer le suivi de la mémoire
tracemalloc.start()
for j, chunk_path in enumerate(list_chunks_path[nchunks_alread_trained:]):
i = j + nchunks_alread_trained
print(
f"Dealing with chunk {i+1} : {i*params['chunk_size_train']} eans have been treated."
)
chunk = pd.read_parquet(chunk_path)
chunk = preparation_timeseries_dataset_training_chunk(
chunk_df=chunk, params=params
)
# Train model
training = TimeSeriesDataSet(
chunk,
time_idx="time_index",
target=target_column,
group_ids=columns_ids,
static_categoricals=params["static_categoricals"],
static_reals=params["static_reals"],
time_varying_known_categoricals=params["time_varying_known_categoricals"],
time_varying_known_reals=params["time_varying_known_reals"],
time_varying_unknown_reals=params["time_varying_unknown_reals"],
target_normalizer=GroupNormalizer(
groups=columns_ids, transformation="softplus"
),
add_relative_time_idx=True,
add_target_scales=True,
max_prediction_length=params["max_prediction_length"],
max_encoder_length=params["max_encoder_length"],
categorical_encoders=categorical_label_encoders,
)
validation = TimeSeriesDataSet.from_dataset(
training, chunk, predict=True, stop_randomization=True
)
train_dataloader = training.to_dataloader(
train=True, batch_size=params["batch_size"], num_workers=num_workers_to_use
)
val_dataloader = validation.to_dataloader(
train=False,
batch_size=params["batch_size"] * 10,
num_workers=num_workers_to_use,
)
if i == 0:
# Initialize the model for the first chunk
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=16,
attention_head_size=2,
dropout=0.1,
hidden_continuous_size=8,
loss=QuantileLoss(),
log_interval=-1,
reduce_on_plateau_patience=4,
)
else:
# Load the model for subsequent chunks
tft = TemporalFusionTransformer.load_from_checkpoint(
"/dbfs" + (TFT_MODEL / "best-checkpoint-tft.ckpt").as_posix()
)
print(
f"Model loaded -- Number of parameters in network: {tft.size()/1e3:.1f}k"
)
trainer.fit(
tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
)
# Capturer un instantané avant la suppression des objets
snapshot_before = tracemalloc.take_snapshot()
top_stats_before = snapshot_before.statistics("lineno")
print("[ Top 10 Before deleting]")
for stat in top_stats_before[:10]:
print(stat)
del chunk
del train_dataloader
del val_dataloader
del training
del validation
del tft
# Delete unused
print(
f"Before gc collect - Nombre total d'objets en mémoire: {len(gc.get_objects())}"
)
print(
f"Before gc collect - Somme des objets en mémoire : {sum(sys.getsizeof(obj) for obj in gc.get_objects())*10e-9} GB"
)
gc.collect()
print(
f"After gc collect - Nombre total d'objets en mémoire: {len(gc.get_objects())}"
)
print(
f"After gc collect - Somme des objets en mémoire : {sum(sys.getsizeof(obj) for obj in gc.get_objects())*10e-9} GB"
)
# Compare memoire après la suppresson des objets
snapshot_after = tracemalloc.take_snapshot()
stats = snapshot_after.compare_to(snapshot_before, "lineno")
print("[ Top 10 différences ]")
for stat in stats[:10]:
print(stat)
tracemalloc.stop()
Robin Vandamme is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.