I have a question regarding the implementation of beam search in the openNMT library, more specifically in the file translate/beam.py. i am using the beam class for a custom seq2seq LM. In this file, for reconstructing which beam and word the respective log probability came from, the authors divide the ids of the topk output on the flattened representation of bos token probs by the num_words variable, which is, according to the comment, the second dim in a tensor of size (K x words), where I can only guess that K shall be the beam size. At the start, when no previous nodes were computed, (line 125), the tensor is indexed, resulting in a tensor of size [words], which results in indeces being divided by the vocab size in line 136 for computing prev_k such that floats are the outcome, which may not serve as indeces in line 138.
Can somebody explain what the issue here is? What is num_words supposed to do in this division for the bos condition?
I have created a minimal example with a simple lstm rnn encoder decoder model for reproducing the error.
A “minimal” code example:
import tensorflow as tf
import torch
from typing import List
import random
import time
# text data
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
class Dataset(torch.utils.data.Dataset):
def __init__(self, text, seq_length):
super(Dataset, self).__init__()
self.text_as_token_sequence = text.split()
# extend tokenization with more complex
# preprocessing if you like
self.seq_length = seq_length
self.w2index = {"<PAD>":0, "<UNK>":1, "SOS":2, "EOS":3}
self.build_vocabulary(text)
self.index2w = {val:key for key, val in self.w2index.items()}
self.text_as_index_sequence = [self.w2index[x] for x in self.text_as_token_sequence]
def indices_to_words(self, index_sequence:torch.Tensor):
# returns sequences of words
print(index_sequence.dim())
if index_sequence.dim() == 1:
return [self.index2w[x.item()] for x in index_sequence]
elif index_sequence.size(dim=1) == 1:
return [self.index2w[x.item()] for x in index_sequence]
else:
return [[self.index2w[x.item()] for x in y] for y in index_sequence]
def words_to_indices(self, token_sequence:List[str]):
# returns sequences of indices
if type(token_sequence[0]) == str:
preliminary = [self.w2index[x] for x in token_sequence]
index_tensor = torch.tensor(preliminary)
return index_tensor
else:
preliminary = [[self.w2index[x] for x in y] for y in token_sequence]
index_tensor = torch.tensor(preliminary)
return index_tensor
def get_vocabulary_size(self):
return len(self.w2index.keys())
def build_vocabulary(self, text):
counter = 4
for token in self.text_as_token_sequence:
if token not in self.w2index.keys():
self.w2index[token] = counter
counter += 1
else:
continue
# the following functions are needed for compatibility
# with troch Datasets and Dataloaders
# returns the number of data points or more precisely (input, target)-tuples
def __len__(self):
return len(self.text_as_index_sequence) - (2*self.seq_length)
# This function returns sequences of text with respective continuations
# of the same length
# the first sequence serves as the prompt (conditioning sentence)
# while the second serves as the target sequence
def __getitem__(self, index):
return (
torch.tensor(self.text_as_index_sequence[index:index+self.seq_length]),
torch.tensor(self.text_as_index_sequence[index+self.seq_length:index+(2*self.seq_length)]),
)
class Beam(object):
"""
Class for managing the internals of the beam search process.
Takes care of beams, back pointers, and scores.
Args:
size (int): beam size
pad, bos, eos (int): indices of padding, beginning, and ending.
n_best (int): nbest size to use
cuda (bool): use gpu
global_scorer (:obj:`GlobalScorer`)
"""
def __init__(self, size, pad, bos, eos,
n_best=1, cuda=False,
global_scorer=None,
min_length=0,
stepwise_penalty=False,
block_ngram_repeat=0,
exclusion_tokens=set()):
self.size = size
self.tt = torch.cuda if cuda else torch
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
self.all_scores = []
# The backpointers at each time-step.
self.prev_ks = []
# The outputs at each time-step.
self.next_ys = [self.tt.LongTensor(size)
.fill_(pad)]
self.next_ys[0][0] = bos
# Has EOS topped the beam yet.
self._eos = eos
self.eos_top = False
# The attentions (matrix) for each time.
self.attn = []
# Time and k pair for finished.
self.finished = []
self.n_best = n_best
# Information for global scoring.
self.global_scorer = global_scorer
self.global_state = {}
# Minimum prediction length
self.min_length = min_length
# Apply Penalty at every step
self.stepwise_penalty = stepwise_penalty
self.block_ngram_repeat = block_ngram_repeat
self.exclusion_tokens = exclusion_tokens
def get_current_state(self):
"Get the outputs for the current timestep."
return self.next_ys[-1]
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
def advance(self, word_probs, attn_out):
"""
Given prob over words for every last beam `wordLk` and attention
`attn_out`: Compute and update the beam search.
Parameters:
* `word_probs`- probs of advancing from the last step (K x words)
* `attn_out`- attention at the last step
Returns: True if beam search is complete.
"""
num_words = word_probs.size(1)
print("NUM_WORDS: ", num_words)
if self.stepwise_penalty:
self.global_scorer.update_score(self, attn_out)
# force the output to be longer than self.min_length
cur_len = len(self.next_ys)
if cur_len < self.min_length:
for k in range(len(word_probs)):
word_probs[k][self._eos] = -1e20
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_scores = word_probs +
self.scores.unsqueeze(1).expand_as(word_probs)
# Don't let EOS have children.
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
beam_scores[i] = -1e20
# Block ngram repeats
if self.block_ngram_repeat > 0:
ngrams = []
le = len(self.next_ys)
for j in range(self.next_ys[-1].size(0)):
hyp, _ = self.get_hyp(le - 1, j)
ngrams = set()
fail = False
gram = []
for i in range(le - 1):
# Last n tokens, n = block_ngram_repeat
gram = (gram +
[hyp[i].item()])[-self.block_ngram_repeat:]
# Skip the blocking if it is in the exclusion list
if set(gram) & self.exclusion_tokens:
continue
if tuple(gram) in ngrams:
fail = True
ngrams.add(tuple(gram))
if fail:
beam_scores[j] = -10e20
else:
beam_scores = word_probs[0]
print("INSIDE BEAM: beam scores at sos", beam_scores.shape)
flat_beam_scores = beam_scores.view(-1)
print("INSIDE BEAM: flat beam scores at sos", flat_beam_scores.shape)
best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0,
True, True)
print("BEST_SCORES: ", best_scores, best_scores.shape)
print("BEST_SCORES_ID: ", best_scores_id, best_scores_id.shape)
self.all_scores.append(self.scores)
self.scores = best_scores
# best_scores_id is flattened beam x word array, so calculate which
# word and beam each score came from
prev_k = best_scores_id / num_words
print("PREV_K: ", prev_k)
self.prev_ks.append(prev_k)
self.next_ys.append((best_scores_id - prev_k * num_words))
self.attn.append(attn_out.index_select(0, prev_k))
self.global_scorer.update_global_state(self)
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
global_scores = self.global_scorer.score(self, self.scores)
s = global_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.next_ys[-1][0] == self._eos:
self.all_scores.append(self.scores)
self.eos_top = True
def done(self):
return self.eos_top and len(self.finished) >= self.n_best
def sort_finished(self, minimum=None):
if minimum is not None:
i = 0
# Add from beam until we have minimum outputs.
while len(self.finished) < minimum:
global_scores = self.global_scorer.score(self, self.scores)
s = global_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
i += 1
self.finished.sort(key=lambda a: -a[0])
scores = [sc for sc, _, _ in self.finished]
ks = [(t, k) for _, t, k in self.finished]
return scores, ks
def get_hyp(self, timestep, k):
"""
Walk back to construct the full hypothesis.
"""
hyp, attn = [], []
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
hyp.append(self.next_ys[j + 1][k])
attn.append(self.attn[j][k])
k = self.prev_ks[j][k]
return hyp[::-1], torch.stack(attn[::-1])
# class GNMTGlobalScorer(object):
# """
# NMT re-ranking score from
# "Google's Neural Machine Translation System" :cite:`wu2016google`
# Args:
# alpha (float): length parameter
# beta (float): coverage parameter
# """
# def __init__(self, alpha, beta, cov_penalty, length_penalty):
# self.alpha = alpha
# self.beta = beta
# penalty_builder = penalties.PenaltyBuilder(cov_penalty,
# length_penalty)
# # Term will be subtracted from probability
# self.cov_penalty = penalty_builder.coverage_penalty()
# # Probability will be divided by this
# self.length_penalty = penalty_builder.length_penalty()
# def score(self, beam, logprobs):
# """
# Rescores a prediction based on penalty functions
# """
# normalized_probs = self.length_penalty(beam,
# logprobs,
# self.alpha)
# if not beam.stepwise_penalty:
# penalty = self.cov_penalty(beam,
# beam.global_state["coverage"],
# self.beta)
# normalized_probs -= penalty
# return normalized_probs
# def update_score(self, beam, attn):
# """
# Function to update scores of a Beam that is not finished
# """
# if "prev_penalty" in beam.global_state.keys():
# beam.scores.add_(beam.global_state["prev_penalty"])
# penalty = self.cov_penalty(beam,
# beam.global_state["coverage"] + attn,
# self.beta)
# beam.scores.sub_(penalty)
# def update_global_state(self, beam):
# "Keeps the coverage vector as sum of attentions"
# if len(beam.prev_ks) == 1:
# beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0)
# beam.global_state["coverage"] = beam.attn[-1]
# self.cov_total = beam.attn[-1].sum(1)
# else:
# self.cov_total += torch.min(beam.attn[-1],
# beam.global_state['coverage']).sum(1)
# beam.global_state["coverage"] = beam.global_state["coverage"]
# .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1])
# prev_penalty = self.cov_penalty(beam,
# beam.global_state["coverage"],
# self.beta)
# beam.global_state["prev_penalty"] = prev_penalty
class RNNEncoder(torch.nn.Module):
def __init__(self, batch_size, num_layers, d_input, d_embed, d_hidden, d_enc_vocab, device, dropout_prob=0.1):
super(RNNEncoder, self).__init__()
self.d_hidden = d_hidden
self.d_input = d_input
self.d_embed = d_embed
self.d_enc_vocab = d_enc_vocab
self.batch_size = batch_size
self.num_layers = num_layers
self.embedding_layer = torch.nn.Embedding(self.d_enc_vocab, self.d_embed)
self.lstm_layer = torch.nn.LSTM(d_embed,
d_hidden,
num_layers=self.num_layers,
bias=True,
batch_first=True,
dropout=0.0,
bidirectional=False)
# optionally, some dropout for preventing overfitting
self.dropout = torch.nn.Dropout(dropout_prob)
def forward(self, input, prev_state=False):
#print(input)
embedded_input = self.embedding_layer(input)
embedded_input = self.dropout(embedded_input)
# the prev state are h and c components of the lstm,
# initialized to zero vectors
# returns output matrix of vector representations of each
# input word
# also returns h and c of the last encodere step
if not prev_state:
output, (state_h, state_c) = self.lstm_layer(embedded_input) # 4P
else:
output, (state_h, state_c) = self.lstm_layer(embedded_input, prev_state)
return output, (state_h, state_c)
class Attention(torch.nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
# three matrices needed for Bahdanau attention:
# W_a, U_a and V_a
self.Wa = torch.nn.Linear(hidden_size, hidden_size, bias=False)
self.Ua = torch.nn.Linear(hidden_size, hidden_size, bias=False)
self.Va = torch.nn.Linear(hidden_size, 1, bias=False)
def forward(self, s_tneg1, h_t):
# s: [batch_size x num_layer x d_hidden]
# h: [batch_size x d_input x d_hidden]
scores = self.Va(torch.tanh(self.Wa(s_tneg1) + self.Ua(h_t)))
scores = scores.squeeze(2).unsqueeze(1)
weights = torch.nn.functional.softmax(scores, dim=-1)
context = torch.bmm(weights, h_t)
return context, weights
class RNNDecoder(torch.nn.Module):
def __init__(self, d_embed, d_hidden, d_dec_vocab, device, dropout_prob=0.1):
super(RNNDecoder, self).__init__()
self.dec_embedding_layer = torch.nn.Embedding(d_enc_vocab, d_embed)
self.attention_layer = Attention(d_hidden)
self.dec_lstm_layer = torch.nn.LSTM(2*d_embed,
d_hidden,
num_layers=1,
bias=True,
batch_first=True,
dropout=0.0,
bidirectional=False)
self.output_layer = torch.nn.Linear(d_hidden, d_dec_vocab)
self.dropout = torch.nn.Dropout(dropout_prob)
def forward(self, encoder_outputs, encoder_hidden_state, max_output_length, target_tensor=None):
batch_size = encoder_outputs.size(0)
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(2)
decoder_hidden_state = encoder_hidden_state
decoder_outputs = []
attention_weights = []
teacher_forcing_ratio = 0.5
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
for i in range(max_output_length):
# compute one iteration step
decoder_output, decoder_hidden_state, attn_weights = self.forward_step(
decoder_input, decoder_hidden_state, encoder_outputs
)
# collect decoder output and attention weights for this step
decoder_outputs.append(decoder_output)
attention_weights.append(attn_weights)
if target_tensor is not None and use_teacher_forcing:
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
else:
_, topi = decoder_output.topk(1)
decoder_input = topi.squeeze(-1).detach() # detach from history as input
decoder_outputs = torch.cat(decoder_outputs, dim=1)
decoder_outputs = torch.nn.functional.log_softmax(decoder_outputs, dim=-1)
attention_weights = torch.cat(attention_weights, dim=1)
return decoder_outputs, decoder_hidden_state, attention_weights
def forward_step(self, input_t, decoder_hidden_state, encoder_outputs):
embedded_dec_input = self.dec_embedding_layer(input_t)
embedded_dec_input = self.dropout(embedded_dec_input)
decoder_h, decoder_c = decoder_hidden_state
s_tneg1 = decoder_h.permute(1, 0, 2)
context, attn_weights = self.attention_layer(s_tneg1, encoder_outputs) # 1P
lstm_input = torch.cat((embedded_dec_input, context), dim=2) # 1P
decoder_output, (decoder_h, decoder_c) = self.dec_lstm_layer(lstm_input, (decoder_h, decoder_c))
output = self.output_layer(decoder_output) # 1P
return output, (decoder_h, decoder_c), attn_weights
class RNNEncoderDecoder(torch.nn.Module):
def __init__(self, batch_size, num_layers, d_embed, d_hidden, d_input, d_enc_vocab, d_dec_vocab, device, dropout_prob=0.1):
super(RNNEncoderDecoder, self).__init__()
self.encoder = RNNEncoder(batch_size, num_layers, d_input, d_embed, d_hidden, d_enc_vocab, device, dropout_prob=0.1)
self.decoder = RNNDecoder(d_embed, d_hidden, d_dec_vocab, device, dropout_prob=0.1)
def forward(self, x, max_output_length):
encoder_out, (encoder_h, encoder_c) = self.encoder(x)
dec_out, dec_hidden_state, weights = self.decoder(encoder_out, (encoder_h, encoder_c), max_output_length)
return dec_out, dec_hidden_state, weights
class RNNGenerator(torch.nn.Module):
def __init__(self,
model,
dataset,
n_best = 1,
max_length = 100,
min_length = 0,
global_scorer = None,
copy_attn = False,
logger = None,
dump_beam = "",
beam_size = 3,
stepwise_penalty = False,
block_ngram_repeat = 0,
ignore_when_blocking = [],
sample_rate = 16000,
window_size = .02,
window_stride = .01,
window = 'hamming',
cuda = False,
):
super(RNNGenerator, self).__init__()
self.model = model
self.n_best = n_best
self.max_length = max_length
self.global_scorer = global_scorer
self.copy_attn = copy_attn
self.beam_size = beam_size
self.min_length = min_length
self.stepwise_penalty = stepwise_penalty
self.dump_beam = dump_beam
self.block_ngram_repeat = block_ngram_repeat
self.ignore_when_blocking = set(ignore_when_blocking)
self.sample_rate = sample_rate
self.window_size = window_size
self.window_stride = window_stride
self.window = window
self.cuda = cuda
def beam_decode(self, batch, data, max_output_length):
beam_size = self.beam_size
batch_size = batch.size(0)
print("beam decode batch size: ", batch_size)
#data_type = data.data_type
vocab = data.w2index
exclusion_tokens = set([vocab.w2index[t]
for t in self.ignore_when_blocking])
beam = [Beam(beam_size, n_best=self.n_best,
cuda=self.cuda,
global_scorer=self.global_scorer,
pad=0,
eos=3,
bos=2,
min_length=self.min_length,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=exclusion_tokens)
for __ in range(batch_size)]
def var(a):
return torch.tensor(a, requires_grad=False)
def rvar(a):
return var(a.repeat(1, beam_size, 1))
def bottle(m):
return m.view(batch_size * beam_size, -1)
def unbottle(m):
return m.view(beam_size, batch_size, -1)
def _repeat_beam_size_times(x, dim):
repeats = [1] * x.dim()
repeats[dim] = beam_size
return x.repeat(*repeats)
# Run encoder to produce output representation on the input batch
# for generating text
encoder_out, memory_bank = self.model.encoder(batch)
print("INPUT SIZE: ", batch.size(1))
src_lengths = torch.LongTensor((batch.size(1),)).fill_(batch.size(1))
print(src_lengths)
# repeat the encoder output *beam_size* times in order to compute all beams simultaneously
if isinstance(memory_bank, tuple):
memory_bank = tuple(rvar(x.data) for x in memory_bank)
else:
memory_bank = rvar(memory_bank.data)
memory_lengths = src_lengths.repeat(beam_size)
encoder_out = encoder_out.repeat(beam_size, 1, 1)
for i in range(self.max_length):
if all((b.done() for b in beam)):
break
inp = var(torch.stack([b.get_current_state() for b in beam])
.t().contiguous().view(1, -1))
inp = inp.unsqueeze(2).squeeze(0)
print("beam decoder input: ", inp.shape)
print("beam memory bank: ", memory_bank[0].shape, memory_bank[1].shape)
print("beam decoder encoder_outputs: ", encoder_out.shape)
dec_out, dec_hidden_state, attn_weights = self.model.decoder.forward_step(inp, memory_bank, encoder_out)
print("beam dec_out: ", dec_out.shape)
print("beam dec hidden state: ", dec_hidden_state[0].shape, dec_hidden_state[1].shape)
print("beam attn: ", attn_weights.shape)
dec_out = dec_out.squeeze(0)
probs = torch.nn.functional.log_softmax(dec_out, dim=-1)
out = unbottle(probs)
# beam x tgt_vocab
beam_attn = unbottle(attn_weights)
print("beam final probs: ", out.shape)
print("beam attn head: ", beam_attn.shape)
# (c) Advance each beam.
select_indices_array = []
for j, b in enumerate(beam):
print("beamwise prob ", j, out[:, j].shape)
b.advance(out[:, j],
attn_weights.data[:, j, :memory_lengths[j]])
select_indices_array.append(
b.get_current_origin() * batch_size + j)
select_indices = torch.cat(select_indices_array)
.view(batch_size, beam_size)
.transpose(0, 1)
.contiguous()
.view(-1)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
max_sequence_length = 4
batch_size = 64
d_hidden = 128
d_input = 4
dataset = Dataset(text, max_sequence_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True)
d_embed = 128
d_enc_vocab = dataset.get_vocabulary_size()
d_dec_vocab = dataset.get_vocabulary_size()
num_layers = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rnn_encdec = RNNEncoderDecoder(batch_size, num_layers, d_embed, d_hidden, d_input, d_enc_vocab, d_dec_vocab, device, dropout_prob=0.1)
rnn_encdec.to(device)
rnn_encdec.train(dataloader, 2, max_sequence_length)
beam_gen = RNNGenerator(
rnn_ee,
dataset,
n_best = 1,
max_length = 100,
min_length = 0,
global_scorer = None,
copy_attn = False,
logger = None,
dump_beam = "",
beam_size = 4,
stepwise_penalty = False,
block_ngram_repeat = 0,
ignore_when_blocking = [],
sample_rate = 16000,
window_size = .02,
window_stride = .01,
window = 'hamming',
cuda=False,
)
# fake some input word indeces with batch_size = 3 and input_size=4
x = torch.randint(0, 100, (3, d_input))
beam_gen.beam_decode(x, dataset, max_sequence_length)