I am trying to reproduce this project: https://github.com/feizc/Gradient-Free-Textual-Inversion,But I now have a problem:
Loading pipeline components...: 100%|█████████████| 6/6 [00:14<00:00, 2.43s/it]
convert text inversion: <sks> in id: 49408
load data length: 50
Traceback (most recent call last):
File "/root/autodl-tmp/train_inversion.py", line 298, in <module>
main()
File "/root/autodl-tmp/train_inversion.py", line 291, in main
fitnesses = pipeline.eval(solutions)
File "/root/autodl-tmp/train_inversion.py", line 108, in eval
z = z + self.init_text_inversion
RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1
And I found that the shape of z changes when I use different models. For example, when I use 512-base-ema.ckpt, it is [1024], and when I use sd-v1-4.ckpt, it becomes [768]. self_init_text_inversion remains unchanged.
In addition, as far as I know, the project’s initialize_inversion.py uses clip to initialize text and training data. Here I use openai/clip-vit-base-patch16.
1.The following is the initialization code:
(https://github.com/feizc/Gradient-Free-Textual-Inversion/blob/main/initialize_inversion.py)
"""
automatically initialize the textual inversion with CLIP and no-parameter cross-attention
使用 CLIP 和无参数交叉注意力自动初始化文本反转
"""
import torch
import os
import argparse
from PIL import Image
import torch.nn.functional as F
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor, CLIPTextModel
from utils import imagenet_template, automatic_subjective_classnames
def embedding_generate(model, tokenizer, text_encoder, classnames, templates, device):
"""
pre-caculate the template sentence, token embeddings
预计算模板句子、令牌嵌入
"""
with torch.no_grad():
sentence_weights = []
token_weights = []
token_embedding_table = text_encoder.get_input_embeddings().weight.data
for classname in classnames:
texts = [template(classname) for template in templates] # format with class
texts = tokenizer(texts, padding="max_length", max_length=77, truncation=True, return_tensors="pt") # tokenize
texts = texts['input_ids'].to(device)
class_embeddings = model.get_text_features(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
sentence_weights.append(class_embedding)
token_ids = tokenizer.encode(classname,add_special_tokens=False)
token_embedding_list = []
for token_id in token_ids:
token_embedding_list.append(token_embedding_table[token_id])
token_weights.append(torch.mean(torch.stack(token_embedding_list), dim=0))
sentence_weights = torch.stack(sentence_weights, dim=1).to(device)
token_weights = torch.stack(token_weights, dim=0).to(device)
return sentence_weights, token_weights
def image_condition_embed_initialize(image_feature_list, sentence_embeddings, token_embeddings):
"""
no-parameter cross-attention: query: image, key: sentence, value: token
无参数交叉注意力:查询:图像,键:句子,值:令牌
"""
inversion_emb_list = []
for image_features in image_feature_list:
cross_attention = image_features @ sentence_embeddings
attention_probs = F.softmax(cross_attention, dim=-1)
inversion_emb = torch.matmul(attention_probs, token_embeddings)
inversion_emb_list.append(inversion_emb)
final_inversion = torch.mean(torch.stack(inversion_emb_list), dim=0)
final_inversion = final_inversion / final_inversion.norm()
return final_inversion
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--save_path", default='./save', type=str)
parser.add_argument("--data_path", default='./coconut_seed_fruit_stage', type=str)
args = parser.parse_args()
save_path = args.save_path
data_path = args.data_path
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = CLIPTokenizer.from_pretrained('/root/.cache/huggingface/patch16b')
model = CLIPModel.from_pretrained('/root/.cache/huggingface/patch16b').to(device)
text_encoder = CLIPTextModel.from_pretrained('/root/.cache/huggingface/patch16b')
processor = CLIPProcessor.from_pretrained('/root/.cache/huggingface/patch16b')
sentence_embeddings, token_embeddings = embedding_generate(model,
tokenizer,
text_encoder,
automatic_subjective_classnames,
imagenet_template,
device)
print('sentence embedding size: ', sentence_embeddings.size(), ' token embedding size: ', token_embeddings.size())
image_feature_list = []
name_list = os.listdir(data_path)
for name in name_list:
image_path = os.path.join(data_path, name)
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt").to(device)
image_features = model.get_image_features(**inputs)
image_features = F.normalize(image_features, dim=-1)
image_feature_list.append(image_features)
print('image size: ', len(image_feature_list))
inversion_emb = image_condition_embed_initialize(image_feature_list, sentence_embeddings, token_embeddings)
inversion_emb_dict = {"initialize": inversion_emb.detach().cpu()}
torch.save(inversion_emb_dict, os.path.join(save_path, 'initialize_emb.bin'))
if __name__ == "__main__":
main()
2.The following is the training code:
(https://github.com/feizc/Gradient-Free-Textual-Inversion/blob/main/train_inversion.py)
import cma
import argparse
import torch
import os
import numpy as np
import copy
from sklearn.decomposition import PCA
from diffusers import StableDiffusionPipeline, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from utils import TextualInversionDataset
from tqdm import tqdm
class GradientFreePipeline:
def __init__(self, model_path, args, init_text_inversion=None, ):
self.tokenizer = CLIPTokenizer.from_pretrained(
os.path.join(model_path, 'tokenizer')
)
self.text_encoder = CLIPTextModel.from_pretrained(
os.path.join(model_path, 'text_encoder')
)
self.pipe = StableDiffusionPipeline.from_pretrained(
model_path,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
).to(args.device)
if args.projection_modeling == 'prior_normal':
self.linear = torch.nn.Linear(args.intrinsic_dim, args.model_dim, bias=False).to(args.device)
embedding = self.text_encoder.get_input_embeddings().weight.clone().cpu()
mu_hat = np.mean(embedding.reshape(-1).detach().cpu().numpy())
std_hat = np.std(embedding.reshape(-1).detach().cpu().numpy())
mu = 0.0
std = args.alpha * std_hat / (np.sqrt(args.intrinsic_dim) * args.sigma)
# incorporate temperature factor
# temp = intrinsic_dim - std_hat * std_hat
# mu = mu_hat / temp
# std = std_hat / np.sqrt(temp)
print('[Embedding] mu: {} | std: {} [RandProj] mu: {} | std: {}'.format(mu_hat, std_hat, mu, std))
for p in self.linear.parameters():
torch.nn.init.normal_(p, mu, std)
elif args.projection_modeling == 'pca':
embedding = self.text_encoder.get_input_embeddings().weight.clone().cpu()
embedding = embedding.detach().cpu().numpy() # (49408, 768)
self.pca_model = PCA(n_components=args.intrinsic_dim)
self.pca_model.fit(embedding)
# Add the placeholder token in tokenizer
num_added_tokens = self.tokenizer.add_tokens(args.placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = self.tokenizer.encode(args.initializer_token, add_special_tokens=False)
initializer_token_id = token_ids[0]
placeholder_token_id = self.tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
print('convert text inversion: ', args.placeholder_token, 'in id: ', str(placeholder_token_id))
self.placeholder_token_id = placeholder_token_id
self.placeholder_token = args.placeholder_token
self.num_call = 0
train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
tokenizer=self.tokenizer,
size=args.resolution,
placeholder_token=args.placeholder_token,
repeats=args.repeats,
learnable_property=args.learnable_property,
center_crop=args.center_crop,
set="train",
)
self.dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.repeats, shuffle=True)
self.batch_size = args.repeats
self.device = args.device
print('load data length: ', len(self.dataloader))
# optimize incremental elements or original inversion
if init_text_inversion is not None:
self.init_text_inversion = init_text_inversion.to(args.device)
else:
self.init_text_inversion = token_embeds[initializer_token_id].to(args.device)
self.args = args
self.best_inversion = None
def eval(self, inversion_embedding):
self.num_call += 1
pe_list = []
if isinstance(inversion_embedding, list): # multiple queries
for pe in inversion_embedding:
if self.args.projection_modeling == 'prior_normal':
z = torch.tensor(pe).type(torch.float32).to(self.device) # z
with torch.no_grad():
z = self.linear(z) # W_p Q
if self.init_text_inversion is not None:
z = z + self.init_text_inversion # W_p Q + p_0
elif self.args.projection_modeling == 'pca':
z = self.pca_model.inverse_transform(pe) # project the original text embedding space
z = torch.tensor(z).type(torch.float32).to(self.device)
print(z.shape)
print(self.init_text_inversion.shape)
if self.init_text_inversion is not None:
z = z + self.init_text_inversion
pe_list.append(z)
elif isinstance(inversion_embedding, np.ndarray): # single query or None
if self.args.projection_modeling == 'prior_normal':
inversion_embedding = torch.tensor(inversion_embedding).type(torch.float32).to(self.device) # z
with torch.no_grad():
inversion_embedding = self.linear(inversion_embedding) # W_p Q
elif self.args.projection_modeling == 'pca':
inversion_embedding = self.pca_model.inverse_transform(inversion_embedding)
inversion_embedding = torch.tensor(inversion_embedding).type(torch.float32).to(self.device)
if self.init_text_inversion is not None:
inversion_embedding = inversion_embedding + self.init_text_inversion # W_p Q + p_0
pe_list.append(inversion_embedding)
else:
raise ValueError(
f'[Inversion Embedding] Only support [list, numpy.ndarray], got `{type(inversion_embedding)}` instead.'
)
loss_list = []
print('begin to calculate loss')
# fixed time step for fair evaluation 公平评估的固定时间步长
noise_scheduler = DDPMScheduler.from_config('./ckpt/scheduler')
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (self.batch_size,), device=self.device
).long()
best_loss = 1000
best_inversion = None
for pe in tqdm(pe_list):
token_embeds = self.text_encoder.get_input_embeddings().weight.data
pe.to(self.text_encoder.get_input_embeddings().weight.dtype)
token_embeds[self.placeholder_token_id] = pe
loss = calculate_mse_loss(self.pipe, self.dataloader, self.device, noise_scheduler, timesteps)
if loss < best_loss:
best_loss = loss
best_inversion = pe
loss_list.append(loss)
# update total point
self.best_inversion = best_inversion
return loss_list
def save(self, output_path):
learned_embeds_dict = {self.placeholder_token: self.best_inversion.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(output_path, "learned_embeds.bin"))
def calculate_mse_loss(image_generator, dataloader, device, noise_scheduler, timesteps):
# print(image_generator.text_encoder.get_input_embeddings().weight.data[49408])
loss_cum = .0
with torch.no_grad():
for batch in dataloader:
# Convert images to latent space
latents = image_generator.vae.encode(batch["pixel_values"].to(device)).latent_dist.sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
# Sample a random timestep for each image
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = image_generator.text_encoder(batch["input_ids"].to(device))[0]
# Predict the noise residual
noise_pred = image_generator.unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss_cum += loss.item()
return loss_cum / len(dataloader)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--intrinsic_dim", default=256, type=int)
parser.add_argument("--k_shot", default=16, type=int)
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--budget", default=5000, type=int) # number of iterations
parser.add_argument("--popsize", default=20, type=int) # number of candidates
parser.add_argument("--bound", default=0, type=int)
parser.add_argument("--sigma", default=1, type=float)
parser.add_argument("--alpha", default=1, type=float)
parser.add_argument("--print_every", default=50, type=int)
parser.add_argument("--eval_every", default=100, type=int)
parser.add_argument("--alg", default='CMA', type=str) # support other advanced evelution strategy
parser.add_argument("--projection_modeling", default='pca', type=str) # decomposition method {'pca', 'prior_norm'}
parser.add_argument("--model_dim", default=512, type=int) # dim of textual inversion
parser.add_argument("--inversion_initialize", default='./save/initialize_emb.bin', type=str) # dim of textual inversion
parser.add_argument("--seed", default=2023, type=int)
parser.add_argument("--loss_type", default='noise', type=str)
parser.add_argument("--cat_or_add", default='add', type=str)
parser.add_argument("--device", default= torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
parser.add_argument("--parallel", default=False, type=bool, help='Whether to allow parallel evaluation')
parser.add_argument(
"--placeholder_token",
type=str,
default='<sks>',
help="A token to use as a placeholder for the concept.",
)
parser.add_argument(
"--initializer_token",
type=str,
default='coconut_seed_fruit_stage_CT',
help="A token to use as initializer word."
)
parser.add_argument(
"--inference_framework",
default='pt',
type=str,
help='''Which inference framework to use.
Currently supports `pt` and `ort`, standing for pytorch and Microsoft onnxruntime respectively'''
)
parser.add_argument(
"--onnx_model_path",
default=None,
type=str,
help='Path to your onnx model.'
)
parser.add_argument(
"--train_data_dir",
type=str,
default='./coconut_seed_fruit_stage',
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--learnable_property",
type=str,
default="object",
help="Choose between 'object' and 'style'"
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--repeats", type=int, default=5, help="How many times to repeat the training data.")
args = parser.parse_args()
cma_opts = {
'seed': args.seed,
'popsize': args.popsize,
'maxiter': args.budget if args.parallel else args.budget // args.popsize,
'verbose': -1,
}
if args.bound > 0:
cma_opts['bounds'] = [-1 * args.bound, 1 * args.bound]
if args.inversion_initialize is not None:
print('initialize textual inversion')
init_text_inversion = torch.load(args.inversion_initialize, map_location="cpu")["initialize"]
else:
init_text_inversion = None
pipeline = GradientFreePipeline(model_path='./ckpt', args=args, init_text_inversion=init_text_inversion)
es = cma.CMAEvolutionStrategy(args.intrinsic_dim * [0], args.sigma, inopts=cma_opts)
while not es.stop():
solutions = es.ask() # (popsize, intrinsic_dim)
fitnesses = pipeline.eval(solutions)
print(fitnesses) # loss for each point
es.tell(solutions, fitnesses)
pipeline.save('./save')
if __name__ == "__main__":
main()