I want to compute the gradient of a loss wrt a large neural network in pytorch. The loss is of form L = g(f(x1), f(x2), f(x3), ..., f(xn))
, where there are potentially many x
s, f
is the neural network, and g
is a relatively simple nonlinear function. A naive implementation is:
zs = []
for x in inputs:
z = f(x)
zs.append(z)
zs = torch.stack(zs)
L = g(zs)
L.backward()
But this will eat up VRAM, as there will be len(inputs)
copies of computation graphs attached to L
. I decide instead to compute through chain rule. Let dL
be the gradient of the loss wrt f
‘s parameters, and dzi
be the Jacobian of zi
wrt f
‘s parameters. Then dL = dL/dz1 dz1 + dL/dz2 dz2 + ... + dL/dzn dzn
. If we know dL/dzi
beforehand, we may simply accumulate each dL/dzi dzi
into the .grad
of neural network parameters, without keeping more than one computation graph within VRAM. Here is the code:
zs = []
# Do not build computation graph.
with torch.no_grad():
for x in inputs:
z = f(x) # NOTE #1
zs.append(z)
# Compute dL/dz.
zs = torch.stack(zs).requires_grad_()
L = g(zs)
L.backward()
zs_grad = torch.empty_like(zs).copy_(zs.grad)
# Each computation graph will be freed at backward.
for x, dL_dz in zip(inputs, zs_grad):
z = f(x) # NOTE #2
z.backward(gradient=dL_dz)
If f
is pure function, this should work. However, in reality, f
may contain batch normalization, dropout etc. that preserve some sort of inner state. As a matter of fact, in my use case, f
will be a pre-trained model, and I can’t change its component arbitrarily. I need to ensure that z at NOTE #1
is identical to z at NOTE #2
for each x, but randomness in dropout will ruin the identity. Even if randomness is somehow controlled, passing x twice through a batch norm will ruin the inner running average of the batch norm as well.
I thought implementing the gradient computation should be fairly easy, but then I got stuck. I asked the same question on pytorch forum. But it’s not getting much attention, so I ask again here. Thank you so much for your help!
3