I was wondering why the gradient in this scaled dot product example does not flow to the key and value. What am I doing wrong? How can I use padded batches with different target sequence lengths? Can I pad keys/values and queries together?
import torch
from torch.nn.functional import scaled_dot_product_attention
k = v = torch.rand(3, 4, 8)
q = torch.rand(3, 5, 8)
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
mask = torch.ones(3, 5, 4, dtype=torch.bool)
mask[:, :, -1] = 0
mask[:, -1, :] = 0
out = scaled_dot_product_attention(q, k, v, attn_mask=mask)
torch.mean(out[:, :-1, :]).backward()