I have a pytorch tensor with NaN inside, when I calculate the loss function using a simple MSE Loss the gradient becomes NaN even if I mask out the NaN values.
Weirdly this happens only when the mask is applyied after calculating the loss and only when the loss has a pow operation inside. The various cases follow
import torch
torch.autograd.set_detect_anomaly(True)
x = torch.rand(10, 10)
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[y > 0.5] = torch.nan
o = w @ x
l = (y - o)**2
l = l[~y.isnan()]
try:
l.mean().backward(retain_graph=True)
except RuntimeError:
print('(y-o)**2 caused nan gradient')
l = (y - o)
l = l[~y.isnan()]
try:
l.mean().backward(retain_graph=True)
except RuntimeError():
pass
else:
print('y-o does not cause nan gradient')
l = (y[~y.isnan()] - o[~y.isnan()])**2
l.mean().backward()
print('masking before pow does not propagate nan gradient')
What makes NaN gradients propagate when passing through the backward pass of the pow function?
The nans don’t come from the gradient, the nans come from the forward pass. These are multiplied by gradient values in the backward pass (chain rule).
Take a simpler example. Set exactly one value in y
to nan:
x = torch.rand(10, 10)
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[0,0] = torch.nan
Now compute your intermediates and retain gradients
o = w@x
o.retain_grad()
l = (y - o).pow(2)
l.retain_grad()
l_nonnan = l[~y.isnan()]
l_nonnan.retain_grad()
l_nonnan.mean().backward()
Inspect the gradients
l_nonnan
has full gradientsl
has full gradients except forl.grad[0,0]
which is0
o
has a nan gradient ato.grad[0,0]
w
has nan gradients for the entire first row
This is due to how the computation propagates. We set y[0,0] = torch.nan
. We compute l = (y - o).pow(2)
this means o[0,0]
is nan because it directly interacts with the nan from y
.
o
is created via o = w@x
. This means the value at o[0,0] = (w[0] * x[:,0]).sum()
. When we run the computation in reverse in backprop, the gradient of o[0,0]
(which we know to be nan) propagates back to all ements of w[0]
. This is why the entire row has nan gradients.
When you set a bunch of nans randomly, you get the same effect on more elements.
You can avoid this via l = (y[~y.isnan()] - o[~y.isnan()])**2
because when you do that you prevent the nans in y
from entering the computation in the first place.
4