I’m trying to implement a training loop for a transformer, using an encoder/decoder pair. Based on my understanding of how transformers work, I was expecting that to generate my output sequence, I would have to iterate the decoder, feeding its output back as input on each step, and building up my output sequence one token at a time.
Following the notation used in the tensorflow documentation, assume the output sequence has dimension (B, T, dim)
, where B
is the batch size, T
is the target/output sequence length and dim
is the size of the target/output embedding vector.
So to generate the complete output sequence, I would have to iterate the decoder T
times. To prevent the decoder’s self-attention mechanism from attending to tokens ahead of the current step t
, I have to mask all sequence tokens from position t+1
onwards.
For example:
T = 4
generate token at position 0: look_ahead_mask = [0, 0, 0, 0].T
generate token at position 1: look_ahead_mask = [1, 0, 0, 0].T
generate token at position 2: look_ahead_mask = [1, 1, 0, 0].T
generate token at position 3: look_ahead_mask = [1, 1, 1, 0].T
My confusion comes when looking at how the MultiHeadAttention layer implements the look-ahead mask. From online examples (e.g. Jason Brownlee’s blog), the mask is passed as an array with dimensions (T, T)
, and seems to represent all T
masks that would be needed if generating the output sequence one token at a time.
For example:
T = 4
look_ahead_mask = [[0, 1, 1, 1]
[0, 0, 1, 1]
[0, 0, 0, 1]
[0, 0, 0, 0]]
Does this mean that a single call to the decoder is internally looping through the entire output sequence? And if this is the case, how does it know when to stop generating any given sequence in the batch when an EOS
token is generated? Do I not then need to manually loop over tokens in the output sequence?