I’ve been trying to implement PyTorch’s nn.TransformerEncoder and nn.TransformerDecoder solutions into a simple model, but I’m running into an issue that I’m unable to resolve where during inference the model only produces the last token fed into it.
For example lets say I have a tensor [1,2,3,4,5] the model will continue the sequence with [1,2,3,4,5,5,5,5,5,5,…] or if I had [5,2,8,3] it would continue to produce [5,2,8,3,3,3,3,3,3,3,…] even when using training data as input although when using a new randomly initialized model it will produce diverse output although since not trained is useless.
Although it produces the above results, the loss continues to decrease as I train it further indicating that its managing to learn the dataset. Due to this I initially thought this was just a problem with the dataset where the target was the same as the input which would cause it to produce the same tokens, but after further testing I’m sure that the targets are definitely the next token in the sequence, for example the input would be [1,2,3,4] and the target would be [2,3,4,5].
This lead me to my current standing theory that there is something wrong with the seq2seq implementation but after much research and trying different implementations of the common components such as positional encoding, adjusting hyper-parameters and removing / adding masks to the encoder and decoder, but regardless still weeks later and I’m still zero progress towards identifying the issue.
For reference here is the model and training step I’m using:
class TextEmbedding(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, padding_index: int):
super(TextEmbedding, self).__init__()
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=padding_index)
def forward(self, x):
return self.embedding(x)
class TextTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim = 512, nhead = 8, num_encoder_layers = 6, num_decoder_layers = 6, max_length = 5000, padding_index = 0):
super(TextTransformer, self).__init__()
self.vocab_size = vocab_size
self.max_length = max_length
self.text_embedding = TextEmbedding(vocab_size, embed_dim, padding_index)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_length, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_encoder_layers)
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)
self.fc = nn.Sequential(
nn.Linear(embed_dim, vocab_size)
)
def forward(self, src, tgt, src_mask, tgt_mask):
#Embedding + Positional Encoding
src_embedding = self.text_embedding(src) + self.positional_encoding[:, :src.size(1), :]
tgt_embedding = self.text_embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]
tgt_square_mask = create_square_mask(tgt.size(1)).to(src.device)
#Encoder
memory = self.encoder(src_embedding.permute(1, 0, 2), src_key_padding_mask=src_mask)
#Decoder
decoder_out = self.decoder(tgt_embedding.permute(1, 0, 2), memory, tgt_mask=tgt_square_mask, tgt_key_padding_mask=tgt_mask)
decoder_out = decoder_out.permute(1, 0, 2)
#FC output
output = self.fc(decoder_out)
return output
def seq2seq(self, src, src_mask, stop_token, max_length = 500):
src_embedding = self.text_embedding(src) + self.positional_encoding[:, :src.size(1), :]
memory = self.encoder(src_embedding.permute(1, 0, 2), src_key_padding_mask=src_mask)
sequence = src
stop = False
while sequence.shape[1] < min(self.max_length, max_length) and not stop:
tgt_embedding = self.text_embedding(sequence) + self.positional_encoding[:, :sequence.size(1), :]
tgt_square_mask = create_square_mask(sequence.size(1)).to(src.device)
dec_output = self.decoder(tgt_embedding.permute(1, 0, 2), memory, tgt_mask=tgt_square_mask)
dec_output = dec_output.permute(1, 0, 2)
out = self.fc(dec_output)[:, -1, :]
predicted = out.argmax(dim=1)
if predicted.item() == stop_token:
stop = True
sequence = torch.cat((sequence, predicted.unsqueeze(dim=0)),dim=1)
return sequence
def create_square_mask(size):
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
return mask
def train_step(model, dataloader, criterion, optimizer, device):
avg_loss = 0
model.train()
for batch, (text_data, text_pad_mask) in enumerate(dataloader):
text_data, text_pad_mask = text_data.to(device), text_pad_mask.to(device)
#shift data so that the in_text is the initial tokens and that tgt_text is the next predicted token in the sequence
in_text = text_data[:, :-1]
in_mask = text_pad_mask[:, :-1]
tgt_text = text_data[:, 1:]
tgt_mask = text_pad_mask[:, 1:]
out = model(in_text, tgt_text, in_mask, tgt_mask)
outputs = out[:, :].reshape(-1, model.vocab_size)# Reshape to [batch_size * steps, vocab_size]
targets = tgt_text[:, :].reshape(-1)# Reshape to [batch_size * steps]
loss = criterion(outputs, targets)
avg_loss += loss.item()
loss.backward()
optimizer.step()
optimizer.zero_grad()
return avg_loss / len(dataloader)
The loss function is CrossEntropyLoss, the optimizer is AdamW and the dataloader returns tokenized texts in the shape of (batch, sequence). I think this is all that is necessary to try diagnose the issue as I’m 100% sure the tokenizer and data loader is working perfectly as I’ve done a lot of testing on them and don’t want to flood this post with too much code but I can provide the code for them upon request if it helps at all.
Thanks for your time.
Fox is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.