I wanted to understand the purpose of the detach
method in PyTorch. Below is an example. If you look into the update_delta method, you see the detach method being used. I don’t understand what the author is trying to achieve by using detach.
class PerturbationLayer(torch.nn.Module):
def __init__(self, hidden_size, learning_rate=1e-4, init_perturbation=1e-2):
super().__init__()
self.learning_rate = learning_rate
self.init_perturbation = init_perturbation
self.delta = None
self.LayerNorm = torch.nn.LayerNorm(hidden_size, 1e-7, elementwise_affine=False)
self.adversarial_mode = False
def adversarial_(self, adversarial=True):
self.adversarial_mode = adversarial
if not adversarial:
self.delta = None
def forward(self, input):
if not self.adversarial_mode:
self.input = self.LayerNorm(input)
return self.input
else:
if self.delta is None:
self.update_delta(requires_grad=True)
return self.perturbated_input
def update_delta(self, requires_grad=False):
if not self.adversarial_mode:
return True
if self.delta is None:
delta = torch.clamp(
self.input.new(self.input.size())
.normal_(0, self.init_perturbation)
.float(),
-2 * self.init_perturbation, 2 * self.init_perturbation,
)
else:
grad = self.delta.grad
self.delta.grad = None
delta = self.delta
norm = grad.norm()
if torch.isnan(norm) or torch.isinf(norm):
return False
eps = self.learning_rate
with torch.no_grad():
delta = delta + eps * grad / (
1e-6 + grad.abs().max(-1, keepdim=True)[0]
)
self.delta = delta.float().detach().requires_grad_(requires_grad)
self.perturbated_input = (self.input.to(delta).detach() + self.delta).to(
self.input
)
return True
Questions:
What does the detach method do in this context?
Why is it necessary to use detach when updating delta?
Any insights would be appreciated!