Model forward pass for generation & GPT2 fix

This commit is contained in:
Guillaume B 2020-03-10 18:50:22 +01:00
parent 8f044b5bcc
commit c7727b0b97
3 changed files with 62 additions and 7 deletions

View File

@ -34,8 +34,8 @@ fn main() -> failure::Fallible<()> {
let model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context = "The dog";
let output = model.generate(Some(input_context), 40, true, 5, 1.0,
50, 1.0, 1.0, 1.0, 3);
let output = model.generate(Some(input_context), 40, true, 1, 1.0,
50, 1.0, 1.0, 1.0, 1);
println!("{:?}", output);
output.print();
Ok(())

View File

@ -141,7 +141,7 @@ impl Attention {
let (key, value) = match layer_past {
Some(past) => {
let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1);
let value = Tensor::cat(&[past.get(1), value], -1);
let value = Tensor::cat(&[past.get(1), value], -2);
(key, value)
}
None => (key, value)

View File

@ -90,6 +90,14 @@ impl LanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Genera
fn get_bos_id(&self) -> &i64 { &self.bos_token_id }
fn get_eos_ids(&self) -> &Vec<i64> { &self.eos_token_ids }
fn get_pad_id(&self) -> &i64 { &self.pad_token_id }
fn prepare_inputs_for_generation(&self, input_ids: Tensor, past: Option<Vec<Tensor>>) -> (Tensor, Option<Vec<Tensor>>) {
if past.is_some() {
(input_ids.slice(1, input_ids.size()[1] - 1, input_ids.size()[1], 1).unsqueeze(-1), past)
} else {
(input_ids, past)
}
}
}
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
@ -100,6 +108,10 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_eos_ids(&self) -> &Vec<i64>;
fn get_pad_id(&self) -> &i64;
fn prepare_inputs_for_generation(&self, input_ids: Tensor, past: Option<Vec<Tensor>>) -> (Tensor, Option<Vec<Tensor>>) {
(input_ids, past)
}
fn encode_prompt_text(&self, prompt_text: &str, max_len: u64) -> Tensor {
let token_ids = self.get_tokenizer().convert_tokens_to_ids(&self.get_tokenizer().tokenize(prompt_text));
let num_truncated_tokens = if token_ids.len() > max_len as usize { token_ids.len() - max_len as usize } else { 0 };
@ -109,6 +121,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
&TruncationStrategy::LongestFirst,
0).unwrap();
Tensor::of_slice(&token_ids).unsqueeze(0).to(self.get_var_store().device())
}
fn enforce_repetition_penalty(&self, next_token_logits: Tensor, batch_size: i64, num_beams: u64, prev_output_tokens: &Tensor, repetition_penalty: f64) {
}
fn generate(&self, prompt_text: Option<&str>, max_length: u64, do_sample: bool, num_beams: u64, temperature: f64,
@ -136,23 +154,60 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
}
let cur_len = *input_ids.size().last().unwrap();
let batch_size = *input_ids.size().first().unwrap();
let vocab_size = self.get_tokenizer().vocab().values().len();
let (effective_batch_size, effective_batch_mult) = match do_sample {
true => (input_ids.size()[0] * num_return_sequences as i64, num_return_sequences as i64),
false => (input_ids.size()[0], 1)
true => (batch_size * num_return_sequences as i64, num_return_sequences as i64),
false => (batch_size, 1)
};
let input_ids = if (num_return_sequences > 1) | (num_beams > 1) {
input_ids
.unsqueeze(1)
.expand(&[input_ids.size()[0], effective_batch_mult * num_beams as i64, cur_len], true)
.expand(&[batch_size, effective_batch_mult * num_beams as i64, cur_len], true)
.contiguous()
.view((effective_batch_size * num_beams as i64, cur_len))
} else {
input_ids
};
input_ids
self.generate_no_beam_search(input_ids, cur_len, max_length, do_sample, temperature,
top_k, top_p, repetition_penalty, batch_size);
Tensor::new()
}
fn generate_no_beam_search(&self, input_ids: Tensor, cur_len: i64, max_len: u64, do_sample: bool, temperature: f64,
top_k: u64, top_p: f64, repetition_penalty: f64, batch_size: i64) {
let unfinished_sentences = Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
let sentence_lengths: Tensor = Tensor::ones(&[batch_size], (Int64, self.get_var_store().device())) * max_len as i64;
let mut past: Option<Vec<Tensor>> = None;
let mut outputs: Tensor = Tensor::new();
let mut cur_len = cur_len as u64;
// ToDo: remove when loop is fixed
let mut input_ids = input_ids.copy();
let input_ids_back = input_ids.copy();
while cur_len < max_len {
let (prepared_input, prepared_past) = self.prepare_inputs_for_generation(input_ids.copy(), past);
let temp = self.get_model().forward_t(&Some(prepared_input), &prepared_past, &None, &None, &None, &None, false).unwrap();
outputs = temp.0;
past = temp.1;
let next_token_logits = outputs.slice(1, outputs.size()[1] - 1, outputs.size()[1], 1);
if repetition_penalty > 1f64 {
self.enforce_repetition_penalty(next_token_logits, batch_size, 1, &input_ids, repetition_penalty)
}
// ToDo: remove when loop is fixed
input_ids = input_ids_back.copy();
cur_len += 1;
}
}
}