Idea is that whenever from a config if we set warm_start_tuning_job parameter
to auto
(instead of null), use the SageMaker API to automatically retrieve relevant HP tuning jobs and include (the most recent 5) as parent tuning jobs
I can do something like this:
def get_recent_tuning_jobs(self, sagemaker_session=None):
sagemaker_session = sagemaker_session or sagemaker.Session()
response = sagemaker_session.sagemaker_client.list_hyper_parameter_tuning_jobs(
SortBy="CreationTime", SortOrder="Descending", MaxResults=5
)
tuning_jobs = [
job["HyperParameterTuningJobName"]
for job in response["HyperParameterTuningJobs"]
]
return tuning_jobs
Reference documentation
However, how do we know these tuning jobs are related to this specific model training? We need to ensure they used the same training data, etc. I might need to pull around last 100 tuning jobs and run a few checks on them and see if they meet the criteria to be used as parent warm up start job for current training effort.
Few checks I can think off:
Will need to add checks to see tuning jobs meet certain criterias
# important restrictions to check:
# data is the same (InputDataConfig)
# HyperParameterTuningJobObjective
# count of static plus tunable hyperparameters
# type of each hyperparameter (continuous, integer, categorical)
# number of total changes to HPs
# as you iterate over each job, you'll need to make sure the total number of training jobs doesn't exceed 500. so count the training jobs for current tuning job (max_jobs) plus number of training jobs for previous tuning job (ResourceLimits -> MaxNumberOfTrainingJobs) and break if you're going to exceed 500
I am not sure how exactly to ensure this and have some bool flag like criteria_met
to ensure the tuning job used as the parent warm start job belongs to that same training effort
Complete file I am working on that has relevant code looks like this:
import logging
import os
import shutil
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Union
import boto3
import lightgbm as lgb
import pandas as pd
import sagemaker
import xgboost as xgb
import yaml
from botocore.exceptions import ClientError
from sagemaker import image_uris
from sagemaker.tuner import HyperparameterTuner, WarmStartConfig, WarmStartTypes
from branch_ml.train_evaluate.hyperparameters import (
LightGBMHyperparameterBuilder,
XGBoostHyperparameterBuilder,
)
from branch_ml.train_evaluate.utils_constants import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON
from branch_ml.train_evaluate.utils_data import (
ASSET_DIR_MODEL_FILENAMES,
VALID_HP_EARLY_STOPPING_TYPES,
VALID_HP_OBJECTIVE_METRICS,
VALID_HP_TUNING_STRATEGIES,
VALID_INSTANCE_TYPES_BATCH_TRANSFORM,
VALID_INSTANCE_TYPES_TRAIN,
VALID_MODELS,
copy_folder_contents_on_s3,
dict_remove_none_values,
load_data,
read_from_s3,
read_s3_tar_member,
)
logging.basicConfig(
format="{asctime} : {levelname} : {filename} : {funcName} : {message}", style="{"
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
MODEL_FILE = "model.tar.gz"
DEFAULT_TRAIN_INSTANCE_COUNT = 1
DEFAULT_TRAIN_INSTANCE_TYPE = "ml.m5.xlarge" # ml.m5.xlarge or ml.m5.2xlarge
DEFAULT_TRAIN_VOLUME_SIZE = 20 # GB
DEFAULT_TRAIN_MAX_RUN = 8 * 60 * 60 # seconds; default: 8 hours
DEFAULT_TRAIN_HP_MAX_JOBS = 128
DEFAULT_TRAIN_HP_MAX_PARALLEL_JOBS = 4
DEFAULT_TRAIN_HP_OBJECTIVE_METRIC_NAME = "validation:auc"
DEFAULT_TRAIN_HP_EARLY_STOPPING_TYPE = "Auto"
DEFAULT_TRAIN_HP_TUNING_STRATEGY = "Bayesian"
DEFAULT_BATCH_TRANSFORM_INSTANCE_COUNT = 1
DEFAULT_BATCH_TRANSFORM_INSTANCE_TYPE = "ml.c5.xlarge"
DEFAULT_MODEL_PACKAGE_BUCKET = "branch-in-models"
DEFAULT_MODEL_REPO_NAME = "xgboost"
DEFAULT_MODEL_REPO_VERSION = "1.7-1"
DEFAULT_MODEL_LABEL = "xgb"
DEFAULT_IMPORTANCE_TYPE_XGB = "total_gain"
DEFAULT_IMPORTANCE_TYPE_LGB = "gain"
PACKAGE_CONFIG_FILENAME = "package_model_config.yaml"
class ModelConfig:
def __init__(
self,
cfg: Dict,
training_sample_config: Dict,
preprocessor_config: Dict,
sagemaker_session=None,
):
"""Take in a cfg dict and save the values setting defaults as necessary"""
# Remove config keys where value is None so `get` method falls back to defaults
cfg = dict_remove_none_values(cfg)
sagemaker_session = sagemaker_session or sagemaker.Session()
self.model_repo_name = cfg.get("model_repo_name") or DEFAULT_MODEL_REPO_NAME
self.model_repo_version = (
cfg.get("model_repo_version") or DEFAULT_MODEL_REPO_VERSION
)
self.model_label = cfg.get("model_label") or DEFAULT_MODEL_LABEL
self.train_instance_count = (
cfg.get("train_instance_count") or DEFAULT_TRAIN_INSTANCE_COUNT
)
self.train_instance_type = (
cfg.get("train_instance_type") or DEFAULT_TRAIN_INSTANCE_TYPE
)
self.train_volume_size = (
cfg.get("train_volume_size") or DEFAULT_TRAIN_VOLUME_SIZE # GB
)
self.train_max_run = (
cfg.get("train_max_run") or DEFAULT_TRAIN_MAX_RUN # seconds
)
self.train_hp_max_jobs = (
cfg.get("train_hp_max_jobs") or DEFAULT_TRAIN_HP_MAX_JOBS
)
self.train_hp_max_parallel_jobs = (
cfg.get("train_hp_max_parallel_jobs") or DEFAULT_TRAIN_HP_MAX_PARALLEL_JOBS
)
self.train_hp_objective_metric_name = (
cfg.get("train_hp_objective_metric_name")
or DEFAULT_TRAIN_HP_OBJECTIVE_METRIC_NAME
)
# Set the direction for this objective metric name
self.train_hp_objective_type = VALID_HP_OBJECTIVE_METRICS[self.model_repo_name][
self.train_hp_objective_metric_name
]["direction"]
self.train_hp_early_stopping_type = (
cfg.get("train_hp_early_stopping_type")
or DEFAULT_TRAIN_HP_EARLY_STOPPING_TYPE
)
self.train_hp_tuning_strategy = (
cfg.get("train_hp_tuning_strategy") or DEFAULT_TRAIN_HP_TUNING_STRATEGY
)
self.input_mode = cfg.get("input_mode") or "File"
self.image_name = cfg.get("image_name")
self.is_custom_image = True if self.image_name else False
self.metric_definitions = (
VALID_MODELS[self.model_repo_name]["metric_definitions"]
if self.is_custom_image
else None
)
self.output_file_format = cfg.get("output_file_format") or "csv"
self.training_job_name = cfg.get("training_job_name")
self.tuning_job_name = cfg.get("tuning_job_name")
self.best_training_job_name = cfg.get("best_training_job_name")
self.warm_start_tuning_job_name = cfg.get("warm_start_tuning_job_name")
self.base_job_name = (
f"{training_sample_config.model_type}"
f"-{self.model_label}"
f"-{training_sample_config.currency}"
f"-{training_sample_config.loan_type}"
)
def __iter__(self):
"""Iteration method is used to convert this object to things like a dict"""
return iter(vars(self).items())
def serialize(self):
return self.__dict__
@staticmethod
def load(
bucket: str = "",
parent_dir: str = "",
filename: str = None,
training_sample_config: Dict = None,
preprocessor_config: Dict = None,
sagemaker_session=None,
):
"""Load a model config file and return a ModelConfig object"""
if filename is None:
raise ValueError("Must provide filename")
if training_sample_config is None:
raise ValueError("Must provide training_sample_config")
if preprocessor_config is None:
raise ValueError("Must provide preprocessor_config")
filepath = os.path.join(parent_dir, filename)
if bucket:
logger.info(f"Loading model config: s3://{bucket}/{filepath}")
config = yaml.load(read_from_s3(bucket, filepath), Loader=yaml.FullLoader)
else:
logger.info(f"Loading model config: {filepath}")
with open(filepath, "r") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
return ModelConfig(
config, training_sample_config, preprocessor_config, sagemaker_session
)
@staticmethod
def save(
config: Dict,
bucket: str = "",
parent_dir: str = "",
filename: str = None,
allow_overwrite: bool = True,
):
"""Save the configuration to a yaml file"""
if filename is None:
raise ValueError("Must supply config filename")
filepath = os.path.join(parent_dir, filename)
save = True
if bucket:
s3 = boto3.client("s3")
if not allow_overwrite:
if "Contents" in s3.list_objects(Bucket=bucket, Prefix=filepath):
# file already exists
save = False
logger.warning(
"Not saving config file, already exists:"
f" s3://{bucket}/{filepath}"
)
if save:
logger.info(f"Saving config file: s3://{bucket}/{filepath}")
s3.put_object(
Body=yaml.dump(dict(config), default_flow_style=False),
Bucket=bucket,
Key=filepath,
)
else:
if not allow_overwrite:
if Path(filepath).exists():
save = False
logger.warning(
f"Not saving config file, already exists: {filepath}"
)
if save:
logger.info(f"Saving config file: {filepath}")
with open(filepath, "w") as fp:
fp.write(yaml.dump(dict(config), default_flow_style=False))
class CreditModel:
def __init__(
self,
bucket: str = "",
parent_dir: str = "",
config_filename: str = None,
training_sample_config: Dict = None,
preprocessor_config: Dict = None,
sagemaker_session=None,
valid_models: Dict = None,
):
if config_filename is None:
raise ValueError("Must provide config_filename")
self.config = ModelConfig.load(
bucket=bucket,
parent_dir=parent_dir,
filename=config_filename,
training_sample_config=training_sample_config,
preprocessor_config=preprocessor_config,
sagemaker_session=sagemaker_session,
)
self.training_sample_config = training_sample_config
self.preprocessor_config = preprocessor_config
self._valid_models = valid_models or VALID_MODELS
# get model filename
self.model_filename = self.get_model_filename()
def get_warm_start_config(self, tuning_job_name: Union[str, List[str]] = None):
"""Some restrictions to running a warm start tuning are mentioned here:
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-warm-start.html#warm-start-tuning-restrictions
TLDR; some important ones are:
- A tuning job can have a maximum of 5 parent jobs
- The objective metric used in the new tuning job must be the same as the
objective metric used in the parent jobs.
- Warm start tuning is not recursive.
"""
tuning_job_name = tuning_job_name or self.config.warm_start_tuning_job_name
if tuning_job_name:
if tuning_job_name == "auto":
tuning_job_name = self.get_recent_tuning_jobs()
if isinstance(tuning_job_name, str):
tuning_job_name = [tuning_job_name]
assert (
len(tuning_job_name) <= 5
), "A tuning job can have a maximum of 5 parent jobs"
warm_start_config = WarmStartConfig(
warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM,
parents={*tuning_job_name},
)
else:
warm_start_config = None
return warm_start_config
def get_recent_tuning_jobs(self, sagemaker_session=None):
sagemaker_session = sagemaker_session or sagemaker.Session()
response = sagemaker_session.sagemaker_client.list_hyper_parameter_tuning_jobs(
SortBy="CreationTime", SortOrder="Descending", MaxResults=5
)
tuning_jobs = [
job["HyperParameterTuningJobName"]
for job in response["HyperParameterTuningJobs"]
]
return tuning_jobs
def get_estimator(
self, sagemaker_session=None, role: str = None, hyperparameters: Dict = None
):
sagemaker_session = sagemaker_session or sagemaker.Session()
# Get a SageMaker-compatible role
role = role or sagemaker.get_execution_role(sagemaker_session=sagemaker_session)
hyperparameters = hyperparameters or {}
if self.config.model_repo_name not in self._valid_models:
raise ValueError(
f"model_repo_name {self.config.model_repo_name} is not supported."
" Only these are supported: "
",".join(self._valid_models),
)
if (
self.config.model_repo_version
not in self._valid_models[self.config.model_repo_name]["versions"]
):
raise ValueError(
f"{self.config.model_repo_name} model_repo_version"
f" {self.config.model_repo_version} is not supported. Only these"
" are supported: "
",".join(self._valid_models[self.config.model_repo_name]["versions"]),
)
if self.config.train_instance_type not in VALID_INSTANCE_TYPES_TRAIN:
raise ValueError(
f"train_instance_type {self.config.train_instance_type} is not"
" supported. Only these are supported: "
",".join(VALID_INSTANCE_TYPES_TRAIN),
)
image_name = self.config.image_name or image_uris.retrieve(
framework=self.config.model_repo_name,
region=sagemaker_session.boto_region_name,
version=self.config.model_repo_version,
)
estimator = sagemaker.estimator.Estimator(
image_name,
role,
instance_count=self.config.train_instance_count,
instance_type=self.config.train_instance_type,
volume_size=self.config.train_volume_size, # GB
max_run=self.config.train_max_run, # seconds
input_mode=self.config.input_mode,
output_path=self.training_sample_config.model_output_path,
base_job_name=self.config.base_job_name,
sagemaker_session=sagemaker_session,
)
estimator.set_hyperparameters(**hyperparameters)
return estimator
def fit_estimator(
self, estimator, dataset_train: str = None, dataset_validation: str = None
):
dataset_train = dataset_train or self.training_sample_config.train_set_label
dataset_validation = (
dataset_validation or self.training_sample_config.validation_set_label
)
distribution = "FullyReplicated"
content_type = CONTENT_TYPE_CSV
s3_data_type = "S3Prefix" # single file
# Set up data
train_data = sagemaker.inputs.TrainingInput(
self.training_sample_config.preprocessed_data_batch[dataset_train],
distribution=distribution,
content_type=content_type,
s3_data_type=s3_data_type,
)
data_channels = {"train": train_data}
if self.training_sample_config.preprocessed_data_batch[dataset_validation]:
validation_data = sagemaker.inputs.TrainingInput(
self.training_sample_config.preprocessed_data_batch[dataset_validation],
distribution=distribution,
content_type=content_type,
s3_data_type=s3_data_type,
)
data_channels["validation"] = validation_data
estimator.fit(inputs=data_channels)
self.config.training_job_name = estimator.latest_training_job.job_name
return estimator.latest_training_job.job_name
def attach_estimator(self, training_job_name: str = None, sagemaker_session=None):
training_job_name = training_job_name or self.config.training_job_name
return sagemaker.estimator.Estimator.attach(
training_job_name, sagemaker_session=sagemaker_session
)
def get_tuner(
self,
estimator,
objective_metric_name: str = None,
hyperparameter_ranges: Dict[str, sagemaker.parameter.ParameterRange] = None,
strategy: str = None,
objective_type: str = None,
max_jobs: int = None,
max_parallel_jobs: int = None,
early_stopping_type: str = None,
warm_start_config=None,
):
if self.config.model_repo_name not in VALID_HP_OBJECTIVE_METRICS:
raise ValueError(
f"model_repo_name {self.config.model_repo_name} not defined in"
" VALID_HP_OBJECTIVE_METRICS"
)
if (
self.config.train_hp_objective_metric_name
not in VALID_HP_OBJECTIVE_METRICS[self.config.model_repo_name]
):
raise ValueError(
"train_hp_objective_metric_name"
f" {self.config.train_hp_objective_metric_name} is not supported"
f" for {self.config.model_repo_name}. Only these are supported: "
",".join(VALID_HP_OBJECTIVE_METRICS[self.config.model_repo_name]),
)
if (
self.config.train_hp_early_stopping_type
not in VALID_HP_EARLY_STOPPING_TYPES
):
raise ValueError(
"train_hp_early_stopping_type"
f" {self.config.train_hp_early_stopping_type} is not supported."
" Only these are supported: "
",".join(VALID_HP_EARLY_STOPPING_TYPES),
)
if self.config.train_hp_tuning_strategy not in VALID_HP_TUNING_STRATEGIES:
raise ValueError(
"train_hp_tuning_strategy"
f" {self.config.train_hp_tuning_strategy} is not supported. Only"
" these are supported: "
",".join(VALID_HP_TUNING_STRATEGIES),
)
objective_metric_name = (
objective_metric_name or self.config.train_hp_objective_metric_name
)
if hyperparameter_ranges is None or len(hyperparameter_ranges) == 0:
raise ValueError("Need to specify hyperparameter ranges")
strategy = strategy or self.config.train_hp_tuning_strategy
objective_type = objective_type or self.config.train_hp_objective_type
max_jobs = max_jobs or self.config.train_hp_max_jobs
max_parallel_jobs = max_parallel_jobs or self.config.train_hp_max_parallel_jobs
early_stopping_type = (
early_stopping_type or self.config.train_hp_early_stopping_type
)
base_tuning_job_name = f"{self.config.base_job_name}-tuner"
tuner = HyperparameterTuner(
estimator,
objective_metric_name,
hyperparameter_ranges,
metric_definitions=self.config.metric_definitions,
strategy=strategy,
objective_type=objective_type,
max_jobs=max_jobs,
max_parallel_jobs=max_parallel_jobs,
# tags=None,
base_tuning_job_name=base_tuning_job_name,
warm_start_config=warm_start_config,
early_stopping_type=early_stopping_type,
# estimator_name=None,
)
return tuner
def fit_tuner(
self,
tuner,
dataset_train: str = None,
dataset_validation: str = None,
wait: bool = False,
):
dataset_train = dataset_train or self.training_sample_config.train_set_label
dataset_validation = (
dataset_validation or self.training_sample_config.validation_set_label
)
distribution = "FullyReplicated"
s3_data_type = "S3Prefix" # single file
if "csv" in self.preprocessor_config.preprocessed_file_format:
content_type = CONTENT_TYPE_CSV
elif "json" in self.preprocessor_config.preprocessed_file_format:
content_type = CONTENT_TYPE_JSON
# Set up data
train_data = sagemaker.inputs.TrainingInput(
self.training_sample_config.preprocessed_data_batch[dataset_train],
distribution=distribution,
content_type=content_type,
s3_data_type=s3_data_type,
)
validation_data = sagemaker.inputs.TrainingInput(
self.training_sample_config.preprocessed_data_batch[dataset_validation],
distribution=distribution,
content_type=content_type,
s3_data_type=s3_data_type,
)
data_channels = {
"train": train_data,
"validation": validation_data,
}
tuner.fit(
inputs=data_channels,
include_cls_metadata=False,
wait=wait,
)
self.config.tuning_job_name = tuner.latest_tuning_job.job_name
return tuner.latest_tuning_job.job_name
def attach_tuner(self, tuning_job_name: str = None, sagemaker_session=None):
tuning_job_name = tuning_job_name or self.config.tuning_job_name
return HyperparameterTuner.attach(
tuning_job_name, sagemaker_session=sagemaker_session
)
def get_tuning_job_info(
self, tuning_job_name: str = None, sagemaker_session=None
) -> pd.DataFrame:
tuning_job_name = tuning_job_name or self.config.tuning_job_name
if tuning_job_name is None:
logger.warning("No tuning job name provided or found in config")
return None
return sagemaker.analytics.HyperparameterTuningJobAnalytics(
tuning_job_name,
sagemaker_session=sagemaker_session,
).dataframe()
def get_training_job_info(
self,
training_job_name: str = None,
minimize_objective: bool = None,
require_completed: bool = None,
use_config_first: bool = None,
sagemaker_session=None,
) -> pd.DataFrame:
training_job_name = training_job_name or self.get_best_training_job_name(
tuning_job_name=self.config.tuning_job_name,
minimize_objective=minimize_objective,
require_completed=require_completed,
use_config_first=use_config_first,
)
if training_job_name is None:
logger.warning(
"No best training job found in tuning job: "
f"{self.config.tuning_job_name}"
)
return None
return sagemaker.analytics.TrainingJobAnalytics(
training_job_name=training_job_name,
sagemaker_session=sagemaker_session,
).dataframe()
Here is a dummy script I was trying out to achieve this:
from datetime import datetime, timedelta, timezone
import boto3
import sagemaker
sagemaker_session = sagemaker.Session()
base_job_name = "credit-xgb-INR-repea"
name_contains = base_job_name
# Search jobs from the last 30 days
creation_time_after = datetime.now(tz=timezone.utc) - timedelta(days=30)
creation_time_after = creation_time_after.strftime("%Y-%m-%dT%H:%M:%SZ")
def get_recent_tuning_jobs(name_contains: str):
sagemaker_client = boto3.client(
"sagemaker", region_name=sagemaker_session.boto_region_name
)
# Get all the relevant tuning jobs that were created in the last 30 days
tuning_jobs = []
# MaxResults seems to limit the number of tuning jobs searched?
response = sagemaker_session.sagemaker_client.list_hyper_parameter_tuning_jobs(
SortBy="CreationTime",
SortOrder="Descending",
MaxResults=100,
NameContains=name_contains[:20],
StatusEquals="Completed",
CreationTimeAfter=creation_time_after,
)
if not response["HyperParameterTuningJobSummaries"]:
print("No tuning jobs found")
return tuning_jobs
if "NextToken" in response:
next_token = response["NextToken"]
while True:
response = sagemaker_session.sagemaker_client.list_hyper_parameter_tuning_jobs(
NextToken=next_token,
SortBy="CreationTime",
SortOrder="Descending",
MaxResults=100,
NameContains=name_contains,
StatusEquals="Completed",
CreationTimeAfter=creation_time_after,
)
if "NextToken" in response:
next_token = response["NextToken"]
else:
print("No next token, no tuning jobs found")
break
if not response["HyperParameterTuningJobSummaries"]:
print("No tuning jobs found")
break
for job in response["HyperParameterTuningJobSummaries"]:
name = job["HyperParameterTuningJobName"]
description = sagemaker_client.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=name
)
criteria_met = True
# Add code to check `decription` for various restrictions
# https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-warm-start.html#warm-start-tuning-restrictions
# you'll need to pass a few more arguments into get_warm_start_config:
# hyperparameter_ranges
# max_jobs
# important restrictions to check:
# data is the same (InputDataConfig)
# HyperParameterTuningJobObjective
# count of static plus tunable hyperparameters
# type of each hyperparameter (continuous, integer, categorical)
# number of total changes to HPs
# as you iterate over each job, you'll need to make sure the total number of training jobs doesn't exceed 500. so count the training jobs for current tuning job (max_jobs) plus number of training jobs for previous tuning job (ResourceLimits -> MaxNumberOfTrainingJobs) and break if you're going to exceed 500
if criteria_met:
tuning_jobs.append(job["HyperParameterTuningJobName"])
return tuning_jobs
get_recent_tuning_jobs(name_contains)