I am building a neural network in PyTorch with an inner loop which applies a series of matrix operations recursively to a BxNxN
matrix.
Below is a minimum working example demonstrating this inner loop, and the mean gradient as a function of the number of recursions applied:
import torch
import torch.nn.functional as f
import matplotlib.pyplot as plt
import numpy as np
def Normalise(matrix):
l1_norm = torch.sum(torch.abs(matrix), axis=-2, keepdim=True)
epsilon = torch.finfo(matrix.dtype).eps
l1_norm = l1_norm + epsilon
norm = matrix / l1_norm
return norm
def NormAndSquare(matrix):
norm = Normalise(matrix)
mp = torch.pow(norm, 2)
return mp
B, N = 1, 10
logits = torch.randn(B, N, N, requires_grad=True)
y = torch.eye(N).unsqueeze(0).repeat(B, 1, 1)
# Initialize y_hat
y_hat = f.softmax(logits, dim=1)
mean_gradients = []
for i in range(13):
# Zero the gradients
if logits.grad is not None:
logits.grad.zero_()
# Apply NormAndSquare
y_hat = NormAndSquare(y_hat)
# Compute loss
loss = f.mse_loss(y, y_hat)
# Backward pass
loss.backward(retain_graph=True)
mean_gradients.append(torch.abs(logits.grad).mean().detach().item())
plt.figure()
plt.plot(np.arange(13), mean_gradients)
plt.yscale("log")
plt.ylim([1e-16, 0.1])
plt.grid(which="both")
plt.ylabel("Mean Gradient")
plt.xlabel("Iteration")
plt.show()
In this example, after around iteration 6, the gradients fed back to logits
rapidly vanish with successive iterations.
What methods exist to handle this sort of situation? Are there ways I can keep an optimizer updated within an internal loop to avoid this effect?