Question:
I am currently training a Generative Adversarial Network (GAN) for the task of image super-resolution. However, I am encountering a persistent issue where the generator’s loss decreases significantly over time, but the discriminator’s loss remains high and does not show any signs of improvement. This imbalance in training progress is preventing the discriminator from effectively distinguishing between real and generated high-resolution images.
I have attempted several strategies to address this problem, including:
Adjusting Learning Rates: I have experimented with different learning rates for both the generator and the discriminator to balance their training dynamics.
Gradient Clipping: To stabilize the training process, I have implemented gradient clipping to prevent exploding gradients.
Architecture Changes: I have modified the architectures of both the generator and the discriminator to ensure they are suitably matched for the task.
Despite these efforts, the issue remains unresolved. The generator continues to improve, producing higher quality images, while the discriminator struggles to provide meaningful feedback.
Has anyone else faced a similar issue when training GANs for image super-resolution? If so, could you please share your insights or solutions that helped to balance the training of the generator and discriminator? Any advice on additional techniques or modifications to the training process would be greatly appreciated.
Current Training Output:
Training Epoch 0 : 100%
613/613 [05:26<00:00, 2.01it/s, disc_loss=0.506, gen_loss=0.0292]
Epoch [1/25]
Training - Generator Loss: 0.0292, Discriminator Loss: 0.5064
Testing Epoch 0 : 100%
17/17 [00:06<00:00, 3.12it/s, disc_loss=0.541, gen_loss=0.0218]
Testing - Generator Loss: 0.0218, Discriminator Loss: 0.5412
Average PSNR: 16.7563, Average SSIM: 0.6819
Training Code:
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
vgg19_model.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])
def forward(self, img):
return self.feature_extractor(img)
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features),
nn.PReLU(),
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.conv_block(x)
class GeneratorResNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, n_residual_blocks=16):
super(GeneratorResNet, self).__init__()
# First layer
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
# Residual blocks
res_blocks = []
for _ in range(n_residual_blocks):
res_blocks.append(ResidualBlock(64))
self.res_blocks = nn.Sequential(*res_blocks)
# Second conv layer post residual blocks
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(64))
# Upsampling layers
upsampling = []
for out_features in range(2):
upsampling += [
# nn.Upsample(scale_factor=2),
nn.Conv2d(64, 256, 3, 1, 1), # Change output channels to 256*4 nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(upscale_factor=2), # nn.PixelShuffle(upscale_factor=2),
nn.PReLU(),
]
self.upsampling = nn.Sequential(*upsampling)
# Final output layer
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
return out
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
self.input_shape = input_shape
in_channels, in_height, in_width = self.input_shape
self.output_shape = (1, int(in_height / (2 ** 6)), int(in_width / (2 ** 6)))
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = in_channels
for i, out_filters in enumerate([64, 64, 128, 256, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)
# Initialize generator and discriminator
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(channels, *hr_shape))
feature_extractor = FeatureExtractor()
# Set feature extractor to inference mode
feature_extractor.eval()
# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss() # Changed from MSELoss to BCEWithLogitsLoss
criterion_content = torch.nn.L1Loss()
if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
feature_extractor = feature_extractor.cuda()
criterion_GAN = criterion_GAN.cuda()
criterion_content = criterion_content.cuda()
# Load pretrained models
if load_pretrained_models:
generator.load_state_dict(torch.load("/kaggle/working/saved_models/generator3.pth"))
discriminator.load_state_dict(torch.load("/kaggle/working/saved_models/discriminator3.pth"))
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0004, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(b1, b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
# Fit
train_gen_losses, train_disc_losses, train_counter = [], [], []
test_gen_losses, test_disc_losses = [], []
test_counter = [idx*len(train_dataloader.dataset) for idx in range(1, n_epochs+1)]
for epoch in range(n_epochs):
### Training
gen_loss, disc_loss = 0, 0
psnr_values = [] # List to store PSNR values
ssim_values = [] # List to store SSIM values
tqdm_bar = tqdm(train_dataloader, desc=f'Training Epoch {epoch} ', total=int(len(train_dataloader)))
for batch_idx, imgs in enumerate(tqdm_bar):
generator.train(); discriminator.train()
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
### Train Generator
optimizer_G.zero_grad()
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
loss_G.backward()
# Apply gradient clipping
torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
optimizer_G.step()
### Train Discriminator
optimizer_D.zero_grad()
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
loss_D.backward()
# Apply gradient clipping
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer_D.step()
gen_loss += loss_G.item()
train_gen_losses.append(loss_G.item())
disc_loss += loss_D.item()
train_disc_losses.append(loss_D.item())
train_counter.append(batch_idx*batch_size + imgs_lr.size(0) + epoch*len(train_dataloader.dataset))
tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))
print(f"Epoch [{epoch+1}/{n_epochs}]")
print(f"Training - Generator Loss: {gen_loss/len(train_dataloader):.4f}, Discriminator Loss: {disc_loss/len(train_dataloader):.4f}")
# Testing
gen_loss, disc_loss = 0, 0
psnr_values = [] # List to store PSNR values
ssim_values = [] # List to store SSIM values
tqdm_bar = tqdm(test_dataloader, desc=f'Testing Epoch {epoch} ', total=int(len(test_dataloader)))
for batch_idx, imgs in enumerate(tqdm_bar):
generator.eval(); discriminator.eval()
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
### Eval Generator
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
for i in range(imgs_hr.size(0)):
gen_hr_numpy = gen_hr[i].permute(1, 2, 0).detach().cpu().numpy()
gt_hr_numpy = imgs_hr[i].permute(1, 2, 0).detach().cpu().numpy()
gen_hr_numpy = np.squeeze(gen_hr_numpy)
gt_hr_numpy = np.squeeze(gt_hr_numpy)
# print(f"Generated image shape: {gen_hr_numpy.shape}")
# print(f"Ground truth image shape: {gt_hr_numpy.shape}")
gen_hr_numpy = (gen_hr_numpy * 255).astype(np.uint8) # Scale to [0, 255] range
gt_hr_numpy = (gt_hr_numpy * 255).astype(np.uint8) # Scale to [0, 255] range
psnr_values.append(peak_signal_noise_ratio(gt_hr_numpy, gen_hr_numpy))
ssim_values.append(structural_similarity(gt_hr_numpy, gen_hr_numpy, multichannel=False))
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
### Eval Discriminator
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
gen_loss += loss_G.item()
disc_loss += loss_D.item()
tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))
# Save image grid with upsampled inputs and SRGAN outputs
if random.uniform(0,1)<0.1:
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
save_image(img_grid, f"images3/epoch_{epoch}_batch_{batch_idx}.png", normalize=False)
test_gen_losses.append(gen_loss/len(test_dataloader))
test_disc_losses.append(disc_loss/len(test_dataloader))
avg_psnr = sum(psnr_values) / len(psnr_values)
avg_ssim = sum(ssim_values) / len(ssim_values)
print(f"Testing - Generator Loss: {gen_loss/len(test_dataloader):.4f}, Discriminator Loss: {disc_loss/len(test_dataloader):.4f}")
print(f"Average PSNR: {avg_psnr:.4f}, Average SSIM: {avg_ssim:.4f}")
print("--------------------")
# Save model checkpoints
if np.argmin(test_gen_losses) == len(test_gen_losses)-1:
torch.save(generator.state_dict(), "saved_models/generator4.pth")
torch.save(discriminator.state_dict(), "saved_models/discriminator4.pth")
Adjusting Learning Rates: I experimented with different learning rates for both the generator and the discriminator to find a balance that would allow both networks to learn effectively without one overpowering the other. I expected that by tuning these rates, the discriminator would start to improve its ability to distinguish real from generated images.
Gradient Clipping: To stabilize the training process, I implemented gradient clipping to prevent exploding gradients, which can disrupt the learning process. I expected this to help the discriminator learn more effectively by ensuring smoother updates.
Architecture Changes: I modified the architectures of both the generator and the discriminator, ensuring that they were suitably matched for the task. I expected that by making the networks more balanced, the discriminator would be able to catch up in terms of learning progress.
What actually resulted was that the generator’s loss continued to decrease significantly, indicating that it was learning to produce better high-resolution images. However, the discriminator’s loss remained high and did not improve, suggesting that it was not learning to effectively distinguish between real and generated images. This imbalance in training progress is preventing the GAN from achieving its goal of producing high-quality super-resolution images.
Osama Shahin is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.