I was trying to replicate a math reasoning model using RAG.
I want to run the predict stage of its retrieval model.
So,I first write a py script to transfer the source model file from .safetensors to .ckpt.
then, I config the transfered .ckpt file as the checkpoint(because the model is Pytorch Lightning module)
then it says:
self.trainer.strategy.load_model_state_dict(
15944 File "/root/miniconda3/envs/lean/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 371, in load_model_state_dict
15945 self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
15946 File "/root/miniconda3/envs/lean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
15947 raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
15948 RuntimeError: Error(s) in loading state_dict for PremiseRetriever:
15949 Missing key(s) in state_dict:....
...
Unexpected key(s) in state_dict: ....
red tide is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
2
I observe the missing part and unexpected part,and find that the missing part key has a “encoder” prefix, and has an extra key named “encoder.encoder.embed_tokens.weight”.besides these,they are almost the same. for exmaple,if there is an unexpected key”encoder.block.0.layer.0.SelfAttention.k.weight”,then there is a missing key “encoder.encoder.block.0.layer.0.SelfAttention.k.weight”.
so I add the prefix before I load the weight into static_dict.and the final script is as follows:
# Check if the input .safetensors file exists
if not os.path.exists(safetensors_path):
print(f'Error: The .safetensors file does not exist: {safetensors_path}')
sys.exit(1)
# Check if the directory for the output .ckpt file exists
ckpt_dir = os.path.dirname(ckpt_path)
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
print(f'Loading weights from: {safetensors_path}')
try:
# Load the .safetensors file
weights = load_file(safetensors_path, device='cpu')
# # Ensure weights are in state_dict format
# if isinstance(weights, dict) and 'state_dict' in weights:
# state_dict = weights['state_dict']
# else:
# # If weights itself is a state_dict
nw={}
for key,value in weights:
nk="encoder."+key
nw[nk]=value
# Add prefix to state_dict keys
# Prepare the data to be saved in .ckpt format
# Adding state_dict under 'state_dicti' key to match typical .ckpt structure
checkpoint_data = {'state_dict': nw}
# Save the state_dict to .ckpt file
torch.save(checkpoint_data, ckpt_path)
print(f'Weights successfully saved to: {ckpt_path}')
but as I mentioned,there is a missing key,so still cant work
but the new error message says that I should pass a “strict=false” to the load_static_dict function.
but HOW?
I found this in pytorch_lightning/core/module.py:
@property
def strict_loading(self) -> bool:
"""Determines how Lightning loads this model using `.load_state_dict(..., strict=model.strict_loading)`."""
# We use None as the default internally to determine whether the user has set a value
return self._strict_loading in (None, True)
so I add just one line in the initialize of the model:
self._strict_loading = false
and it works!
it is really hard for me to figure out all this, and I hope it may help those who has the same problem as me.
I also see solutiontext like:
self.model.load_state_dict(dict([(n, p) for n, p in checkpoint['model'].items()]), strict=False)
but since the original code has defined functions to load checkpoint, I don’t know how to add this into it.
but add one parameter to the model is easy.
red tide is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.