I am training a model using YOLO NAS pytorch. It’s working fine but my kernel keeps dying at epoch20. So i am trying to resume training from epoch 21 by loading epoch number and other details from average checkpoint file(average_model.pth) but no matter what it always resumes training from epoch 0. PFB code details
————————————-
#below part deals with respective YOLO image,label path declarations
from super_gradients.training.dataloaders.dataloaders import coco_detection_yolo_format_train, coco_detection_yolo_format_val
BATCH_SIZE = 1
CLASSES = ['product']
CLASSES += [str(i) for i in range(80 - len(CLASSES))]
dataset_params = {
'data_dir': r"C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixed",
'train_images_dir':r'C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixedimagestrain',
'train_labels_dir':r'C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixedlabelstrain',
'val_images_dir':r'C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixedimagesval',
'val_labels_dir':r'C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixedlabelsval',
'test_images_dir':r'C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixedimagestest',
'test_labels_dir':r'C:UsersGirirajDocumentsPrernas ML ModelsSKU110KDatasetSKU110K_fixedlabelstest',
'classes': CLASSES
}
train_data = coco_detection_yolo_format_train(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['train_images_dir'],
'labels_dir': dataset_params['train_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': BATCH_SIZE,
'num_workers': 2
}
)
val_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['val_images_dir'],
'labels_dir': dataset_params['val_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': BATCH_SIZE,
'num_workers': 2
}
)
test_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['test_images_dir'],
'labels_dir': dataset_params['test_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': BATCH_SIZE,
'num_workers': 2
}
)
#below part deals with respective model,device declarations
import torch
from super_gradients.training import models
from super_gradients.training import Trainer
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.get('yolo_nas_s', pretrained_weights="coco").to(DEVICE)
#optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
#optimizer = torch.optim.Adam(model.parameters())
trainer = Trainer(experiment_name="SKU110K", ckpt_root_dir="./weights")
---------------------------------------------------------------
#below part deals with respective parameter declarations like max epoch etc
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
MAX_EPOCHS = 11
train_params = {
'silent_mode': False,
"average_best_models":True,
"warmup_mode": "linear_epoch_step",
"warmup_initial_lr": 1e-6,
"lr_warmup_epochs": 3,
"initial_lr": 5e-4,
"lr_mode": "cosine",
"cosine_final_lr_ratio": 0.1,
"optimizer": "Adam",
"optimizer_params": {"weight_decay": 0.0001},
"zero_weight_decay_on_bias_and_bn": True,
"ema": True,
"ema_params": {"decay": 0.9, "decay_type": "threshold"},
"max_epochs": MAX_EPOCHS,
"mixed_precision": True,
"loss": PPYoloELoss(
use_static_assigner=False,
num_classes=len(dataset_params['classes']),
reg_max=16
),
"valid_metrics_list": [
DetectionMetrics_050(
score_thres=0.1,
top_k_predictions=50,
num_cls=len(dataset_params['classes']),
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
nms_top_k=100,
max_predictions=20,
nms_threshold=0.7
)
)
],
"metric_to_watch": '[email protected]'
}
#below is training model part and same is used during retraining by changing start epoch number
import torch
# Set the desired starting epoch
# Load the model with the weights from the average checkpoint file
checkpoint_file = torch.load(r'C:UsersGirirajDocumentsPrernas ML Modelsyolo-nas-retail-training-mainyolo-nas-retail-training-mainweightsSKU110Kaverage_model.pth')
model.load_state_dict(checkpoint_file['net'])
START_EPOCH = checkpoint_file['epoch'] + 1
max_epochs=3
# Your training loop
for epoch in range(START_EPOCH, max_epochs):
# Train the model
trainer.train(
model=model,
training_params=train_params,
train_loader=train_data,
valid_loader=val_data
)
print(f"Epoch {epoch + 1})