I don’t understand how to pass the hidden_states of the ViTMAEModel encoder into the Unet Decoder. I saw a visual on how a version of what I’m trying to do is done and it involves “reshaping”. How do I pass in the hidden states? How do I reshape? Is it like the unpatchify function where I get rid of the cls token? What do I do with the CLS token?
from transformers import ViTMAEModel, ViTMAEConfig, ViTMAEForPreTraining, AutoImageProcessor
import torch
import torch.nn as nn
from torchvision.utils import save_image
from PIL import Image
#No segmentation head yet
class MAEEncoderWithSkipConnections(ViTMAEModel):
def __init__(self,config):
super().__init__(config)
self.config = config
self.vit = ViTMAEModel.from_pretrained("facebook/vit-mae-base", config=config)
def forward(self, pixel_values):
outputs = self.vit(pixel_values, output_hidden_states=True,)
hidden_states = outputs.hidden_states
feature_maps = []
patch_size = self.config.patch_size #unused
img_size = self.config.image_size #unused
print(hidden_states[2])
for hs in hidden_states:
hs = hs[:,1:,:]
hs = hs.permute(2,0,1).contiguous()
#print(hs)
#num_patches_per_dim = img_size//patch_size
print(hs.shape) #torch.Size([768, 1, 196])
feature_maps.append(hs)
return feature_maps
config = ViTMAEConfig.from_pretrained("facebook/vit-mae-base",mask_ratio = 0.0)
mae_encoder = MAEEncoderWithSkipConnections(config)
class UNetDecoder(nn.Module):
def __init__(self,config):
super(UNetDecoder, self).__init__()
#self.config = config
self.up1 = nn.ConvTranspose2d(in_channels=768,out_channels=512,kernel_size=2,stride=2)
self.up2 = nn.ConvTranspose2d(512,256,kernel_size=2,stride=2)
self.up3 = nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)
self.up4 = nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
self.out_conv = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, skip_connection):
x = self.up1(skip_connection[-1])
#print(x)
#print(x.shape)
x = torch.cat([x,skip_connection[-3]],dim=1)
x = self.up2(x)
x = torch.cat([x,skip_connection[-6]],dim=1)
x = self.up3(x)
x = torch.cat([x,skip_connection[-9]],dim=1)
x = self.up4(x)
x = self.out_conv(x)
print(x)
return x
class MAEUnet(nn.Module):
def __init__(self,encoder,decoder):
super(MAEUnet,self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self,x):
skip_connections = self.encoder(x)
output = self.decoder(skip_connections)
return output
img = Image.open("/vast/home/mayolo/Downloads/mayolos_face.jpg").convert("RGB")
# Has Dataset STD and mean
processor = AutoImageProcessor.from_pretrained('facebook/vit-mae-base')
mae_model = ViTMAEModel.from_pretrained("facebook/vit-mae-base", config=config)
unet_model = MAEUnet(mae_encoder, UNetDecoder(config))
input_data = torch.randn(1,3,224,224)
input_img = processor(img, return_tensors="pt")
output = unet_model(input_img.pixel_values)
save_image(output,'outputimg.png')
Error:
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 512 but got size 768 for tensor number 1 in the list.
Mayolo is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.