I am attempting to implement the temperature parameter within a contrastive loss, and I have no idea if I’m doing well. However, it seems to be working because the learning seems to be harder than without the temperature parameter.
First, here’s my training loop.
initial_temperature = 1.0
temperature_decay = 0.1
for epoch in range(num_epochs_vae):
contrastive_loss_total = 0 # Initialize total contrastive loss
epoch_counter += 1
temperature = initial_temperature * (1 / (1 + temperature_decay * epoch))
for batch_idx, data in enumerate(tqdm(train_dataloader)):
optimizer_vae.zero_grad()
# Pass Forward
mu, logvar = vae(data)
# Generate pairs of latent vectors and their labels for contrastive loss
z1, z2, labels = generate_contrastive_pairs(mu, logvar, data, temperature)
# Calculate contrastive loss
contrastive_loss = vae.module.contrastive_loss(z1, z2, labels)
# Backpropagation
contrastive_loss.backward()
optimizer_vae.step()
# Accumulate total contrastive loss
contrastive_loss_total += contrastive_loss.item()
Then, here’s my generate_contrastive_pairs function.
threshold = 1.0
def generate_contrastive_pairs(mu, logvar, data, temperature):
batch_size = data.size(0)
# Reparameterize to obtain latent vectors
z = vae.module.reparameterize(mu, logvar)
# Randomly select positive and negative pairs
idx1 = torch.randint(0, batch_size, (batch_size,))
idx2 = torch.randint(0, batch_size, (batch_size,))
# Ensure that idx1 and idx2 are not the same index
idx2 = torch.where(idx1 == idx2, (idx2 + 1) % batch_size, idx2)
# Obtain latent vectors corresponding to the selected indices
z1 = z[idx1]
z2 = z[idx2]
# Calculate distances between latent vectors
distances = torch.norm(z1 - z2, dim=1)
# Apply temperature scaling to distances
scaled_distances = distances / temperature
# Define labels: 1 for similar pairs, 0 for dissimilar pairs
labels = (scaled_distances < threshold).float().unsqueeze(1) # Add an extra dimension to match the shape of sim_scores
return z1, z2, labels
And finally, my loss function.
def contrastive_loss(self, z1, z2, label):
projection1 = self.projector(z1)
projection2 = self.projector(z2)
sim_scores = self.contrastive_head(torch.abs(projection1 - projection2))
criterion = nn.BCEWithLogitsLoss()
contrastive_loss = criterion(sim_scores, label)
return contrastive_loss