I’m writing codes for multi-GPU training, and I need to gather embeddings from different gpus to calculate loss and then propagate the gradients back to different GPUs. However, when the programs runs to optimizer.step(), the memory usage increases dramatically and resulted in a out-of-memory problem. The code is as below, thanks!
def training_stage_2(model, optimizer, train_dataloader, val_dataloader, tokenizer, accelerator, epochs):
# finetuning the model with query-doc pairs
for epoch in range(epochs):
for batch in tqdm(train_dataloader):
with accelerator.accumulate(model):
# take the last hidden states of eos token as embeddings]
query_inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
query_embeds_ = model(**query_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
doc_inputs = tokenizer(doc, return_tensors="pt", padding=True, truncation=True, max_length=512)
doc_embeds_ = model(**doc_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
# collect embeddings from all gpus
query_embeds = torch.zeros((query_embeds_.shape[0] * accelerator.num_processes, query_embeds_.shape[1]), device=accelerator.device, dtype=query_embeds_.dtype)
doc_embeds = torch.zeros((doc_embeds_.shape[0] * accelerator.num_processes, doc_embeds_.shape[1]), device=accelerator.device, dtype=doc_embeds_.dtype)
dist.all_gather(list(query_embeds.chunk(accelerator.num_processes, dim=0)), query_embeds_.data)
dist.all_gather(list(doc_embeds.chunk(accelerator.num_processes, dim=0)), doc_embeds_.data)
# requires grad for embeddings
query_embeds.requires_grad = True
doc_embeds.requires_grad = True
loss = classification_loss_single_vector(query_embeds, doc_embeds)
accelerator.backward(loss)
# scatter the gradients to all gpus
if query_embeds.grad is not None:
query_embeds.grad.detach_()
doc_embeds.grad.detach_()
query_grad = torch.zeros_like(query_embeds_)
doc_grad = torch.zeros_like(doc_embeds_)
# feature gradient all-reduce
dist.reduce_scatter(query_grad, list(query_embeds.grad.chunk(accelerator.num_processes, dim=0)))
dist.reduce_scatter(doc_grad, list(doc_embeds.grad.chunk(accelerator.num_processes, dim=0)))
query_grad.mul_(accelerator.num_processes)
doc_grad.mul_(accelerator.num_processes)
query_embeds_.backward(query_grad)
doc_embeds_.backward(doc_grad)
<code>
def training_stage_2(model, optimizer, train_dataloader, val_dataloader, tokenizer, accelerator, epochs):
# finetuning the model with query-doc pairs
for epoch in range(epochs):
model.train()
for batch in tqdm(train_dataloader):
query, doc = batch
with accelerator.accumulate(model):
# take the last hidden states of eos token as embeddings]
query_inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
query_embeds_ = model(**query_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
doc_inputs = tokenizer(doc, return_tensors="pt", padding=True, truncation=True, max_length=512)
doc_embeds_ = model(**doc_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
# collect embeddings from all gpus
query_embeds = torch.zeros((query_embeds_.shape[0] * accelerator.num_processes, query_embeds_.shape[1]), device=accelerator.device, dtype=query_embeds_.dtype)
doc_embeds = torch.zeros((doc_embeds_.shape[0] * accelerator.num_processes, doc_embeds_.shape[1]), device=accelerator.device, dtype=doc_embeds_.dtype)
dist.all_gather(list(query_embeds.chunk(accelerator.num_processes, dim=0)), query_embeds_.data)
dist.all_gather(list(doc_embeds.chunk(accelerator.num_processes, dim=0)), doc_embeds_.data)
# requires grad for embeddings
query_embeds.requires_grad = True
doc_embeds.requires_grad = True
loss = classification_loss_single_vector(query_embeds, doc_embeds)
accelerator.backward(loss)
# scatter the gradients to all gpus
if query_embeds.grad is not None:
query_embeds.grad.detach_()
doc_embeds.grad.detach_()
query_grad = torch.zeros_like(query_embeds_)
doc_grad = torch.zeros_like(doc_embeds_)
# feature gradient all-reduce
dist.reduce_scatter(query_grad, list(query_embeds.grad.chunk(accelerator.num_processes, dim=0)))
dist.reduce_scatter(doc_grad, list(doc_embeds.grad.chunk(accelerator.num_processes, dim=0)))
query_grad.mul_(accelerator.num_processes)
doc_grad.mul_(accelerator.num_processes)
# backward the model
query_embeds_.backward(query_grad)
doc_embeds_.backward(doc_grad)
optimizer.step()
optimizer.zero_grad()
</code>
def training_stage_2(model, optimizer, train_dataloader, val_dataloader, tokenizer, accelerator, epochs):
# finetuning the model with query-doc pairs
for epoch in range(epochs):
model.train()
for batch in tqdm(train_dataloader):
query, doc = batch
with accelerator.accumulate(model):
# take the last hidden states of eos token as embeddings]
query_inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
query_embeds_ = model(**query_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
doc_inputs = tokenizer(doc, return_tensors="pt", padding=True, truncation=True, max_length=512)
doc_embeds_ = model(**doc_inputs).last_hidden_state[:,-1,:] # shape: (batch_size, hidden_size)
# collect embeddings from all gpus
query_embeds = torch.zeros((query_embeds_.shape[0] * accelerator.num_processes, query_embeds_.shape[1]), device=accelerator.device, dtype=query_embeds_.dtype)
doc_embeds = torch.zeros((doc_embeds_.shape[0] * accelerator.num_processes, doc_embeds_.shape[1]), device=accelerator.device, dtype=doc_embeds_.dtype)
dist.all_gather(list(query_embeds.chunk(accelerator.num_processes, dim=0)), query_embeds_.data)
dist.all_gather(list(doc_embeds.chunk(accelerator.num_processes, dim=0)), doc_embeds_.data)
# requires grad for embeddings
query_embeds.requires_grad = True
doc_embeds.requires_grad = True
loss = classification_loss_single_vector(query_embeds, doc_embeds)
accelerator.backward(loss)
# scatter the gradients to all gpus
if query_embeds.grad is not None:
query_embeds.grad.detach_()
doc_embeds.grad.detach_()
query_grad = torch.zeros_like(query_embeds_)
doc_grad = torch.zeros_like(doc_embeds_)
# feature gradient all-reduce
dist.reduce_scatter(query_grad, list(query_embeds.grad.chunk(accelerator.num_processes, dim=0)))
dist.reduce_scatter(doc_grad, list(doc_embeds.grad.chunk(accelerator.num_processes, dim=0)))
query_grad.mul_(accelerator.num_processes)
doc_grad.mul_(accelerator.num_processes)
# backward the model
query_embeds_.backward(query_grad)
doc_embeds_.backward(doc_grad)
optimizer.step()
optimizer.zero_grad()
I tried using accelerator.gather to replace dist.gather_all but it doesn’t work.