I have a piece of code to accelerate text generation using past_key_values
. The simplified version is as follows:
prefix_output = model(prefix_input_ids)
generation_output = model.generate(postfix_input_ids, num_beams=1, use_cache=True, past_key_values=prefix_output.past_key_values)
Here the variable model
can be GPT2LMHeadModel
that has loaded gpt2-xl
. The code works perfectly fine. The problem is that if num_beams
is set to greater than 1, then I get the exception below (in the example I set num_beams
to 3):
‘Sizes of tensors must match except in dimension 2. Expected size 1 but got size 3 for tensor number 1 in the list.’
I suspect that I should somehow pre-process the values of prefix_output.past_key_values
, before passing it to model.generate()
. I am not sure though. Anybody knows how to fix this? Thanks.
darvish is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.