Is it necessary for torch_dtype when loading a model and the precision for trainable weights to be different? If so, why?

According to this comment in the huggingface/peft package, if a model is loaded in fp16, the trainable weights must be cast to fp32. From this comment, I understand that generally, the torch_dtype used when loading a model and the precision used for training must be different. Why is it necessary to change the precision? Also, does this principle apply to both fine-tuning and continual pretraining?

As a minimal working example, I’m attempting to perform a continual pretraining on microsoft/Phi-3-mini-128k-instruct, whose default torch_dtype is bfloat16. When loading the model with torch_dtype=torch.float16, training commenced when the precision for trainable weights was set to TrainingArguments(fp16=False, bf16=True) (i.e. different precision for model loading and trainable weights). However, when the precision for trainable weights was set to TrainingArguments(fp16=True, bf16=False) (i.e. same precision for model loading and trainable weights), an error raise ValueError("Attempting to unscale FP16 gradients.") occurred, preventing the start of training. The execution environment was an NVIDIA RTX3060 with only 12GB of vRAM. For continual pretraining of the Phi-3 model, how should the torch_dtype be set when loading the model and for trainable weights to minimize vRAM usage? For instance, should the model be loaded with torch_dtype=fp32 and the precision for trainable weights set to TrainingArguments(fp16=True, bf16=False), or should the model be loaded with load_in_8bit and the precision for trainable weights set to TrainingArguments(fp16=True, bf16=False)? I would like to know effective and feasible combinations.

MWE

train_deepspeed.py

import argparse
import os
import warnings
from typing import Dict, List
import deepspeed
import torch
from datasets import load_dataset
from omegaconf import OmegaConf
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizer,
    Trainer,
    TrainingArguments,
)
import gc
from utils import seed_everything

warnings.filterwarnings("ignore")

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def preprocess_function(
    examples: Dict[str, List[str]],
    tokenizer: PreTrainedTokenizer,
    max_length: int,
) -> Dict[str, List[int]]:
    inputs = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=max_length,
    )
    inputs["labels"] = inputs.input_ids.copy()
    return inputs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train_config",
        "-p",
        type=str,
        default="./configs/train_configs/train_base.yaml",
    )
    parser.add_argument(
        "--local_rank",
        "-l",
        type=int,
        default=0,
    )
    args = parser.parse_args()

    config = OmegaConf.load(args.train_config)

    # distributed learning
    deepspeed.init_distributed()

    # set seed
    seed_everything(config.seed)

    # load model
    model = AutoModelForCausalLM.from_pretrained(
        config.model.model,
        torch_dtype=torch.float16,
        use_cache=config.model.use_cache,
        device_map={"": 0},
        attn_implementation="flash_attention_2",
    )
    tokenizer = AutoTokenizer.from_pretrained(
        config.model.tokenizer,
        add_eos_token=True,
    )

    # load dataset
    dataset = load_dataset(
        path=config.dataset.path,
        name=config.dataset.subset,
        split=config.dataset.split,
        cache_dir=config.dataset.cache_dir,
    )

    # transform dataset
    dataset = dataset.map(
        lambda examples: preprocess_function(
            examples, tokenizer, config.model.max_length
        ),
        batched=True,
        remove_columns=dataset.column_names,
        num_proc=32,
    )

    dataset = dataset.train_test_split(test_size=0.2)

    # initiate training
    training_args = TrainingArguments(**config.train)
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        args=training_args,
        # data_collator=data_collator,
    )

    with torch.autocast("cuda"):
        trainer.train()

    del dataset
    del trainer

    gc.collect()
    deepspeed.runtime.utils.empty_cache()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    main()

train_base.yaml

model:
  model: microsoft/Phi-3-mini-128k-instruct
  tokenizer: microsoft/Phi-3-mini-128k-instruct
  use_cache: False
  max_length: 512


train:
  output_dir: ./outputs
  evaluation_strategy: steps
  logging_strategy: steps
  save_strategy: steps
  learning_rate: 1e-6
  num_train_epochs: 3
  per_device_train_batch_size: 1
  per_device_eval_batch_size: 1
  gradient_accumulation_steps: 256 # per_device_train_bath_size*gradient_accumulation_steps=256
  gradient_checkpointing: True
  weight_decay: 0.01
  warmup_ratio: 0.1
  optim: adamw_bnb_8bit # adamw_torch
  fp16: False
  bf16: True
  dataloader_num_workers: 1
  eval_steps: 50
  save_steps: 100
  logging_steps: 5
  run_name: test
  save_total_limit: 2
  save_on_each_node: False
  neftune_noise_alpha: 5 # NEFTTune
  # deepspeed: ./configs/deepspeed/ds_config_zero2.json
  report_to: wandb
  torch_compile: True
  logging_dir: ./outputs/log
  
seed: 42

dataset:
  path: hotchpotch/wikipedia-ja-20231030
  subset: chunked #!!null
  split: train
  cache_dir: /mnt/d/huggingface/datasets

pyproject.toml

[tool.poetry]
name = "continual-pretrain"
version = "0.1.0"
description = ""
authors = ["Carlos Luis Rivera"]
license = "MIT"
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.11"
fsspec = "2024.3.1"
datasets = "^2.19.2"
accelerate = "^0.31.0"
aiohttp = "^3.9.5"
aiosignal = "^1.3.1"
annotated-types = "^0.7.0"
appdirs = "^1.4.4"
async-timeout = "^4.0.3"
attrs = "^23.2.0"
bitsandbytes = "^0.43.1"
certifi = "^2024.6.2"
charset-normalizer = "^3.3.2"
click = "^8.1.7"
deepspeed = "^0.14.2"
dill = "^0.3.8"
docker-pycreds = "^0.4.0"
docstring-parser = "^0.16"
filelock = "^3.14.0"
frozenlist = "^1.4.1"
gitdb = "^4.0.11"
gitpython = "^3.1.43"
hjson = "^3.1.0"
huggingface-hub = "^0.23.3"
idna = "^3.7"
jinja2 = "^3.1.4"
markdown-it-py = "^3.0.0"
markupsafe = "^2.1.5"
mdurl = "^0.1.2"
mpmath = "^1.3.0"
multidict = "^6.0.5"
multiprocess = "^0.70.16"
networkx = "^3.3"
ninja = "^1.11.1.1"
numpy = "^1.26.4"
nvidia-ml-py = "^12.555.43"
packaging = "^24.0"
pandas = "^2.2.2"
peft = "0.6.0"
protobuf = "<5.0.0"
psutil = "^5.9.8"
py-cpuinfo = "^9.0.0"
pyarrow = "^16.1.0"
pyarrow-hotfix = "^0.6"
pydantic = "^2.7.3"
pydantic-core = "^2.18.4"
pygments = "^2.18.0"
pynvml = "^11.5.0"
python-dateutil = "^2.9.0.post0"
pytz = "^2024.1"
pyyaml = "^6.0.1"
regex = "^2024.5.15"
requests = "^2.32.3"
rich = "^13.7.1"
safetensors = "^0.4.3"
scipy = "^1.13.1"
sentencepiece = "^0.2.0"
sentry-sdk = "^2.5.1"
setproctitle = "^1.3.3"
shtab = "^1.7.1"
six = "^1.16.0"
smmap = "^5.0.1"
sympy = "^1.12.1"
tokenizers = "^0.19.1"
tqdm = "^4.66.4"
transformers = "^4.41.2"
trl = "^0.9.4"
typing-extensions = "^4.12.2"
tyro = "^0.8.4"
tzdata = "^2024.1"
urllib3 = "^2.2.1"
wandb = "^0.17.1"
xxhash = "^3.4.1"
yarl = "^1.9.4"
omegaconf = "^2.3.0"
llama-cpp-python = { version = "^0.2.77", source = "llama_cpp_python_cu121" }
torch = { version = "^2.3.1+cu121", source = "torch_cu121" }
nvidia-cublas-cu12 = { version = "^12.1.3.1", source = "torch_cu121" }
nvidia-cuda-cupti-cu12 = { version = "^12.1.105", source = "torch_cu121" }
nvidia-cuda-nvrtc-cu12 = { version = "^12.1.105", source = "torch_cu121" }
nvidia-cuda-runtime-cu12 = { version = "^12.1.105", source = "torch_cu121" }
nvidia-cudnn-cu12 = { version = "^8.9.2.26", source = "torch_cu121" }
nvidia-cufft-cu12 = { version = "^11.0.2.54", source = "torch_cu121" }
nvidia-curand-cu12 = { version = "^10.3.2.106", source = "torch_cu121" }
nvidia-cusolver-cu12 = { version = "^11.4.5.107", source = "torch_cu121" }
nvidia-cusparse-cu12 = { version = "^12.1.0.106", source = "torch_cu121" }
nvidia-nccl-cu12 = { version = "^2.20.5", source = "torch_cu121" }
nvidia-nvtx-cu12 = { version = "^12.1.105", source = "torch_cu121" }
optimum = "^1.20.0"
tensorboard = "^2.17.0"
wheel = "^0.43.0"
pytorch-triton = { version = "^2.3.0", source = "torch_cu121" }

[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
flake8 = "^7.0.0"
ipykernel = "^6.29.4"
ipywidgets = "^8.1.3"
seedir = "^0.4.2"
emoji = "^2.12.1"
nbformat = "^5.10.4"
nbclient = "^0.10.0"
nbconvert = "^7.16.4"


[[tool.poetry.source]]
name = "torch_cu121"
url = "https://download.pytorch.org/whl/cu121"
priority = "explicit"


[[tool.poetry.source]]
name = "llama_cpp_python_cu121"
url = "https://abetlen.github.io/llama-cpp-python/whl/cu121"
priority = "explicit"


[[tool.poetry.source]]
name = "torch_nightly_cu121"
url = "https://download.pytorch.org/whl/nightly/cu121/"
priority = "explicit"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa Dịch vụ tổ chức sự kiện 5 sao Thông tin về chúng tôi Dịch vụ sinh nhật bé trai Dịch vụ sinh nhật bé gái Sự kiện trọn gói Các tiết mục giải trí Dịch vụ bổ trợ Tiệc cưới sang trọng Dịch vụ khai trương Tư vấn tổ chức sự kiện Hình ảnh sự kiện Cập nhật tin tức Liên hệ ngay Thuê chú hề chuyên nghiệp Tiệc tất niên cho công ty Trang trí tiệc cuối năm Tiệc tất niên độc đáo Sinh nhật bé Hải Đăng Sinh nhật đáng yêu bé Khánh Vân Sinh nhật sang trọng Bích Ngân Tiệc sinh nhật bé Thanh Trang Dịch vụ ông già Noel Xiếc thú vui nhộn Biểu diễn xiếc quay đĩa Dịch vụ tổ chức tiệc uy tín Khám phá dịch vụ của chúng tôi Tiệc sinh nhật cho bé trai Trang trí tiệc cho bé gái Gói sự kiện chuyên nghiệp Chương trình giải trí hấp dẫn Dịch vụ hỗ trợ sự kiện Trang trí tiệc cưới đẹp Khởi đầu thành công với khai trương Chuyên gia tư vấn sự kiện Xem ảnh các sự kiện đẹp Tin mới về sự kiện Kết nối với đội ngũ chuyên gia Chú hề vui nhộn cho tiệc sinh nhật Ý tưởng tiệc cuối năm Tất niên độc đáo Trang trí tiệc hiện đại Tổ chức sinh nhật cho Hải Đăng Sinh nhật độc quyền Khánh Vân Phong cách tiệc Bích Ngân Trang trí tiệc bé Thanh Trang Thuê dịch vụ ông già Noel chuyên nghiệp Xem xiếc khỉ đặc sắc Xiếc quay đĩa thú vị
Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa
Thiết kế website Thiết kế website Thiết kế website Cách kháng tài khoản quảng cáo Mua bán Fanpage Facebook Dịch vụ SEO Tổ chức sinh nhật