I am fine-tuning an embedding model using contrastive learning. For the loss function, I’m using torch.nn.CrossEntropyLoss
.
The training process initially seems to work fine — the loss decreases steadily on average. However, at some point during training (usually around step 16,000 in this case), the loss and gradients explode. After this, the model is unable to stabilize, and training becomes unusable.
Here is a graph showing the behavior:
Tensorboard loss by step graph
What I have tried so far:
-
Data preprocessing:
-
Removed outliers (e.g., very long texts).
-
Cleaned and filtered the dataset for consistency.
-
-
Hyperparameter tuning:
-
Adjusted the learning rate and tried different values.
-
Changed the optimizer (e.g., switching from Adam to SGD)
-
-
Gradient clipping:
- Clipped gradients to a max norm of 1 using
torch.nn.utils.clip_grad_norm_
.
- Clipped gradients to a max norm of 1 using
My setup:
-
Dataset size: ~14,000 samples
-
Model architecture: Transformer-based embedding model
-
Batch size: 1 (given my gpu capacity)
-
Learning rate: 1e-5
-
Optimizer: Adam with weight decay
Training loop (relevant part):
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
for step, batch in enumerate(dataset_train):
temperature = max(0.1, 0.05 * (1 - step / num_training_steps))
# Move data to device
anchor_input_ids = batch["anchor_input_ids"].to(device)
anchor_attention_mask = batch["anchor_attention_mask"].to(device)
positive_input_ids = batch["positive_input_ids"].to(device)
positive_attention_mask = batch["positive_attention_mask"].to(device)
negative_input_ids = batch["negative_input_ids"].to(device)
negative_attention_mask = batch["negative_attention_mask"].to(device)
anchor_input_ids = anchor_input_ids.unsqueeze(0) # Add a dimension for the batch
anchor_attention_mask = anchor_attention_mask.unsqueeze(0) # Add a dimension for the batch
positive_input_ids = positive_input_ids.unsqueeze(0) # Add a dimension for the batch
positive_attention_mask = positive_attention_mask.unsqueeze(0) # Add a dimension for the batch
negative_input_ids = negative_input_ids.unsqueeze(0) # Add a dimension for the batch
negative_attention_mask = negative_attention_mask.unsqueeze(0) # Add a dimension for the batch
optimizer.zero_grad()
# Generate embeddings
anchor_embeddings = model.forward(anchor_input_ids, anchor_attention_mask)
positive_embeddings = model.forward(positive_input_ids, positive_attention_mask)
negative_embeddings = model.forward(negative_input_ids, negative_attention_mask)
# Calculate cosine similarities
pos_sim = cosine_similarity(anchor_embeddings, positive_embeddings)
neg_sim = cosine_similarity(anchor_embeddings, negative_embeddings)
# Calculate logits
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim.unsqueeze(1)], dim=1) / temperature
labels = torch.zeros(logits.size(0), dtype=torch.long).to(device) # The positive class is always the first
# Calculate InfoNCE loss
loss = torch.nn.CrossEntropyLoss()(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
Dataset generation:
import torch
from torch.utils.data import Dataset
import random
class FineTuneContrastiveDataset(Dataset):
def __init__(self, pairs_data, df_1, df_2, tokenizer, max_tokens=512):
"""
pairs_data: List of tuples (id_1, id_2, label).
df_1: DataFrame containing text data associated with id_1.
df_2: DataFrame containing text data associated with id_2.
tokenizer: Hugging Face tokenizer.
max_tokens: Maximum allowed length for the tokenized text.
"""
self.pairs_data = pairs_data
self.df_1 = df_1.set_index("id_1")
self.df_2 = df_2.set_index("id_2")
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.id_2_list = list(self.df_2.index) # For selecting negative samples
def __len__(self):
return len(self.pairs_data)
def __getitem__(self, idx):
# Retrieve data from the pair
id_1, id_2_positive, label = self.pairs_data[idx]
# Text associated with id_1 (anchor)
text_1 = " ".join(self.df_1.loc[id_1]["chunks"])
# Positive text associated with id_2
text_2_positive = " ".join(self.df_2.loc[id_2_positive]["chunks"])
# Generate a negative sample from id_2
id_2_negative = random.choice(
[candidate_id for candidate_id in self.id_2_list if candidate_id != id_2_positive]
)
text_2_negative = " ".join(self.df_2.loc[id_2_negative]["chunks"])
# Tokenize inputs
inputs_anchor = self.tokenizer(
text_1, truncation=True, max_length=self.max_tokens,
padding="max_length", return_tensors="pt"
)
inputs_positive = self.tokenizer(
text_2_positive, truncation=True, max_length=self.max_tokens,
padding="max_length", return_tensors="pt"
)
inputs_negative = self.tokenizer(
text_2_negative, truncation=True, max_length=self.max_tokens,
padding="max_length", return_tensors="pt"
)
return {
"anchor_input_ids": inputs_anchor["input_ids"].squeeze(0),
"anchor_attention_mask": inputs_anchor["attention_mask"].squeeze(0),
"positive_input_ids": inputs_positive["input_ids"].squeeze(0),
"positive_attention_mask": inputs_positive["attention_mask"].squeeze(0),
"negative_input_ids": inputs_negative["input_ids"].squeeze(0),
"negative_attention_mask": inputs_negative["attention_mask"].squeeze(0),
"label": torch.tensor(label, dtype=torch.float),
"id_1": id_1,
}
dakinga is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
Batch size = 1 seems problem to me. You are giving one data point at a time and because of this the updates in the weight has high varaince and it make the convergence difficult and unstable.
And try to use Gradient Scaling Before Clipping.