I have fine-tuned llama-3 using qlora technique and wanted to do inference.
I use the below code for inference.
bnb_config = BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype= torch.bfloat16,
bnb_4bit_use_double_quant= False)
access_token = ''
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id, quantization_config = bnb_config,
device_map="auto", #use_auth_token=True/
token = access_token
)
peft_model = PeftModel.from_pretrained(base_model, path)
When I load the base model in 4-bit precision, the VRAM usage is around 6 GB. However, when I load the PEFT model using the same function, the VRAM usage spikes to 19 GB. During fine-tuning, only 11% of the parameters are trainable.
I’m trying to understand why there’s such a large increase in VRAM usage after loading the adapter weights. Could it be because the QLoRA adapter weights are being loaded in 16-bit precision? Is there a way to load these adapter weights in 8-bit?
Any insights or suggestions would be greatly appreciated!