Currently you can let SFTTrainer teach your models to learn to predict every token in your dataset, or you can let it train on “completions only”, using the DataCollatorForCompletionOnlyLM
class.
I would like something in between, where certain tokens have a higher weight than others.
I thought it would be fairly trivial, but nope.
Here’s what I currently came up with (using Unsloth, so I can try this out on Google Collab):
import transformers
import torch.nn as nn
import torch
from datetime import datetime
from transformers import PreTrainedTokenizerBase
from typing import List, Dict, Any
from unsloth import is_bfloat16_supported
from trl import SFTTrainer
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers.modeling_utils")
class WeightedLossTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
logger.info("Compute loss starts")
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
weight_ranges = inputs.get("weight_ranges")
batch_size, seq_len, num_classes = logits.shape
loss_fct = nn.CrossEntropyLoss(reduction='none')
total_weighted_loss = 0.0
total_weights = 0.0
logger.info(f"Doing {batch_size} batch sizes")
for batch_idx in range(batch_size):
# Collect weights and losses.
batch_weighted_losses = []
for start_idx, end_idx, weight in weight_ranges[batch_idx]:
logit_chunk = logits[batch_idx, start_idx:end_idx + 1]
label_chunk = labels[batch_idx, start_idx:end_idx + 1]
loss = loss_fct(logit_chunk.view(-1, num_classes), label_chunk.view(-1))
weighted_loss = loss * weight
batch_weighted_losses.append(weighted_loss.sum())
total_weights += weight * (end_idx - start_idx + 1) # Total token count in this range
# Sum the weighted losses for the batch.
batch_weighted_loss_sum = torch.stack(batch_weighted_losses).sum()
total_weighted_loss += batch_weighted_loss_sum.detach()
# Compute the mean loss.
mean_loss = total_weighted_loss / total_weights
mean_loss = torch.tensor(mean_loss, dtype=torch.float32, device=logits.device, requires_grad=True)
logger.info(f"Mean loss: {mean_loss}")
return (mean_loss, outputs) if return_outputs else mean_loss
class WeightedDataCollator:
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
def __call__(self, examples: List):
all_input_ids = []
all_attention_masks = []
all_weight_ranges = []
for entry in examples:
example_input_ids = []
example_attention_masks = []
example_weight_ranges = []
current_length = 0 # Initialize length counter
for item in entry['pieces']:
tokenized = self.tokenizer(item['text'], truncation=True, padding=False, return_tensors='pt')
input_ids = tokenized.input_ids.squeeze() # Get tensor, remove batch dimension
attention_mask = tokenized.attention_mask.squeeze() # Get tensor, remove batch dimension
start_idx = current_length
end_idx = start_idx + len(input_ids) - 1
example_input_ids.append(input_ids)
example_attention_masks.append(attention_mask)
example_weight_ranges.append((start_idx, end_idx, item['weight']))
current_length = end_idx + 1 # Update current length
concatenated_input_ids = torch.cat(example_input_ids, dim=0) if example_input_ids else torch.tensor([], dtype=torch.long)
concatenated_attention_masks = torch.cat(example_attention_masks, dim=0) if example_attention_masks else torch.tensor([], dtype=torch.long)
pad_length = max_seq_length - len(concatenated_input_ids) # Assuming max_length = 512 for padding if needed
if pad_length > 0:
concatenated_input_ids = torch.cat([concatenated_input_ids, torch.tensor([self.tokenizer.pad_token_id] * pad_length)])
concatenated_attention_masks = torch.cat([concatenated_attention_masks, torch.tensor([0] * pad_length)])
all_input_ids.append(concatenated_input_ids)
all_attention_masks.append(concatenated_attention_masks)
all_weight_ranges.append(example_weight_ranges)
logger.info(f"All ranges: {all_weight_ranges}")
return {
"input_ids": torch.stack(all_input_ids),
"attention_mask": torch.stack(all_attention_masks),
"labels": torch.stack(all_input_ids).clone(),
"weight_ranges": all_weight_ranges
}
# Define data collator
data_collator = WeightedDataCollator(tokenizer=tokenizer)
# Prepare dataset for the data collator
#collated_data = data_collator(dataset)
training_args = transformers.TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 60,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 5,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
remove_unused_columns=False,
)
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
trainer = WeightedLossTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
data_collator=data_collator,
max_seq_length = max_seq_length,
dataset_num_proc = 2,
args = training_args,
packing=False,
dataset_text_field='text',
dataset_kwargs={'skip_prepare_dataset': True}
)
trainer_stats = trainer.train()
Each entry in my dataset is an object that has a single property pieces
.
pieces
is an array, and it contains other objects. Each object inside it has a text
and a weight
property.
As soon as it starts to calculate the loss, it seems to take a long while (a few seconds) until it eventually just OOMs: ran out of CUDA memory.
So what exactly am I doing wrong, and how can I fix it?