According to this comment in the huggingface/peft package, if a model is loaded in fp16, the trainable weights must be cast to fp32. From this comment, I understand that generally, the torch_dtype
used when loading a model and the precision used for training must be different. Why is it necessary to change the precision? Also, does this principle apply to both fine-tuning and continual pretraining?
As a minimal working example, I’m attempting to perform a continual pretraining on microsoft/Phi-3-mini-128k-instruct, whose default torch_dtype
is bfloat16. When loading the model with torch_dtype=torch.float16
, training commenced when the precision for trainable weights was set to TrainingArguments(fp16=False, bf16=True)
(i.e. different precision for model loading and trainable weights). However, when the precision for trainable weights was set to TrainingArguments(fp16=True, bf16=False)
(i.e. same precision for model loading and trainable weights), an error raise ValueError("Attempting to unscale FP16 gradients.")
occurred, preventing the start of training. The execution environment was an NVIDIA RTX3060 with only 12GB of vRAM. For continual pretraining of the Phi-3 model, how should the torch_dtype
be set when loading the model and for trainable weights to minimize vRAM usage? For instance, should the model be loaded with torch_dtype=fp32
and the precision for trainable weights set to TrainingArguments(fp16=True, bf16=False)
, or should the model be loaded with load_in_8bit
and the precision for trainable weights set to TrainingArguments(fp16=True, bf16=False)
? I would like to know effective and feasible combinations.
MWE
train_deepspeed.py
import argparse
import os
import warnings
from typing import Dict, List
import deepspeed
import torch
from datasets import load_dataset
from omegaconf import OmegaConf
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizer,
Trainer,
TrainingArguments,
)
import gc
from utils import seed_everything
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def preprocess_function(
examples: Dict[str, List[str]],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Dict[str, List[int]]:
inputs = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=max_length,
)
inputs["labels"] = inputs.input_ids.copy()
return inputs
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_config",
"-p",
type=str,
default="./configs/train_configs/train_base.yaml",
)
parser.add_argument(
"--local_rank",
"-l",
type=int,
default=0,
)
args = parser.parse_args()
config = OmegaConf.load(args.train_config)
# distributed learning
deepspeed.init_distributed()
# set seed
seed_everything(config.seed)
# load model
model = AutoModelForCausalLM.from_pretrained(
config.model.model,
torch_dtype=torch.float16,
use_cache=config.model.use_cache,
device_map={"": 0},
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(
config.model.tokenizer,
add_eos_token=True,
)
# load dataset
dataset = load_dataset(
path=config.dataset.path,
name=config.dataset.subset,
split=config.dataset.split,
cache_dir=config.dataset.cache_dir,
)
# transform dataset
dataset = dataset.map(
lambda examples: preprocess_function(
examples, tokenizer, config.model.max_length
),
batched=True,
remove_columns=dataset.column_names,
num_proc=32,
)
dataset = dataset.train_test_split(test_size=0.2)
# initiate training
training_args = TrainingArguments(**config.train)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
# data_collator=data_collator,
)
with torch.autocast("cuda"):
trainer.train()
del dataset
del trainer
gc.collect()
deepspeed.runtime.utils.empty_cache()
torch.cuda.empty_cache()
if __name__ == "__main__":
main()
train_base.yaml
model:
model: microsoft/Phi-3-mini-128k-instruct
tokenizer: microsoft/Phi-3-mini-128k-instruct
use_cache: False
max_length: 512
train:
output_dir: ./outputs
evaluation_strategy: steps
logging_strategy: steps
save_strategy: steps
learning_rate: 1e-6
num_train_epochs: 3
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
gradient_accumulation_steps: 256 # per_device_train_bath_size*gradient_accumulation_steps=256
gradient_checkpointing: True
weight_decay: 0.01
warmup_ratio: 0.1
optim: adamw_bnb_8bit # adamw_torch
fp16: False
bf16: True
dataloader_num_workers: 1
eval_steps: 50
save_steps: 100
logging_steps: 5
run_name: test
save_total_limit: 2
save_on_each_node: False
neftune_noise_alpha: 5 # NEFTTune
# deepspeed: ./configs/deepspeed/ds_config_zero2.json
report_to: wandb
torch_compile: True
logging_dir: ./outputs/log
seed: 42
dataset:
path: hotchpotch/wikipedia-ja-20231030
subset: chunked #!!null
split: train
cache_dir: /mnt/d/huggingface/datasets
pyproject.toml
[tool.poetry]
name = "continual-pretrain"
version = "0.1.0"
description = ""
authors = ["Carlos Luis Rivera"]
license = "MIT"
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
fsspec = "2024.3.1"
datasets = "^2.19.2"
accelerate = "^0.31.0"
aiohttp = "^3.9.5"
aiosignal = "^1.3.1"
annotated-types = "^0.7.0"
appdirs = "^1.4.4"
async-timeout = "^4.0.3"
attrs = "^23.2.0"
bitsandbytes = "^0.43.1"
certifi = "^2024.6.2"
charset-normalizer = "^3.3.2"
click = "^8.1.7"
deepspeed = "^0.14.2"
dill = "^0.3.8"
docker-pycreds = "^0.4.0"
docstring-parser = "^0.16"
filelock = "^3.14.0"
frozenlist = "^1.4.1"
gitdb = "^4.0.11"
gitpython = "^3.1.43"
hjson = "^3.1.0"
huggingface-hub = "^0.23.3"
idna = "^3.7"
jinja2 = "^3.1.4"
markdown-it-py = "^3.0.0"
markupsafe = "^2.1.5"
mdurl = "^0.1.2"
mpmath = "^1.3.0"
multidict = "^6.0.5"
multiprocess = "^0.70.16"
networkx = "^3.3"
ninja = "^1.11.1.1"
numpy = "^1.26.4"
nvidia-ml-py = "^12.555.43"
packaging = "^24.0"
pandas = "^2.2.2"
peft = "0.6.0"
protobuf = "<5.0.0"
psutil = "^5.9.8"
py-cpuinfo = "^9.0.0"
pyarrow = "^16.1.0"
pyarrow-hotfix = "^0.6"
pydantic = "^2.7.3"
pydantic-core = "^2.18.4"
pygments = "^2.18.0"
pynvml = "^11.5.0"
python-dateutil = "^2.9.0.post0"
pytz = "^2024.1"
pyyaml = "^6.0.1"
regex = "^2024.5.15"
requests = "^2.32.3"
rich = "^13.7.1"
safetensors = "^0.4.3"
scipy = "^1.13.1"
sentencepiece = "^0.2.0"
sentry-sdk = "^2.5.1"
setproctitle = "^1.3.3"
shtab = "^1.7.1"
six = "^1.16.0"
smmap = "^5.0.1"
sympy = "^1.12.1"
tokenizers = "^0.19.1"
tqdm = "^4.66.4"
transformers = "^4.41.2"
trl = "^0.9.4"
typing-extensions = "^4.12.2"
tyro = "^0.8.4"
tzdata = "^2024.1"
urllib3 = "^2.2.1"
wandb = "^0.17.1"
xxhash = "^3.4.1"
yarl = "^1.9.4"
omegaconf = "^2.3.0"
llama-cpp-python = { version = "^0.2.77", source = "llama_cpp_python_cu121" }
torch = { version = "^2.3.1+cu121", source = "torch_cu121" }
nvidia-cublas-cu12 = { version = "^12.1.3.1", source = "torch_cu121" }
nvidia-cuda-cupti-cu12 = { version = "^12.1.105", source = "torch_cu121" }
nvidia-cuda-nvrtc-cu12 = { version = "^12.1.105", source = "torch_cu121" }
nvidia-cuda-runtime-cu12 = { version = "^12.1.105", source = "torch_cu121" }
nvidia-cudnn-cu12 = { version = "^8.9.2.26", source = "torch_cu121" }
nvidia-cufft-cu12 = { version = "^11.0.2.54", source = "torch_cu121" }
nvidia-curand-cu12 = { version = "^10.3.2.106", source = "torch_cu121" }
nvidia-cusolver-cu12 = { version = "^11.4.5.107", source = "torch_cu121" }
nvidia-cusparse-cu12 = { version = "^12.1.0.106", source = "torch_cu121" }
nvidia-nccl-cu12 = { version = "^2.20.5", source = "torch_cu121" }
nvidia-nvtx-cu12 = { version = "^12.1.105", source = "torch_cu121" }
optimum = "^1.20.0"
tensorboard = "^2.17.0"
wheel = "^0.43.0"
pytorch-triton = { version = "^2.3.0", source = "torch_cu121" }
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
flake8 = "^7.0.0"
ipykernel = "^6.29.4"
ipywidgets = "^8.1.3"
seedir = "^0.4.2"
emoji = "^2.12.1"
nbformat = "^5.10.4"
nbclient = "^0.10.0"
nbconvert = "^7.16.4"
[[tool.poetry.source]]
name = "torch_cu121"
url = "https://download.pytorch.org/whl/cu121"
priority = "explicit"
[[tool.poetry.source]]
name = "llama_cpp_python_cu121"
url = "https://abetlen.github.io/llama-cpp-python/whl/cu121"
priority = "explicit"
[[tool.poetry.source]]
name = "torch_nightly_cu121"
url = "https://download.pytorch.org/whl/nightly/cu121/"
priority = "explicit"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"