I am trying to load a model from a certain checkpoint and use it for inference. The checkpoint folder looks like this. How do I load the model in torch from this folder. The resources I could find are for loading from a checkpoint file, not a folder.
import whisper_timestamped as whisper
from transformers import AutoProcessor, WhisperForConditionalGeneration
from peft import prepare_model_for_kbit_training, LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model
from peft import PeftModel, PeftConfig
import torch
from datasets import Dataset, Audio
from transformers import AutoFeatureExtractor, WhisperModel
peft_model_id = "aben118/finetuned_model/checkpoint-3900"
language = "en"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
print(model)
model = model.merge_and_unload()
model.save_pretrained(<model_path>)
But it saves it in .safetensors
format. I want it to be a model that i can load using torch.load
.