Updated generation process to allow for direct ids input

This commit is contained in:
Guillaume B 2020-06-25 21:32:45 +02:00
parent 9081bc3318
commit c045b3212f
3 changed files with 111 additions and 18 deletions

View File

@ -17,7 +17,7 @@ use rust_bert::pipelines::conversation::ConversationModel;
fn main() -> failure::Fallible<()> {
let conversation_model = ConversationModel::new(Default::default())?;
let input = ["Hello, how are you? <|endoftext|>"];
let input = ["Hello, how are you?"];
let output = conversation_model.reply(&input);

View File

@ -27,8 +27,11 @@ use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
use tch::Device;
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
use rust_tokenizers::{Tokenizer, TruncationStrategy};
use tch::{Device, Tensor};
/// # Configuration for multi-turn classification
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
@ -105,6 +108,7 @@ impl Default for ConversationConfig {
/// # Conversation model
pub struct ConversationModel {
model: GPT2Generator,
eos_token_id: i64,
}
impl ConversationModel {
@ -146,8 +150,11 @@ impl ConversationModel {
};
let model = GPT2Generator::new(generate_config)?;
Ok(ConversationModel { model })
let eos_token_id = *model.get_eos_ids().as_ref().unwrap().first().unwrap();
Ok(ConversationModel {
model,
eos_token_id,
})
}
/// Perform a multi-turn conversation based on user input
@ -180,6 +187,72 @@ impl ConversationModel {
// ToDo: update base `generate` function to perform some preparation steps and then delegate to the lower level `generate` taking input ids & cache as input
// ToDo: update return of function to return a Vec<String> and a History
self.model.generate(Some(texts.to_vec()), None)
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());
let max_len = self.model.get_config().max_length;
let pad_token = match self.model.get_pad_id() {
Some(value) => *value,
None => self.eos_token_id,
};
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| {
self.model
.get_tokenizer()
.convert_tokens_to_ids(&prompt_tokens)
})
.map(|mut tokens| {
tokens.push(self.eos_token_id);
tokens
})
.collect::<Vec<Vec<i64>>>();
let num_truncated_tokens = token_ids
.iter()
.map(|token_ids| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.collect::<Vec<usize>>();
let token_ids = token_ids
.into_iter()
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
tokens,
None,
vec![],
None,
vec![],
None,
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
0,
)
.unwrap()
.0
})
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let token_ids = token_ids
.into_iter()
.map(|input| {
let mut temp = vec![pad_token; max_len - input.len()];
temp.extend(input);
temp
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.model.get_var_store().device()))
.collect::<Vec<Tensor>>();
let prompt_ids = Tensor::stack(&token_ids, 0);
self.model.generate_from_ids_and_past(prompt_ids, None)
}
}

View File

@ -1071,7 +1071,7 @@ pub enum Cache {
None,
}
mod private_generation_utils {
pub(crate) mod private_generation_utils {
use super::ordered_float::OrderedFloat;
use crate::pipelines::generation::{BeamHypotheses, Cache, GenerateConfig, LMHeadModel};
use itertools::Itertools;
@ -1949,24 +1949,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let do_sample = config.do_sample;
let num_return_sequences = config.num_return_sequences;
let num_beams = config.num_beams;
let min_length = config.min_length;
let max_length = config.max_length;
let encoding_max_len = if self.is_encoder_decoder() {
1024u64
} else {
max_length
};
let early_stopping = config.early_stopping;
let temperature = config.temperature;
let top_k = config.top_k;
let top_p = config.top_p;
let repetition_penalty = config.repetition_penalty;
let length_penalty = config.length_penalty;
let no_repeat_ngram_size = config.no_repeat_ngram_size;
let pad_token_id = match self.get_pad_id() {
Some(value) => Some(*value),
None => match &eos_token_ids {
@ -1986,6 +1974,37 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
),
},
};
self.generate_from_ids_and_past(input_ids, attention_mask)
}
fn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
) -> Vec<String> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let do_sample = config.do_sample;
let num_return_sequences = config.num_return_sequences;
let num_beams = config.num_beams;
let min_length = config.min_length;
let max_length = config.max_length;
let early_stopping = config.early_stopping;
let temperature = config.temperature;
let top_k = config.top_k;
let top_p = config.top_p;
let repetition_penalty = config.repetition_penalty;
let length_penalty = config.length_penalty;
let no_repeat_ngram_size = config.no_repeat_ngram_size;
let pad_token_id = match self.get_pad_id() {
Some(value) => Some(*value),
None => match &eos_token_ids {
Some(eos_ids) => Some(eos_ids[0]),
None => None,
},
};
let cur_len = if !self.is_encoder_decoder() {
*input_ids.size().last().unwrap()
@ -2056,6 +2075,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
(input_ids, attention_mask)
};
input_ids.print();
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(