I am using the transformers library (Huggingface) to extract all hidden units of LLaVa 1.5. On the huggingface documentation, it shows that it is possible to extract image hidden states from the vision component.
Unfortunately, the outputs
object has only these following keys available in the output dictionary:
odict_keys(['sequences', 'attentions', 'hidden_states', 'past_key_values'])
How do I also extract the image_hidden_states
from this LLaVa implementation alongwith the exisiting outputs?
I have implemented the follow code in the hopes to do so.
import torch
from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig, AutoProcessor, LlavaProcessor
from PIL import Image
import requests
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_id = 'llava-hf/llava-1.5-7b-hf'
# Initializing a CLIP-vision config
vision_config = CLIPVisionConfig(output_hidden_states=True, output_attentions=True, return_dict=True)
# Initializing a Llama config
text_config = LlamaConfig(output_hidden_states=True, output_attentions=True, return_dict=True)
# Initializing a Llava llava-1.5-7b style configuration
configuration = LlavaConfig(vision_config, text_config, output_hidden_states=True, output_attentions=True, return_dict=True)
cfg=LlavaConfig(vision_config, text_config, output_hidden_states=True, output_attentions=True, return_dict=True)
# Initializing a model from the llava-1.5-7b style configuration
model = LlavaForConditionalGeneration(configuration).from_pretrained(model_id, output_hidden_states=True, output_attentions=True, return_dict=True)
# Accessing the model configuration
configuration = model.config
model=model.to(device)
print(summary(model))
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", output_hidden_states=True, output_attentions=True, return_dict=True)
prompt = "USER: <image>nIs there sun in the image? ASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
inputs=inputs.to(device)
with torch.no_grad():
outputs = model.generate(**inputs,
output_hidden_states=True,
return_dict_in_generate=True,
max_new_tokens=1,
min_new_tokens=1,
return_dict=True)
print(outputs.keys())