Optimization of generation pipeline

This commit is contained in:
Guillaume B 2020-04-04 09:40:13 +02:00
parent 6aa6a4b2a2
commit dbf6841610
3 changed files with 19 additions and 31 deletions

View File

@ -39,7 +39,7 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 30,
max_length: 20,
do_sample: true,
num_beams: 5,
temperature: 1.1,

View File

@ -465,7 +465,7 @@ impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokeni
mod private_generation_utils {
use rust_tokenizers::{Vocab, Tokenizer, TruncationStrategy};
use tch::{nn, Tensor};
use tch::{nn, Tensor, Device};
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
use std::collections::HashMap;
use tch::kind::Kind::{Int64, Float, Bool};
@ -555,30 +555,25 @@ mod private_generation_utils {
if cur_len + 1 < no_repeat_ngram_size {
vec!(vec!())
} else {
let input_ids = input_ids.to(Device::Cpu);
let num_hypothesis = *input_ids.size().first().unwrap();
let mut banned_tokens: Vec<Vec<i64>> = Vec::with_capacity(num_hypothesis as usize);
for hypothesis_index in 0..num_hypothesis {
let hypothesis_input_ids = input_ids.get(hypothesis_index);
let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
let input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
let query = hypothesis_input_ids
.slice(0,
cur_len + 1 - no_repeat_ngram_size,
*hypothesis_input_ids.size().last().unwrap(),
1).iter::<i64>()
let hypothesis_input_ids = hypothesis_input_ids
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let query = &hypothesis_input_ids[cur_len as usize + 1 - no_repeat_ngram_size as usize..].to_vec();
let ngram_indices: Vec<(i64, i64)> = input
.windows(3)
.windows(no_repeat_ngram_size as usize)
.map(|win| (*win.first().unwrap(), *win.last().unwrap()))
.collect();
for ngram in ngram_indices.into_iter() {
let ngram = hypothesis_input_ids
.slice(0, ngram.0, ngram.1 + 1, 1)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let key = ngram[..ngram.len() - 1].to_vec();
let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
let value = *ngram.last().unwrap();
if generated_ngram.contains_key(&key) {
generated_ngram.get_mut(&key).unwrap().push(value)
@ -586,7 +581,7 @@ mod private_generation_utils {
generated_ngram.insert(key, vec!(value));
}
}
let hypothesis_banned_tokens = match generated_ngram.get(&query) {
let hypothesis_banned_tokens = match generated_ngram.get(query) {
Some(banned_tokens) => banned_tokens.clone(),
None => vec!()
};
@ -785,20 +780,16 @@ mod private_generation_utils {
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
}
let mut scores = next_token_logits.log_softmax(-1, Float);
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&scores.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()), std::f64::NEG_INFINITY);
}
// Get banned tokens and set their probability to 0
let banned_tokens = self.get_banned_tokens(&input_ids, no_repeat_ngram_size as i64, current_length as i64);
for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) {
&scores.get(batch_index).index_fill_(0, &Tensor::of_slice(&index_banned_token).to_device(next_token_logits.device()), std::f64::NEG_INFINITY);
}
let (next_scores, next_tokens) = if do_sample {
let mut _scores: Tensor = &scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
@ -815,10 +806,7 @@ mod private_generation_utils {
let next_scores = next_scores.contiguous().view((batch_size, num_beams * vocab_size));
next_scores.topk(2 * num_beams, 1, true, true)
};
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec!();
for batch_index in 0..batch_size {
if done[batch_index as usize] {
assert!(hypotheses[batch_index as usize].len() >= num_beams,
@ -829,6 +817,7 @@ mod private_generation_utils {
next_batch_beam.append(&mut
(0..num_beams).map(|_| (0f64, pad_token_id.unwrap(), 0i64)).collect::<Vec<(f64, i64, i64)>>()
);
continue;
}
let mut next_sentence_beam: Vec<(f64, i64, i64)> = vec!();
@ -845,7 +834,8 @@ mod private_generation_utils {
if eos_token_ids.as_ref().is_some() {
if eos_token_ids.as_ref().unwrap().contains(&token_id) {
if beam_token_rank > num_beams {
if beam_token_rank >= num_beams {
beam_token_rank += 1;
continue;
}
hypotheses[batch_index as usize].add(input_ids.get(effective_beam_id).copy(), beam_token_score)
@ -860,7 +850,6 @@ mod private_generation_utils {
(beam_token_rank == beam_token_rank_max_value) {
break;
}
beam_token_rank += 1;
}
@ -872,7 +861,6 @@ mod private_generation_utils {
assert_eq!(next_sentence_beam.len() as i64, num_beams, "Beam incomplete");
next_batch_beam.append(&mut next_sentence_beam);
}
if done.iter().all(|&x| x) {
break;
}
@ -882,7 +870,6 @@ mod private_generation_utils {
input_ids = input_ids.index_select(0, &beam_indices);
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
past = match past {
Some(past_values) => Some(self.reorder_cache(past_values, &beam_indices)),
None => None
@ -890,13 +877,17 @@ mod private_generation_utils {
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
current_length += 1
current_length += 1;
}
let mut batch_index = 0i64;
loop {
if batch_index == batch_size {
break;
}
if done[batch_index as usize] {
batch_index += 1;
continue;
}
for beam_index in 0..num_beams {
@ -906,9 +897,6 @@ mod private_generation_utils {
hypotheses[batch_index as usize].add(final_tokens, final_score);
}
batch_index += 1;
if batch_index == batch_size {
break;
}
}
let (output_batch_size, output_num_return_sequences_per_batch) = if do_sample {

View File

@ -93,7 +93,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
let output = model.generate(Some(vec!(input_context)), None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i 'm sorry, \" i said. \" i 'm not sure what you're talking about. \" \n \" you're not a vampire, \" he said");
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i \'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
Ok(())
}