i’m writing diploma project to generate photorealistic images by using GAN. So, i decided to use BigGan’s architecture for generating only human faces from noise.
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
import torchvision.transforms as transforms
import PIL
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from tqdm import tqdm
class SelfAttn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_channels, eps=1e-12):
super(SelfAttn, self).__init__()
self.in_channels = in_channels
self.snconv1x1_theta = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False)
self.snconv1x1_phi = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False)
self.snconv1x1_g = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,
kernel_size=1, bias=False)
self.snconv1x1_o_conv = nn.Conv2d(in_channels=in_channels//2, out_channels=in_channels,
kernel_size=1, bias=False)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch//8, h*w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch//8, h*w//4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch//2, h*w//4)
# Attn_g - o_conv
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch//2, h, w)
attn_g = self.snconv1x1_o_conv(attn_g)
# Out
out = x + self.gamma*attn_g
return out
class GenBlock(nn.Module):
def __init__(self, in_size, out_size, reduction_factor=4, up_sample=False,
eps=1e-12):
super(GenBlock, self).__init__()
self.up_sample = up_sample
self.drop_channels = (in_size != out_size)
middle_size = in_size // reduction_factor
self.conv_0 = nn.Conv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1)
self.conv_1 = nn.Conv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1)
self.conv_2 = nn.Conv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1)
self.conv_3 = nn.Conv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1)
self.relu = nn.ReLU()
def forward(self, x):
x0 = x
x = self.relu(x)
x = self.conv_0(x)
x = self.relu(x)
if self.up_sample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv_1(x)
x = self.relu(x)
x = self.conv_2(x)
x = self.relu(x)
x = self.conv_3(x)
if self.drop_channels:
new_channels = x0.shape[1] // 2
x0 = x0[:, :new_channels, ...]
if self.up_sample:
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
out = x + x0
return out
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen_z = nn.Linear(in_features=128, out_features=4 * 4 * 16 * 96)
layers = []
layers.append(SelfAttn(96 * 16))
layers.append(GenBlock(96 * 16, 96 * 8))
layers.append(GenBlock(96 * 8, 96 * 4))
layers.append(GenBlock(96 * 4, 96 * 2, up_sample=True))
layers.append(GenBlock(96 * 2, 96))
self.layers = nn.ModuleList(layers)
self.conv_to_rgb = nn.Conv2d(in_channels=96, out_channels=3, kernel_size=3, padding=1)
self.tanh = nn.Tanh()
def forward(self, z):
z = self.gen_z(z)
z = z.view(-1, 4, 4, 16 * 96)
z = z.permute(0, 3, 1, 2).contiguous()
for layer in self.layers:
if isinstance(layer, GenBlock):
z = layer(z)
else:
z = layer(z)
z = self.conv_to_rgb(z)
z = self.tanh(z)
return z
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) # Changed to avoid reduction
self.conv5 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1) # Added deconvolution layer
self.conv6 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1) # Added deconvolution layer
self.conv7 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1)
self.leaky_relu = nn.LeakyReLU(0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.leaky_relu(self.conv1(x))
x = self.leaky_relu(self.conv2(x))
x = self.leaky_relu(self.conv3(x))
x = self.leaky_relu(self.conv4(x))
x = self.leaky_relu(self.conv5(x)) # Pass through deconvolution layers
x = self.leaky_relu(self.conv6(x))
x = self.sigmoid(self.conv7(x))
return x
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_list = os.listdir(root_dir)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_list[idx])
image = PIL.Image.open(img_name).convert('RGB')
if self.transform:
image = self.transform(image)
return image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = CustomDataset(root_dir='/content/Humans', transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
for i, batch in tqdm(enumerate(dataloader, 0), total=len(dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False):
discriminator.zero_grad()
real_images = batch.to(device)
z = torch.randn(batch_size, 128).to(device)
fake_images = generator(z).detach()
real_outputs = discriminator(real_images)
fake_outputs = discriminator(fake_images)
real_labels = torch.ones_like(real_outputs).to(device)
fake_labels = torch.zeros_like(fake_outputs).to(device)
d_loss_real = criterion(real_outputs, real_labels)
d_loss_fake = criterion(fake_outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
generator.zero_grad()
z = torch.randn(batch_size, 128).to(device)
fake_images = generator(z)
outputs = discriminator(fake_images)
print(outputs.shape)
print('-----------------------------------------------------')
print(real_labels.shape)
real_labels_reshaped = real_labels[:, :, :8, :8].view(outputs.shape)
g_loss = criterion(outputs, real_labels_reshaped)
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(f"[{epoch + 1}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")
torch.save(generator.state_dict(), f"generator_epoch_{epoch + 1}.pth")
torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch + 1}.pth")
print("finish")
Always have same issue, that output and input shape is different, so model doesn’t go further first epoch. Help to student pls
Tried to reshape output after getting it, worked on reshaping real_labels and changing architecture of Generator and Discriminator. Still have this issue, i couldn’t solve it by myself