KL divergence loss too high
I’m trying to perform knowledge distillation . For my student loss I have used cross entropy loss and for my knowledge distillation loss I am trying to use KL divergence loss.
Here is the code that I used for my KL divergence loss.
class KLDivLoss(nn.Module):
def __init__(self,ignore_index=-1, reduction="batchmean", log_target=False):
super(KLDivLoss, self).__init__()
self.reduction = reduction
self.log_target = log_target
self.ignore_index = ignore_index
def forward(self, preds_S, preds_T, T =1.0, alpha = 1.0):
preds_T[0] = preds_T[0].detach() # Detach teacher predictions
pred_1 = torch.sigmoid(preds_T[0]/T) # white
pred_0 = 1 - pred_1
preds_teacher = torch.cat((pred_0, pred_1), dim=1)
assert preds_S[0].shape == preds_teacher.shape, "Input and target shapes must match for KLDivLoss"
stu_prob = F.log_softmax(preds_S[0]/T, dim=1)
kd_loss = F.kl_div(stu_prob,
preds_teacher,
reduction='batchmean',
) * T * T
return {'loss': kd_loss}
The values that I am getting from this are extremely huge. I am simply adding my knowledge distillation loss and cross entropy loss from student model. Since my CE loss is very small this is all from the KLdiv loss. Could you tell me how to reduce the loss? Or if I am doing something wrong.
enter image description here
I tried using the KL div loss with temperature =1
my teacher model gave the output in the form of tensor [8,1,224,224] as it was used for binary prediction of pixel while my student model gave output in the form [8,2,224,224] where 0 belongs to class black and 1 to white.
so inorder to match them up for KL div loss I used sigmoid function to get probablities for class white and 1 – white probablity for black. and then concatenated them to form a tensor of size [8,2,224,224] which would be similar to the student tensor.
and then i tried performing the KL divergence. the losses i got were extremely high
Prashna Thapa is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.