Here are the main functions for this. Basically, I need to calculate the surprisal for each of two options given the prompt.
The problem is that Llama-3-8B-Instruct model does not return the proper scores, only -inf values, while other models do.
Here’s the code. I’d greatly appreciate help.
def load_model(model_name):
if 'llama' in model_name.lower() or 'mistral' in model_name.lower():
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, token=hf_token).to(device)
print(f"Successfully loaded model ({model_name})")
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, token=hf_token)
print(f"Successfully loaded tokenizer ({model_name})")
elif 'olmo' in model_name.lower():
model = OLMoForCausalLM.from_pretrained(args.model_name).to(device)
print(f"Successfully loaded model ({model_name})")
tokenizer = OLMoTokenizerFast.from_pretrained(args.model_name)
print(f"Successfully loaded tokenizer ({model_name})")
return model, tokenizer
def calculate_surprisal(model, tokenizer, prompt, context, options, label, **kwargs):
full_context = f"{prompt}. {context} ({options[0]}/{options[1]})."
# Encode the full context using the tokenizer
input_ids = tokenizer.encode(full_context, return_tensors='pt', add_special_tokens=False)
print(input_ids)
surprisals = {}
option_ids_list = []
for option in options:
# Encode the option using the tokenizer
option_tokens = tokenizer(option, return_tensors="pt", add_special_tokens=False)
# Get the option's token IDs
option_ids = option_tokens.input_ids.squeeze().tolist()
option_ids_list.append(option_ids)
outputs = model.generate(
input_ids=input_ids.to('cuda'),
max_new_tokens=5,
output_scores=True,
num_return_sequences=1,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id
)
# Extract logit scores
if isinstance(outputs.scores, tuple):
logits = outputs.scores[0][0]
else:
logits = outputs.scores
# Calculate option logit scores
option_logits = []
for tokens in option_ids_list:
if isinstance(tokens, int):
option_logits.append(logits[tokens].item()) # Handle single token options
else:
pair_logits = sum(logits[token_id].item() for token_id in tokens) # Handle multi-token options (sum logit scores)
option_logits.append(pair_logits)
# Convert logits to probabilities using softmax
answer_logits_tensor = torch.tensor(option_logits)
probs = torch.nn.functional.softmax(torch.tensor(option_logits).unsqueeze(0)).squeeze(0).numpy()
# Calculate surprisal for each option
for i, prob in enumerate(probs):
surprisal = -math.log2(prob)
surprisals[options[i]] = surprisal
# Find the option with the lowest surprisal (most likely answer)
generated_response = min(surprisals, key=surprisals.get)
# Predict if the generated response matches the label
prediction = 1 if generated_response == label else 0
return generated_response, prediction, surprisals
**The output for Llama-3-8B-Instruct looks like this:
**
GenerateDecoderOnlyOutput(sequences=tensor([[25017, 279, 3072, 304, 279, 40029, 430, 1888, 45695, 279,
11914, 13, 578, 3838, 48841, 706, 459, 11245, 40902, 430,
374, 2216, 320, 396, 99594, 1113, 14, 20111, 570, 13220,
9, 510, 16, 60]], device=’cuda:0′), scores=(tensor([[-inf, -inf, -inf, …, -inf, -inf, -inf]], device=’cuda:0′), tensor([[-inf, -inf, -inf, …, -inf, -inf, -inf]], device=’cuda:0′), tensor([[-inf, -inf, -inf, …, -inf, -inf, -inf]], device=’cuda:0′), tensor([[-inf, -inf, -inf, …, -inf, -inf, -inf]], device=’cuda:0′), tensor([[-inf, -inf, -inf, …, -inf, -inf, -inf]], device=’cuda:0′)), logits=None, attentions=None, hidden_states=None, past_key_values=(and so on)..
**The output for Mistral-7B-Instruct-v0.3, e.g, is:
**
GenerateDecoderOnlyOutput(sequences=tensor([[37923, 253, 4500, 275, 253, 26609, 326, 1682, 29141, 253,
6197, 15, 3808, 952, 1333, 326, 247, 9479, 14018, 476,
313, 579, 249, 16, 7317, 481, 19810, 1333, 326, 247,
9950]], device=’cuda:0′), scores=(tensor([[-0.2477, -0.7886, 3.1052, …, -0.7890, -0.7896, -0.7898]],
device=’cuda:0′), tensor([[-0.8094, -0.9550, 5.2297, …, -0.9548, -0.9548, -0.9546]],
device=’cuda:0′), tensor([[-1.2616, -1.4411, 4.9028, …, -1.4407, -1.4408, -1.4418]],
device=’cuda:0′), tensor([[-0.8956, -0.8624, 2.6877, …, -0.8620, -0.8626, -0.8629]],
device=’cuda:0′), tensor([[-0.6234, -0.7014, 3.1742, …, -0.7013, -0.7019, -0.7020]],
device=’cuda:0′)), logits=None, attentions=None, hidden_states=None, past_key_values=(and so on)…
A Nh is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.