I’m working on testing a U-Net architecture by training it on a “dataset” consisting of a single training example (image to image). The input image is a noisy version of the output image. Initially, the output image starts to look more like the desired output, but the loss curve begins to plateau, and the model stops improving.
My questions are:
- Should a U-Net (or any CNN without fully connected layers) be able to overfit on a constant image given a single example?
- What would be a common mistake or things to look at if it can’t do this simply task?
I’ve simplified the architecture by reducing the depth to almost a Double Conv Layer without any encoder/decoder layers and tuned the learning rate, but it still won’t overfit. I expected the model to perfectly overfit on the single training example, but it doesn’t. I’ve tuned the learning rate but the model still fails to achieve perfect overfitting on this simple task.
Zoom Input Image
Zoom Ground Truth
Zoom Model Output after hitting Plateau
Here are the specifics of my experiment:
Architecture (Depth 1):
UNet(
(encoders): ModuleList(
(0): Sequential(
(0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
)
(decoders): ModuleList(
(0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(1): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(bottleneck): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(out_conv): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
)
Initialization: Kaiming Normal
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
Upsampling Method: Crop and Concat
def crop_and_concat(self, upsampled: torch.Tensor,
bypass: torch.Tensor) -> torch.Tensor:
diffY = bypass.size()[2] - upsampled.size()[2]
diffX = bypass.size()[3] - upsampled.size()[3]
upsampled = F.pad(upsampled, (diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2))
return torch.cat((upsampled, bypass), dim=1)
Loss: MSELoss
Optimizer: Adam with learning rates from 1e-3 to 1e-4 (everything above is oscillating, below is also hitting a plateau and just converging slower)
Samuel Kopp is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.