In PyTorch, assume I have a chain of computations:
import torch
x = torch.tensor(1.0)
x.requires_grad_(True)
for _ in range(10):
x = x + 1 # some complex computations
x_to_detach_later = x # cannot use x.detach() now because I want to backward later
for _ in range(10):
x = x + 1 # more complex computations
x.backward(retain_graph=True) # backward through the entire chain
# Now RETROSPECTIVELY detach x_to_detach_later to delete the beginning of graph
x.backward() # backward through only part of the chain
Is there any way to achieve this in PyTorch?
It is easy to delete the end of the graph by deleting all references to the tensors: PyTorch: When using backward(), how can I retain only part of the graph?. But this does not work for the beginning because of the backward pointing references of the other part of the graph (the one I need to retain for the second backward call).