I’m trying to extract past key, value pair using attention_layers and hidden_state for a particular layer
import torch
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers import LlamaModel, LlamaTokenizer, LlamaForCausalLM
tokenizer = LlamaTokenizer.from_pretrained(path_to_llama2)
# Load the configuration and enable required outputs
config = LlamaConfig.from_pretrained(path_to_llama2)
config.output_hidden_states = True
config.output_attentions = True # To get self_attn_weights and biases if needed
config.use_cache = True # To get past_key_values
model = LlamaForCausalLM.from_pretrained(path_to_llama2, config=config)
model.eval()
input_text = "Once upon a time"
inputs = tokenizer(input_text, return_tensors='pt')
outputs = model(**inputs)
hidden_states = outputs.hidden_states # List of hidden states from each layer
state_dict = model.state_dict()
# Function to compute past_key_values for a single layer
def compute_past_key_values_for_layer(layer_idx, hidden_state):
attention_layers = [layer.self_attn for layer in model.model.layers]
W_q = state_dict[f'model.layers.{layer_idx}.self_attn.q_proj.weight']
W_k = state_dict[f'model.layers.{layer_idx}.self_attn.k_proj.weight']
W_v = state_dict[f'model.layers.{layer_idx}.self_attn.v_proj.weight']
queries = torch.matmul(hidden_state, W_q.T)
keys = torch.matmul(hidden_state, W_k.T)
values = torch.matmul(hidden_state, W_v.T)
batch_size, seq_length, hidden_dim = hidden_state.size()
num_attention_heads = attention_layers[layer_idx].num_heads
head_dim = hidden_dim // num_attention_heads
keys = keys.view(batch_size, seq_length, num_attention_heads, head_dim)
keys = keys.permute(0, 2, 1, 3)
values = values.view(batch_size, seq_length, num_attention_heads, head_dim)
values = values.permute(0, 2, 1, 3)
return keys, values
past_key_values = []
for i, hidden_state in enumerate(hidden_states[1:]): # Skip the embedding layer
keys, values = compute_past_key_values_for_layer(i, hidden_state)
past_key_values.append((keys, values))
past_key_values = tuple(past_key_values)
but these past_key_values don’t match with the values I get from outputs.past_key_values for the particular layer.
why’s it happening? are there any suggestions?