I am trying to fine-tune mit-b0 segmentation model on satellite images to segment rice paddies on RTX 2070 with 8 GB VRAM but I get CUDA out of memory at the beginning of the first epoch. I believe I have some memory allocation problem, please let me know what is the issue. I think I should be able to fit the b0 on my GPU
def train_step(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device):
model.train()
loss = 0.0
for i, batch in enumerate(dataloader):
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
z = model(pixel_values=pixel_values, labels=labels)
# logits_resized = nn.functional.interpolate(z.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
# predicted_labels = torch.argmax(logits_resized, dim=1)
loss += z.loss
loss = loss / len(dataloader)
return loss
def eval_step(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
device: torch.device):
model.eval()
loss = 0.0
with torch.inference_mode():
for i, batch in enumerate(dataloader):
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
z = model(pixel_values=pixel_values, labels=labels)
# logits_resized = nn.functional.interpolate(z.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
# predicted_labels = torch.argmax(logits_resized, dim=1)
loss += z.loss
loss = loss / len(dataloader)
return loss
def train_loop(dataset_loc: str = None,
num_epochs: int = 1,
batch_size: int = 32,
num_workers: int = 1,
model_path: str = None):
train_images = os.path.join(dataset_loc, "images/train")
train_masks = os.path.join(dataset_loc, "masks/train")
list_of_train_images = os.listdir(train_images)
val_images = os.path.join(dataset_loc, "images/val")
val_masks = os.path.join(dataset_loc, "masks/val")
list_of_val_images = os.listdir(val_images)
train_transform = A.Compose([
A.HorizontalFlip(p=0.3),
A.VerticalFlip(p=0.3),
A.RandomRotate90(p=0.3),
])
image_processor = SegformerImageProcessor(do_resize=True, size={"height": 256, "width": 256})
train_dataset = RiceDataset(images=list_of_train_images,
image_folder=train_images,
mask_folder=train_masks,
transform=train_transform,
image_processor=image_processor)
val_dataset = RiceDataset(images=list_of_val_images,
image_folder=val_images,
mask_folder=val_masks,
image_processor=image_processor)
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True)
val_dataloader = DataLoader(val_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "nvidia/mit-b0"
id2label = {0: "outer", 1: "rice_paddy"}
label2id = {label: id for id, label in id2label.items()}
num_labels = len(id2label)
model = SegformerForSemanticSegmentation.from_pretrained(
model_checkpoint,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
reshape_last_stage=True
)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
best_val_loss = float("inf")
best_model_state_dict = None
for epoch in tqdm(range(num_epochs)):
train_loss = train_step(model=model,
dataloader=train_dataloader,
optimizer=optimizer,
device=device)
val_loss = eval_step(model=model,
dataloader=val_dataloader,
device=device)
print(
f"Epoch: {epoch+1} | "
f"train_loss: {train_loss:.4f} | "
f"val_loss: {val_loss:.4f} | "
)
scheduler.step()
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state_dict = model.state_dict()
if best_model_state_dict is not None:
if model_path.endswith(".pth") or model_path.endswith(".pt"):
torch.save(best_model_state_dict, model_path)
else:
torch.save(best_model_state_dict, model_path + ".pth")
print(f"Best validation loss: {best_val_loss:.4f}")
print("DONE")
I tried to run with the same code with batch size of 2 on Google Colab but having the same issue
New contributor
Miras Sagirbay is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.