I trying to fine tune Mistral model for 5 epochs. It shows it take 72 hours with 8 bit quantized fine tuning but 48 hours with just original mode fine tuning. Also memory footprint is higher for 8bit quantized fine tuning. Below is the code where I am loading model for 8 bit quantization.
bnb_config = BitsAndBytesConfig(
load_in_8bit = True ,
llm_int8_enable_fp32_cpu_offload = True,
llm_int8_has_fp16_weight = False,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
Someone please explain why it is happening. Am I making any mistake?