Gradient accumulation loss compute
Suppose we have data [b,s,dim], I recently noticed that CrossEntropyLoss is (1) computed the average on all tokens (b * s) in a batch instead of (2) computing on each sentence and then compute the average.
Here is the code to compute loss for hugging_face transformers BertLMHeadMode