I’m working on a Vision Transformer using the Pytorch Lightning framework. I ran into an issue where the gradient and predictions would end up as nan after a few iterations of the training step.
After much trial and error, I discovered that adding these lines of preprocessing code (before the data is passed to the model) to set non-finite values to 0 fixed the issue:
non_fin = ~torch.isfinite(x)
x[non_fin] = 0
The model now trains and runs as expected.
However, what still confuses me is that, earlier in my debugging process, I had added these lines to the forward pass of the model class to try to check for the exact same issue:
if not torch.all(torch.isfinite(x)) :
print("THERE ARE EITHER NANS OR INF IN THE INPUT ITSELF")
raise Exception("THERE ARE NANS OR INFS IN THE INPUT X")
When doing this (and without the preprocessing step), the gradient and predictions still end up as nan, but nothing is printed and no errors are thrown. This confuses me because logically, I’m thinking that
preprocessing step solves the issue ->
at least some values in x are being replaced with 0 ->
the if statement should trigger in the forward pass (when the preprocessing isn’t added)
Am I missing something? I can provide more details if necessary.
Brandon Zhao is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.