mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Optimization of generation pipeline
This commit is contained in:
parent
6aa6a4b2a2
commit
dbf6841610
@ -39,7 +39,7 @@ fn main() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 30,
|
||||
max_length: 20,
|
||||
do_sample: true,
|
||||
num_beams: 5,
|
||||
temperature: 1.1,
|
||||
|
@ -465,7 +465,7 @@ impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokeni
|
||||
|
||||
mod private_generation_utils {
|
||||
use rust_tokenizers::{Vocab, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Tensor};
|
||||
use tch::{nn, Tensor, Device};
|
||||
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
|
||||
use std::collections::HashMap;
|
||||
use tch::kind::Kind::{Int64, Float, Bool};
|
||||
@ -555,30 +555,25 @@ mod private_generation_utils {
|
||||
if cur_len + 1 < no_repeat_ngram_size {
|
||||
vec!(vec!())
|
||||
} else {
|
||||
let input_ids = input_ids.to(Device::Cpu);
|
||||
let num_hypothesis = *input_ids.size().first().unwrap();
|
||||
let mut banned_tokens: Vec<Vec<i64>> = Vec::with_capacity(num_hypothesis as usize);
|
||||
for hypothesis_index in 0..num_hypothesis {
|
||||
let hypothesis_input_ids = input_ids.get(hypothesis_index);
|
||||
let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
|
||||
let input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
|
||||
let query = hypothesis_input_ids
|
||||
.slice(0,
|
||||
cur_len + 1 - no_repeat_ngram_size,
|
||||
*hypothesis_input_ids.size().last().unwrap(),
|
||||
1).iter::<i64>()
|
||||
let hypothesis_input_ids = hypothesis_input_ids
|
||||
.iter::<i64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<i64>>();
|
||||
let query = &hypothesis_input_ids[cur_len as usize + 1 - no_repeat_ngram_size as usize..].to_vec();
|
||||
let ngram_indices: Vec<(i64, i64)> = input
|
||||
.windows(3)
|
||||
.windows(no_repeat_ngram_size as usize)
|
||||
.map(|win| (*win.first().unwrap(), *win.last().unwrap()))
|
||||
.collect();
|
||||
for ngram in ngram_indices.into_iter() {
|
||||
let ngram = hypothesis_input_ids
|
||||
.slice(0, ngram.0, ngram.1 + 1, 1)
|
||||
.iter::<i64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<i64>>();
|
||||
let key = ngram[..ngram.len() - 1].to_vec();
|
||||
let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
|
||||
let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
|
||||
let value = *ngram.last().unwrap();
|
||||
if generated_ngram.contains_key(&key) {
|
||||
generated_ngram.get_mut(&key).unwrap().push(value)
|
||||
@ -586,7 +581,7 @@ mod private_generation_utils {
|
||||
generated_ngram.insert(key, vec!(value));
|
||||
}
|
||||
}
|
||||
let hypothesis_banned_tokens = match generated_ngram.get(&query) {
|
||||
let hypothesis_banned_tokens = match generated_ngram.get(query) {
|
||||
Some(banned_tokens) => banned_tokens.clone(),
|
||||
None => vec!()
|
||||
};
|
||||
@ -785,20 +780,16 @@ mod private_generation_utils {
|
||||
if temperature > 1f64 {
|
||||
next_token_logits = next_token_logits / temperature;
|
||||
}
|
||||
|
||||
let mut scores = next_token_logits.log_softmax(-1, Float);
|
||||
|
||||
// Do not allow eos token if min length is not reached
|
||||
if (&eos_token_ids.is_some()) & (current_length < min_length) {
|
||||
&scores.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()), std::f64::NEG_INFINITY);
|
||||
}
|
||||
|
||||
// Get banned tokens and set their probability to 0
|
||||
let banned_tokens = self.get_banned_tokens(&input_ids, no_repeat_ngram_size as i64, current_length as i64);
|
||||
for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) {
|
||||
&scores.get(batch_index).index_fill_(0, &Tensor::of_slice(&index_banned_token).to_device(next_token_logits.device()), std::f64::NEG_INFINITY);
|
||||
}
|
||||
|
||||
let (next_scores, next_tokens) = if do_sample {
|
||||
let mut _scores: Tensor = &scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
|
||||
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
|
||||
@ -815,10 +806,7 @@ mod private_generation_utils {
|
||||
let next_scores = next_scores.contiguous().view((batch_size, num_beams * vocab_size));
|
||||
next_scores.topk(2 * num_beams, 1, true, true)
|
||||
};
|
||||
|
||||
|
||||
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec!();
|
||||
|
||||
for batch_index in 0..batch_size {
|
||||
if done[batch_index as usize] {
|
||||
assert!(hypotheses[batch_index as usize].len() >= num_beams,
|
||||
@ -829,6 +817,7 @@ mod private_generation_utils {
|
||||
next_batch_beam.append(&mut
|
||||
(0..num_beams).map(|_| (0f64, pad_token_id.unwrap(), 0i64)).collect::<Vec<(f64, i64, i64)>>()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut next_sentence_beam: Vec<(f64, i64, i64)> = vec!();
|
||||
@ -845,7 +834,8 @@ mod private_generation_utils {
|
||||
|
||||
if eos_token_ids.as_ref().is_some() {
|
||||
if eos_token_ids.as_ref().unwrap().contains(&token_id) {
|
||||
if beam_token_rank > num_beams {
|
||||
if beam_token_rank >= num_beams {
|
||||
beam_token_rank += 1;
|
||||
continue;
|
||||
}
|
||||
hypotheses[batch_index as usize].add(input_ids.get(effective_beam_id).copy(), beam_token_score)
|
||||
@ -860,7 +850,6 @@ mod private_generation_utils {
|
||||
(beam_token_rank == beam_token_rank_max_value) {
|
||||
break;
|
||||
}
|
||||
|
||||
beam_token_rank += 1;
|
||||
}
|
||||
|
||||
@ -872,7 +861,6 @@ mod private_generation_utils {
|
||||
assert_eq!(next_sentence_beam.len() as i64, num_beams, "Beam incomplete");
|
||||
next_batch_beam.append(&mut next_sentence_beam);
|
||||
}
|
||||
|
||||
if done.iter().all(|&x| x) {
|
||||
break;
|
||||
}
|
||||
@ -882,7 +870,6 @@ mod private_generation_utils {
|
||||
|
||||
input_ids = input_ids.index_select(0, &beam_indices);
|
||||
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
|
||||
|
||||
past = match past {
|
||||
Some(past_values) => Some(self.reorder_cache(past_values, &beam_indices)),
|
||||
None => None
|
||||
@ -890,13 +877,17 @@ mod private_generation_utils {
|
||||
|
||||
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
|
||||
(Int64, attention_mask.device())).as_ref()], -1);
|
||||
current_length += 1
|
||||
current_length += 1;
|
||||
}
|
||||
|
||||
let mut batch_index = 0i64;
|
||||
|
||||
loop {
|
||||
if batch_index == batch_size {
|
||||
break;
|
||||
}
|
||||
if done[batch_index as usize] {
|
||||
batch_index += 1;
|
||||
continue;
|
||||
}
|
||||
for beam_index in 0..num_beams {
|
||||
@ -906,9 +897,6 @@ mod private_generation_utils {
|
||||
hypotheses[batch_index as usize].add(final_tokens, final_score);
|
||||
}
|
||||
batch_index += 1;
|
||||
if batch_index == batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (output_batch_size, output_num_return_sequences_per_batch) = if do_sample {
|
||||
|
@ -93,7 +93,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
||||
let output = model.generate(Some(vec!(input_context)), None);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i 'm sorry, \" i said. \" i 'm not sure what you're talking about. \" \n \" you're not a vampire, \" he said");
|
||||
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i \'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user