Imbalance Between Generator and Discriminator Losses in GAN Training for Super-Resolution

Question:
I am currently training a Generative Adversarial Network (GAN) for the task of image super-resolution. However, I am encountering a persistent issue where the generator’s loss decreases significantly over time, but the discriminator’s loss remains high and does not show any signs of improvement. This imbalance in training progress is preventing the discriminator from effectively distinguishing between real and generated high-resolution images.

I have attempted several strategies to address this problem, including:

Adjusting Learning Rates: I have experimented with different learning rates for both the generator and the discriminator to balance their training dynamics.

Gradient Clipping: To stabilize the training process, I have implemented gradient clipping to prevent exploding gradients.

Architecture Changes: I have modified the architectures of both the generator and the discriminator to ensure they are suitably matched for the task.

Despite these efforts, the issue remains unresolved. The generator continues to improve, producing higher quality images, while the discriminator struggles to provide meaningful feedback.

Has anyone else faced a similar issue when training GANs for image super-resolution? If so, could you please share your insights or solutions that helped to balance the training of the generator and discriminator? Any advice on additional techniques or modifications to the training process would be greatly appreciated.

Current Training Output:

Training Epoch 0 : 100%
613/613 [05:26<00:00, 2.01it/s, disc_loss=0.506, gen_loss=0.0292]

Epoch [1/25]
Training - Generator Loss: 0.0292, Discriminator Loss: 0.5064

Testing Epoch 0 : 100%
17/17 [00:06<00:00, 3.12it/s, disc_loss=0.541, gen_loss=0.0218]

Testing - Generator Loss: 0.0218, Discriminator Loss: 0.5412
Average PSNR: 16.7563, Average SSIM: 0.6819    

Training Code:

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        vgg19_model.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.conv_block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, n_residual_blocks=16):
        super(GeneratorResNet, self).__init__()

        # First layer
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())

        # Residual blocks
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)

        # Second conv layer post residual blocks
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(64))
        


        # Upsampling layers
        upsampling = []
        for out_features in range(2):
            upsampling += [
                # nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 256, 3, 1, 1),  # Change output channels to 256*4      nn.Conv2d(64, 256, 3, 1, 1),
                nn.PixelShuffle(upscale_factor=2), #        nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)

        # Final output layer
        self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())


    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        self.output_shape = (1, int(in_height / (2 ** 6)), int(in_width / (2 ** 6)))

        
        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 64, 128, 256, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)


# Initialize generator and discriminator
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(channels, *hr_shape))
feature_extractor = FeatureExtractor()

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss()  # Changed from MSELoss to BCEWithLogitsLoss
criterion_content = torch.nn.L1Loss()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    feature_extractor = feature_extractor.cuda()
    criterion_GAN = criterion_GAN.cuda()
    criterion_content = criterion_content.cuda()

# Load pretrained models
if load_pretrained_models:
    generator.load_state_dict(torch.load("/kaggle/working/saved_models/generator3.pth"))
    discriminator.load_state_dict(torch.load("/kaggle/working/saved_models/discriminator3.pth"))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0004, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Fit
train_gen_losses, train_disc_losses, train_counter = [], [], []
test_gen_losses, test_disc_losses = [], []
test_counter = [idx*len(train_dataloader.dataset) for idx in range(1, n_epochs+1)]

for epoch in range(n_epochs):
    ### Training
    gen_loss, disc_loss = 0, 0
    psnr_values = []  # List to store PSNR values
    ssim_values = []  # List to store SSIM values
    tqdm_bar = tqdm(train_dataloader, desc=f'Training Epoch {epoch} ', total=int(len(train_dataloader)))
    for batch_idx, imgs in enumerate(tqdm_bar):
        generator.train(); discriminator.train()
        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        
        ### Train Generator
        optimizer_G.zero_grad()
        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)
        # Adversarial loss
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())
        # Total loss
        loss_G = loss_content + 1e-3 * loss_GAN
        loss_G.backward()
        
        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        
        optimizer_G.step()

        ### Train Discriminator
        optimizer_D.zero_grad()
        # Loss of real and fake images
        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
        # Total loss
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        
        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        
        optimizer_D.step()

        gen_loss += loss_G.item()
        train_gen_losses.append(loss_G.item())
        disc_loss += loss_D.item()
        train_disc_losses.append(loss_D.item())
        train_counter.append(batch_idx*batch_size + imgs_lr.size(0) + epoch*len(train_dataloader.dataset))
        tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))

        
    print(f"Epoch [{epoch+1}/{n_epochs}]")
    print(f"Training - Generator Loss: {gen_loss/len(train_dataloader):.4f}, Discriminator Loss: {disc_loss/len(train_dataloader):.4f}")

    # Testing
    gen_loss, disc_loss = 0, 0
    psnr_values = []  # List to store PSNR values
    ssim_values = []  # List to store SSIM values
    tqdm_bar = tqdm(test_dataloader, desc=f'Testing Epoch {epoch} ', total=int(len(test_dataloader)))
    for batch_idx, imgs in enumerate(tqdm_bar):
        generator.eval(); discriminator.eval()
        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        
        ### Eval Generator
        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)
        
        for i in range(imgs_hr.size(0)):
            gen_hr_numpy = gen_hr[i].permute(1, 2, 0).detach().cpu().numpy()
            gt_hr_numpy = imgs_hr[i].permute(1, 2, 0).detach().cpu().numpy()
            gen_hr_numpy = np.squeeze(gen_hr_numpy)
            gt_hr_numpy = np.squeeze(gt_hr_numpy)
#             print(f"Generated image shape: {gen_hr_numpy.shape}")
#             print(f"Ground truth image shape: {gt_hr_numpy.shape}")
        
            gen_hr_numpy = (gen_hr_numpy * 255).astype(np.uint8)  # Scale to [0, 255] range
            gt_hr_numpy = (gt_hr_numpy * 255).astype(np.uint8)  # Scale to [0, 255] range

            psnr_values.append(peak_signal_noise_ratio(gt_hr_numpy, gen_hr_numpy))
            ssim_values.append(structural_similarity(gt_hr_numpy, gen_hr_numpy, multichannel=False))

        # Adversarial loss
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())
        # Total loss
        loss_G = loss_content + 1e-3 * loss_GAN

        ### Eval Discriminator
        # Loss of real and fake images
        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        gen_loss += loss_G.item()
        disc_loss += loss_D.item()
        tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))
        
        # Save image grid with upsampled inputs and SRGAN outputs
        if random.uniform(0,1)<0.1:
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
            save_image(img_grid, f"images3/epoch_{epoch}_batch_{batch_idx}.png", normalize=False)

    test_gen_losses.append(gen_loss/len(test_dataloader))
    test_disc_losses.append(disc_loss/len(test_dataloader))
    avg_psnr = sum(psnr_values) / len(psnr_values)
    avg_ssim = sum(ssim_values) / len(ssim_values)

    print(f"Testing - Generator Loss: {gen_loss/len(test_dataloader):.4f}, Discriminator Loss: {disc_loss/len(test_dataloader):.4f}")
    print(f"Average PSNR: {avg_psnr:.4f}, Average SSIM: {avg_ssim:.4f}")
    print("--------------------")
    
    # Save model checkpoints
    if np.argmin(test_gen_losses) == len(test_gen_losses)-1:
        torch.save(generator.state_dict(), "saved_models/generator4.pth")
        torch.save(discriminator.state_dict(), "saved_models/discriminator4.pth")

Adjusting Learning Rates: I experimented with different learning rates for both the generator and the discriminator to find a balance that would allow both networks to learn effectively without one overpowering the other. I expected that by tuning these rates, the discriminator would start to improve its ability to distinguish real from generated images.

Gradient Clipping: To stabilize the training process, I implemented gradient clipping to prevent exploding gradients, which can disrupt the learning process. I expected this to help the discriminator learn more effectively by ensuring smoother updates.

Architecture Changes: I modified the architectures of both the generator and the discriminator, ensuring that they were suitably matched for the task. I expected that by making the networks more balanced, the discriminator would be able to catch up in terms of learning progress.

What actually resulted was that the generator’s loss continued to decrease significantly, indicating that it was learning to produce better high-resolution images. However, the discriminator’s loss remained high and did not improve, suggesting that it was not learning to effectively distinguish between real and generated images. This imbalance in training progress is preventing the GAN from achieving its goal of producing high-quality super-resolution images.

New contributor

Osama Shahin is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.

Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa Dịch vụ tổ chức sự kiện 5 sao Thông tin về chúng tôi Dịch vụ sinh nhật bé trai Dịch vụ sinh nhật bé gái Sự kiện trọn gói Các tiết mục giải trí Dịch vụ bổ trợ Tiệc cưới sang trọng Dịch vụ khai trương Tư vấn tổ chức sự kiện Hình ảnh sự kiện Cập nhật tin tức Liên hệ ngay Thuê chú hề chuyên nghiệp Tiệc tất niên cho công ty Trang trí tiệc cuối năm Tiệc tất niên độc đáo Sinh nhật bé Hải Đăng Sinh nhật đáng yêu bé Khánh Vân Sinh nhật sang trọng Bích Ngân Tiệc sinh nhật bé Thanh Trang Dịch vụ ông già Noel Xiếc thú vui nhộn Biểu diễn xiếc quay đĩa Dịch vụ tổ chức tiệc uy tín Khám phá dịch vụ của chúng tôi Tiệc sinh nhật cho bé trai Trang trí tiệc cho bé gái Gói sự kiện chuyên nghiệp Chương trình giải trí hấp dẫn Dịch vụ hỗ trợ sự kiện Trang trí tiệc cưới đẹp Khởi đầu thành công với khai trương Chuyên gia tư vấn sự kiện Xem ảnh các sự kiện đẹp Tin mới về sự kiện Kết nối với đội ngũ chuyên gia Chú hề vui nhộn cho tiệc sinh nhật Ý tưởng tiệc cuối năm Tất niên độc đáo Trang trí tiệc hiện đại Tổ chức sinh nhật cho Hải Đăng Sinh nhật độc quyền Khánh Vân Phong cách tiệc Bích Ngân Trang trí tiệc bé Thanh Trang Thuê dịch vụ ông già Noel chuyên nghiệp Xem xiếc khỉ đặc sắc Xiếc quay đĩa thú vị
Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa
Thiết kế website Thiết kế website Thiết kế website Cách kháng tài khoản quảng cáo Mua bán Fanpage Facebook Dịch vụ SEO Tổ chức sinh nhật