In the official implementation of sift by Microsoft/Deberta here, they implemented symmetric_kl function as:
def symmetric_kl(logits, target):
logit_stu = logits.view(-1, logits.size(-1)).float()
logit_tea = target.view(-1, target.size(-1)).float()
logprob_stu = F.log_softmax(logit_stu, -1)
logprob_tea = F.log_softmax(logit_tea, -1)
prob_tea = logprob_tea.exp().detach()
prob_stu = logprob_stu.exp().detach()
floss = ((prob_tea*(-logprob_stu)).sum(-1)) # Cross Entropy
bloss = ((prob_stu*(-logprob_tea)).sum(-1)) # Cross Entropy
loss = floss + bloss
return loss
My question is why do you need to detach() when computing prob_tea and prob_stu.
I checked and found that its presence affects the gradients computed.
New contributor
Sohaib Ahmed is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.