I’m writing codes for multi-GPU training, and I need to gather embeddings from different gpus to calculate loss and then propagate the gradients back to different GPUs. However, when the programs runs to optimizer.step(), the memory usage increases dramatically and resulted in a out-of-memory problem. The code is as below, thanks!
def training_stage_2(model, optimizer, train_dataloader, val_dataloader, tokenizer, accelerator, epochs):
# finetuning the model with query-doc pairs
for epoch in range(epochs):
model.train()
for batch in tqdm(train_dataloader):
query, doc = batch
with accelerator.accumulate(model):
# take the last hidden states of eos token as embeddings]
query_inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
query_embeds_ = model(**query_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
doc_inputs = tokenizer(doc, return_tensors="pt", padding=True, truncation=True, max_length=512)
doc_embeds_ = model(**doc_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
# collect embeddings from all gpus
query_embeds = torch.zeros((query_embeds_.shape[0] * accelerator.num_processes, query_embeds_.shape[1]), device=accelerator.device, dtype=query_embeds_.dtype)
doc_embeds = torch.zeros((doc_embeds_.shape[0] * accelerator.num_processes, doc_embeds_.shape[1]), device=accelerator.device, dtype=doc_embeds_.dtype)
dist.all_gather(list(query_embeds.chunk(accelerator.num_processes, dim=0)), query_embeds_.data)
dist.all_gather(list(doc_embeds.chunk(accelerator.num_processes, dim=0)), doc_embeds_.data)
# requires grad for embeddings
query_embeds.requires_grad = True
doc_embeds.requires_grad = True
loss = classification_loss_single_vector(query_embeds, doc_embeds)
accelerator.backward(loss)
# scatter the gradients to all gpus
if query_embeds.grad is not None:
query_embeds.grad.detach_()
doc_embeds.grad.detach_()
query_grad = torch.zeros_like(query_embeds_)
doc_grad = torch.zeros_like(doc_embeds_)
# feature gradient all-reduce
dist.reduce_scatter(query_grad, list(query_embeds.grad.chunk(accelerator.num_processes, dim=0)))
dist.reduce_scatter(doc_grad, list(doc_embeds.grad.chunk(accelerator.num_processes, dim=0)))
query_grad.mul_(accelerator.num_processes)
doc_grad.mul_(accelerator.num_processes)
# backward the model
query_embeds_.backward(query_grad)
doc_embeds_.backward(doc_grad)
optimizer.step()
optimizer.zero_grad()
I tried using accelerator.gather to replace dist.gather_all but it doesn’t work.
Drack Young is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.