Hi I am training ViT (patch 16, 224, ImageNet Pretrained) backbone on satellite imagery (Million-AID dataset, ~900000 images of varying sizes) in Self-Supervised learning fashion using an Masked Auto Encoder (MAE). The loss I get during training is very volatile and wanted to confirm if anyone knows if this behavior is normal when training a ViT or when training ViT using MAE approach?
I have pasted the code here just in case I have made a mistake. I am using the lightly-ssl
to ease up training an MAE.
<code># ==============================================================================
# MAE Imp0lementation : THE PAPER HAS TWO DIFFERENT SETTINGS FOR PRETRAINING & FINETUNING
# What the paper uses | Any changes that deviate from paper
# Backbone : VIT-L/16 | VIT-B/16
# Decoder : depth - 8 blocks, Width - 512d
# Encoder : w/o mask tokens
# Optimizer
# linear lr scaling (cosine decay) : lr = base_lr x bs / 256, base_lr = 1.5e-3/1e-3, warmup = 40/5 epochs
# AdamW : Beta1,2 = 0.9,0.95 / 0.9,0.999
# Batch Size : 4096/1024 | 1024/1024
# Loss func : MSE
# Masking : 75%
# Augmenntation : RandomResized Crop
# Crop Size : 224x224
# Fine tuning : 50 epochs vs Pretraining : 200 epochs
# Note that these are for Vit-L
# ==============================================================================
mae_params = config["mae_params"]
class MaeBBViT(pl.LightningModule):
def __init__(self, model_params:dict, backbone:VisionTransformer):
"""
Inputs
- model_params: dictionary containing model parameters such as lr, batch_size
- backbone : vit model with last layer removed
"""
super().__init__()
# Saving hyperparameters
hyper_dict = {}
hyper_dict.update(mae_params)
hyper_dict.update(model_params)
self.save_hyperparameters(hyper_dict)
self.model_params = model_params
self.backbone = backbone # we will need this later
self.mask_ratio = mae_params["mask_ratio"] # 0.75
self.patch_size = backbone.patch_embed.patch_size[0] #(16,16) so index [0] returns 16
self.masked_encoder = MaskedVisionTransformerTIMM(vit = backbone)
self.sequence_length = self.masked_encoder.sequence_length #197
self.decoder = MAEDecoderTIMM(
num_patches = backbone.patch_embed.num_patches, #196
patch_size = self.patch_size,#196
embed_dim = backbone.embed_dim, #768
decoder_embed_dim = mae_params["decoder_dim"], #512
decoder_depth = mae_params["decoder_depth"], #8
mlp_ratio = 4.0,
proj_drop_rate = 0.0, # drop out rate in projection head
attn_drop_rate = 0.0 # drop out rate in
)
self.criterion = nn.MSELoss()
self.apply_lr_scheduler = False if model_params["lr"] else True
def forward_encoder(self, images, idx_keep = None):
# shape returned from .encode = (bs, num_unmasked_patches, embed size)
# What you pass here as images are actually just unamsked portion of the image only
# So if 50/197 patches were left not masked you get the shape below (total 197 including cls) of which 50 is not masked
return self.masked_encoder.encode(images = images, idx_keep = idx_keep) #(*,3,224,224) -> (*,50,768)
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# Build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded) #(*,50,512) where 50 is number NOT makes patches, 512 is num dimension decalred above
x_masked = repeat_token(
self.decoder.mask_token,
(batch_size, self.sequence_length)
) #(*,197,512) where 197 nuber of masked patches & 512 is the nub dimensions declared above
x_masked = set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) #(*,197,512)
# decoder forward pass
x_decoded = self.decoder.decode(x_masked) #(*,197,512)
# Predict pixel values for masked tokens
x_pred = get_at_index(x_decoded, idx_mask) #(*,147,512)
x_pred = self.decoder.predict(x_pred) #(*,147,768)
return x_pred
def training_step(self, batch, batch_idx):
X,y,f = batch
images = X[0] # (*,3, 224,224) There is only a single view but its within a list
batch_size = images.shape[0]
# These are unmasked and masked patches. Rememnber that there is 197 patches including ...
# CLS token (14*14 + 1), 14 patches was obtained by 224/16 where each patch is 16 pixels in height and width
idx_keep, idx_mask = random_token_mask(
size = (batch_size, self.sequence_length),
mask_ratio = self.mask_ratio,
device = images.device
) #(*,50),(*,147)
x_encoded = self.forward_encoder(
images = images,
idx_keep = idx_keep
) #(*,50,768) where 50 is the number of NOT maked patches
x_pred = self.forward_decoder(
x_encoded = x_encoded,
idx_keep = idx_keep,
idx_mask = idx_mask
) #(*,147,768) where 147 are the number of masked patches
# get image patches for masked tokens
patches = patchify(images, self.patch_size) #(*,196,768)
# must adjust idx mask for missing class token
target = get_at_index(patches, idx_mask - 1) #(*, 147, 768)
self.loss = self.criterion(x_pred, target) #(,)
return self.loss
def on_train_epoch_end(self) -> None:
self.log("training loss" , self.loss)
if self.apply_lr_scheduler:
self.log("current lr", self.scheduler.get_lr()[0])
else:
self.log("current lr", self.model_params["lr"])
def configure_optimizers(self):
if self.apply_lr_scheduler:
optimizer = torch.optim.AdamW(
params= self.parameters(),
lr = mae_params["base_lr"] * self.model_params["eff_batch_size"] / 256,
weight_decay=mae_params["weight_decay"]
)
self.scheduler = LinearWarmupCosineAnnealingLR(
optimizer = optimizer,
warmup_epochs=mae_params["warmup_epochs"], # Linearly rampup lr as then decay using cosine as indicated in paper
max_epochs=self.model_params["epochs"],
warmup_start_lr=mae_params["base_lr"], # we linearly ramp up from 0 to base_lr which is indicated in the optimizer
eta_min=mae_params["eta_min"] #* We keep eta_min at 0 as Dino Paper hasnt indicated a value
)
return [optimizer],[{"scheduler" : self.scheduler, "interval" : "epoch"}]
else:
optimizer = torch.optim.AdamW(
params = self.parameters(), lr = self.model_params["lr"], weight_decay=mae_params["weight_decay"]
)
return optimizer
</code>
<code># ==============================================================================
# MAE Imp0lementation : THE PAPER HAS TWO DIFFERENT SETTINGS FOR PRETRAINING & FINETUNING
# What the paper uses | Any changes that deviate from paper
# Backbone : VIT-L/16 | VIT-B/16
# Decoder : depth - 8 blocks, Width - 512d
# Encoder : w/o mask tokens
# Optimizer
# linear lr scaling (cosine decay) : lr = base_lr x bs / 256, base_lr = 1.5e-3/1e-3, warmup = 40/5 epochs
# AdamW : Beta1,2 = 0.9,0.95 / 0.9,0.999
# Batch Size : 4096/1024 | 1024/1024
# Loss func : MSE
# Masking : 75%
# Augmenntation : RandomResized Crop
# Crop Size : 224x224
# Fine tuning : 50 epochs vs Pretraining : 200 epochs
# Note that these are for Vit-L
# ==============================================================================
mae_params = config["mae_params"]
class MaeBBViT(pl.LightningModule):
def __init__(self, model_params:dict, backbone:VisionTransformer):
"""
Inputs
- model_params: dictionary containing model parameters such as lr, batch_size
- backbone : vit model with last layer removed
"""
super().__init__()
# Saving hyperparameters
hyper_dict = {}
hyper_dict.update(mae_params)
hyper_dict.update(model_params)
self.save_hyperparameters(hyper_dict)
self.model_params = model_params
self.backbone = backbone # we will need this later
self.mask_ratio = mae_params["mask_ratio"] # 0.75
self.patch_size = backbone.patch_embed.patch_size[0] #(16,16) so index [0] returns 16
self.masked_encoder = MaskedVisionTransformerTIMM(vit = backbone)
self.sequence_length = self.masked_encoder.sequence_length #197
self.decoder = MAEDecoderTIMM(
num_patches = backbone.patch_embed.num_patches, #196
patch_size = self.patch_size,#196
embed_dim = backbone.embed_dim, #768
decoder_embed_dim = mae_params["decoder_dim"], #512
decoder_depth = mae_params["decoder_depth"], #8
mlp_ratio = 4.0,
proj_drop_rate = 0.0, # drop out rate in projection head
attn_drop_rate = 0.0 # drop out rate in
)
self.criterion = nn.MSELoss()
self.apply_lr_scheduler = False if model_params["lr"] else True
def forward_encoder(self, images, idx_keep = None):
# shape returned from .encode = (bs, num_unmasked_patches, embed size)
# What you pass here as images are actually just unamsked portion of the image only
# So if 50/197 patches were left not masked you get the shape below (total 197 including cls) of which 50 is not masked
return self.masked_encoder.encode(images = images, idx_keep = idx_keep) #(*,3,224,224) -> (*,50,768)
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# Build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded) #(*,50,512) where 50 is number NOT makes patches, 512 is num dimension decalred above
x_masked = repeat_token(
self.decoder.mask_token,
(batch_size, self.sequence_length)
) #(*,197,512) where 197 nuber of masked patches & 512 is the nub dimensions declared above
x_masked = set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) #(*,197,512)
# decoder forward pass
x_decoded = self.decoder.decode(x_masked) #(*,197,512)
# Predict pixel values for masked tokens
x_pred = get_at_index(x_decoded, idx_mask) #(*,147,512)
x_pred = self.decoder.predict(x_pred) #(*,147,768)
return x_pred
def training_step(self, batch, batch_idx):
X,y,f = batch
images = X[0] # (*,3, 224,224) There is only a single view but its within a list
batch_size = images.shape[0]
# These are unmasked and masked patches. Rememnber that there is 197 patches including ...
# CLS token (14*14 + 1), 14 patches was obtained by 224/16 where each patch is 16 pixels in height and width
idx_keep, idx_mask = random_token_mask(
size = (batch_size, self.sequence_length),
mask_ratio = self.mask_ratio,
device = images.device
) #(*,50),(*,147)
x_encoded = self.forward_encoder(
images = images,
idx_keep = idx_keep
) #(*,50,768) where 50 is the number of NOT maked patches
x_pred = self.forward_decoder(
x_encoded = x_encoded,
idx_keep = idx_keep,
idx_mask = idx_mask
) #(*,147,768) where 147 are the number of masked patches
# get image patches for masked tokens
patches = patchify(images, self.patch_size) #(*,196,768)
# must adjust idx mask for missing class token
target = get_at_index(patches, idx_mask - 1) #(*, 147, 768)
self.loss = self.criterion(x_pred, target) #(,)
return self.loss
def on_train_epoch_end(self) -> None:
self.log("training loss" , self.loss)
if self.apply_lr_scheduler:
self.log("current lr", self.scheduler.get_lr()[0])
else:
self.log("current lr", self.model_params["lr"])
def configure_optimizers(self):
if self.apply_lr_scheduler:
optimizer = torch.optim.AdamW(
params= self.parameters(),
lr = mae_params["base_lr"] * self.model_params["eff_batch_size"] / 256,
weight_decay=mae_params["weight_decay"]
)
self.scheduler = LinearWarmupCosineAnnealingLR(
optimizer = optimizer,
warmup_epochs=mae_params["warmup_epochs"], # Linearly rampup lr as then decay using cosine as indicated in paper
max_epochs=self.model_params["epochs"],
warmup_start_lr=mae_params["base_lr"], # we linearly ramp up from 0 to base_lr which is indicated in the optimizer
eta_min=mae_params["eta_min"] #* We keep eta_min at 0 as Dino Paper hasnt indicated a value
)
return [optimizer],[{"scheduler" : self.scheduler, "interval" : "epoch"}]
else:
optimizer = torch.optim.AdamW(
params = self.parameters(), lr = self.model_params["lr"], weight_decay=mae_params["weight_decay"]
)
return optimizer
</code>
# ==============================================================================
# MAE Imp0lementation : THE PAPER HAS TWO DIFFERENT SETTINGS FOR PRETRAINING & FINETUNING
# What the paper uses | Any changes that deviate from paper
# Backbone : VIT-L/16 | VIT-B/16
# Decoder : depth - 8 blocks, Width - 512d
# Encoder : w/o mask tokens
# Optimizer
# linear lr scaling (cosine decay) : lr = base_lr x bs / 256, base_lr = 1.5e-3/1e-3, warmup = 40/5 epochs
# AdamW : Beta1,2 = 0.9,0.95 / 0.9,0.999
# Batch Size : 4096/1024 | 1024/1024
# Loss func : MSE
# Masking : 75%
# Augmenntation : RandomResized Crop
# Crop Size : 224x224
# Fine tuning : 50 epochs vs Pretraining : 200 epochs
# Note that these are for Vit-L
# ==============================================================================
mae_params = config["mae_params"]
class MaeBBViT(pl.LightningModule):
def __init__(self, model_params:dict, backbone:VisionTransformer):
"""
Inputs
- model_params: dictionary containing model parameters such as lr, batch_size
- backbone : vit model with last layer removed
"""
super().__init__()
# Saving hyperparameters
hyper_dict = {}
hyper_dict.update(mae_params)
hyper_dict.update(model_params)
self.save_hyperparameters(hyper_dict)
self.model_params = model_params
self.backbone = backbone # we will need this later
self.mask_ratio = mae_params["mask_ratio"] # 0.75
self.patch_size = backbone.patch_embed.patch_size[0] #(16,16) so index [0] returns 16
self.masked_encoder = MaskedVisionTransformerTIMM(vit = backbone)
self.sequence_length = self.masked_encoder.sequence_length #197
self.decoder = MAEDecoderTIMM(
num_patches = backbone.patch_embed.num_patches, #196
patch_size = self.patch_size,#196
embed_dim = backbone.embed_dim, #768
decoder_embed_dim = mae_params["decoder_dim"], #512
decoder_depth = mae_params["decoder_depth"], #8
mlp_ratio = 4.0,
proj_drop_rate = 0.0, # drop out rate in projection head
attn_drop_rate = 0.0 # drop out rate in
)
self.criterion = nn.MSELoss()
self.apply_lr_scheduler = False if model_params["lr"] else True
def forward_encoder(self, images, idx_keep = None):
# shape returned from .encode = (bs, num_unmasked_patches, embed size)
# What you pass here as images are actually just unamsked portion of the image only
# So if 50/197 patches were left not masked you get the shape below (total 197 including cls) of which 50 is not masked
return self.masked_encoder.encode(images = images, idx_keep = idx_keep) #(*,3,224,224) -> (*,50,768)
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# Build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded) #(*,50,512) where 50 is number NOT makes patches, 512 is num dimension decalred above
x_masked = repeat_token(
self.decoder.mask_token,
(batch_size, self.sequence_length)
) #(*,197,512) where 197 nuber of masked patches & 512 is the nub dimensions declared above
x_masked = set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) #(*,197,512)
# decoder forward pass
x_decoded = self.decoder.decode(x_masked) #(*,197,512)
# Predict pixel values for masked tokens
x_pred = get_at_index(x_decoded, idx_mask) #(*,147,512)
x_pred = self.decoder.predict(x_pred) #(*,147,768)
return x_pred
def training_step(self, batch, batch_idx):
X,y,f = batch
images = X[0] # (*,3, 224,224) There is only a single view but its within a list
batch_size = images.shape[0]
# These are unmasked and masked patches. Rememnber that there is 197 patches including ...
# CLS token (14*14 + 1), 14 patches was obtained by 224/16 where each patch is 16 pixels in height and width
idx_keep, idx_mask = random_token_mask(
size = (batch_size, self.sequence_length),
mask_ratio = self.mask_ratio,
device = images.device
) #(*,50),(*,147)
x_encoded = self.forward_encoder(
images = images,
idx_keep = idx_keep
) #(*,50,768) where 50 is the number of NOT maked patches
x_pred = self.forward_decoder(
x_encoded = x_encoded,
idx_keep = idx_keep,
idx_mask = idx_mask
) #(*,147,768) where 147 are the number of masked patches
# get image patches for masked tokens
patches = patchify(images, self.patch_size) #(*,196,768)
# must adjust idx mask for missing class token
target = get_at_index(patches, idx_mask - 1) #(*, 147, 768)
self.loss = self.criterion(x_pred, target) #(,)
return self.loss
def on_train_epoch_end(self) -> None:
self.log("training loss" , self.loss)
if self.apply_lr_scheduler:
self.log("current lr", self.scheduler.get_lr()[0])
else:
self.log("current lr", self.model_params["lr"])
def configure_optimizers(self):
if self.apply_lr_scheduler:
optimizer = torch.optim.AdamW(
params= self.parameters(),
lr = mae_params["base_lr"] * self.model_params["eff_batch_size"] / 256,
weight_decay=mae_params["weight_decay"]
)
self.scheduler = LinearWarmupCosineAnnealingLR(
optimizer = optimizer,
warmup_epochs=mae_params["warmup_epochs"], # Linearly rampup lr as then decay using cosine as indicated in paper
max_epochs=self.model_params["epochs"],
warmup_start_lr=mae_params["base_lr"], # we linearly ramp up from 0 to base_lr which is indicated in the optimizer
eta_min=mae_params["eta_min"] #* We keep eta_min at 0 as Dino Paper hasnt indicated a value
)
return [optimizer],[{"scheduler" : self.scheduler, "interval" : "epoch"}]
else:
optimizer = torch.optim.AdamW(
params = self.parameters(), lr = self.model_params["lr"], weight_decay=mae_params["weight_decay"]
)
return optimizer