Addition of TextOutput and IndicesOutput, updated pipelines and tests

This commit is contained in:
Guillaume B 2021-06-16 18:15:22 +02:00
parent c40a218b37
commit f29e02ecbc
9 changed files with 240 additions and 101 deletions

View File

@ -50,10 +50,11 @@ fn main() -> anyhow::Result<()> {
None, None,
target_language, target_language,
None, None,
false,
); );
for sentence in output { for sentence in output {
println!("{:?}", sentence); println!("{:?}", sentence.text);
} }
Ok(()) Ok(())
} }

View File

@ -41,7 +41,16 @@ fn main() -> anyhow::Result<()> {
// Define input // Define input
let input = ["translate English to German: This sentence will get translated to German"]; let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None); let output = t5_model.generate(
Some(input.to_vec()),
None,
None,
None,
None,
None,
None,
false,
);
println!("{:?}", output); println!("{:?}", output);
Ok(()) Ok(())

View File

@ -716,16 +716,20 @@ impl ConversationOption {
attention_mask: Option<Tensor>, attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> { ) -> Vec<Vec<i64>> {
match *self { match *self {
Self::GPT2(ref model) => model.generate_from_ids_and_past( Self::GPT2(ref model) => model
input_ids, .generate_from_ids_and_past(
attention_mask, input_ids,
None, attention_mask,
None, None,
None, None,
None, None,
None, None,
false, None,
), false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
} }
} }
} }

View File

@ -48,6 +48,7 @@
//! decoder_start_id, //! decoder_start_id,
//! forced_bos_token_id, //! forced_bos_token_id,
//! None, //! None,
//! false,
//! ); //! );
//! # Ok(()) //! # Ok(())
//! # } //! # }
@ -1215,6 +1216,22 @@ pub(crate) mod private_generation_utils {
} }
} }
#[derive(Debug, Clone)]
/// # Generated text output
/// Contains generated text and an optional log-likelihood score for the generated sequence
pub struct TextOutput {
pub text: String,
pub score: Option<f64>,
}
#[derive(Debug, Clone)]
/// # Generated indices output
/// Contains generated indices and an optional log-likelihood score for the generated sequence
pub struct IndicesOutput {
pub indices: Vec<i64>,
pub score: Option<f64>,
}
/// # Common trait for text generation models. /// # Common trait for text generation models.
/// Main API for text generation /// Main API for text generation
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
@ -1232,7 +1249,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens. /// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
/// ///
/// # Returns /// # Returns
/// * `Vec<String>` Vector of generated strings based on the prompts of length *number_of_prompts* x *num_return_sequences*. /// * `Vec<TextOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing TextOutput with the generated texts and the generation score if `output_scores` is true.
/// ///
/// # Example /// # Example
/// ///
@ -1268,6 +1285,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128; /// let max_length = 128;
/// let decoder_start_token_id = None; /// let decoder_start_token_id = None;
/// let forced_bos_token_id = None; /// let forced_bos_token_id = None;
/// let output_scores = true;
/// ///
/// //Example custom function for fine-grained generation control /// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1294,6 +1312,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id, /// decoder_start_token_id,
/// forced_bos_token_id, /// forced_bos_token_id,
/// Some(&force_one_paragraph), /// Some(&force_one_paragraph),
/// output_scores
/// ); /// );
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -1320,11 +1339,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id: impl Into<Option<i64>>, decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>, forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> (Vec<String>, Option<Vec<f64>>) output_scores: bool,
) -> Vec<TextOutput>
where where
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
let (generated_indices, scores) = self.generate_indices( let indices_outputs = self.generate_indices(
prompt_texts, prompt_texts,
attention_mask, attention_mask,
min_length, min_length,
@ -1332,13 +1352,18 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id, decoder_start_token_id,
forced_bos_token_id, forced_bos_token_id,
prefix_allowed_tokens_fn, prefix_allowed_tokens_fn,
true, output_scores,
); );
let mut output = Vec::with_capacity(generated.len()); let mut output = Vec::with_capacity(indices_outputs.len());
for generated_sequence in generated { for generated_sequence in indices_outputs {
output.push(self._get_tokenizer().decode(generated_sequence, true, true)); output.push(TextOutput {
text: self
._get_tokenizer()
.decode(generated_sequence.indices, true, true),
score: generated_sequence.score,
});
} }
(output, scores) output
} }
/// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training). /// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
@ -1353,7 +1378,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens. /// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
/// ///
/// # Returns /// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*. /// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
/// ///
/// # Example /// # Example
/// ///
@ -1388,6 +1413,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128; /// let max_length = 128;
/// let decoder_start_token_id = None; /// let decoder_start_token_id = None;
/// let forced_bos_token_id = None; /// let forced_bos_token_id = None;
/// let output_scores = true;
/// ///
/// //Example custom function for fine-grained generation control /// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1414,6 +1440,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id, /// decoder_start_token_id,
/// forced_bos_token_id, /// forced_bos_token_id,
/// Some(&force_one_paragraph), /// Some(&force_one_paragraph),
/// output_scores
/// ); /// );
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -1428,7 +1455,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
forced_bos_token_id: impl Into<Option<i64>>, forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
output_scores: bool, output_scores: bool,
) -> (Vec<Vec<i64>>, Option<Vec<f64>>) ) -> Vec<IndicesOutput>
where where
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
@ -1482,7 +1509,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens. /// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
/// ///
/// # Returns /// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*. /// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
/// ///
/// # Example /// # Example
/// ///
@ -1517,6 +1544,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128; /// let max_length = 128;
/// let decoder_start_token_id = None; /// let decoder_start_token_id = None;
/// let forced_bos_token_id = None; /// let forced_bos_token_id = None;
/// let output_scores = true;
/// ///
/// //Example custom function for fine-grained generation control /// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1543,6 +1571,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id, /// decoder_start_token_id,
/// forced_bos_token_id, /// forced_bos_token_id,
/// Some(&force_one_paragraph), /// Some(&force_one_paragraph),
/// output_scores
/// ); /// );
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -1557,7 +1586,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
forced_bos_token_id: impl Into<Option<i64>>, forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
output_scores: bool, output_scores: bool,
) -> (Vec<Vec<i64>>, Option<Vec<f64>>) { ) -> Vec<IndicesOutput> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone(); let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self); let config = PrivateLanguageGenerator::get_config(self);
@ -1714,17 +1743,20 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
} }
}); });
let num_sequences = *decoded.size().first().unwrap(); let num_sequences = *decoded.size().first().unwrap();
let mut output_ids = Vec::with_capacity(num_sequences as usize); let mut output = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences { for sequence_index in 0..num_sequences {
let sequence_output_ids = decoded let indices = decoded
.as_ref() .as_ref()
.get(sequence_index) .get(sequence_index)
.iter::<i64>() .iter::<i64>()
.unwrap() .unwrap()
.collect::<Vec<i64>>(); .collect::<Vec<i64>>();
output_ids.push(sequence_output_ids.clone()); let score = scores
.as_ref()
.map(|scores_value| scores_value[sequence_index as usize]);
output.push(IndicesOutput { indices, score });
} }
(output_ids, scores) output
} }
/// Returns a reference to the text generator's tokenizer /// Returns a reference to the text generator's tokenizer

View File

@ -263,18 +263,62 @@ impl SummarizationOption {
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
match *self { match *self {
Self::Bart(ref model) => { Self::Bart(ref model) => model
model.generate(prompt_texts, attention_mask, None, None, None, None, None) .generate(
} prompt_texts,
Self::T5(ref model) => { attention_mask,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) None,
} None,
Self::ProphetNet(ref model) => { None,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) None,
} None,
Self::Pegasus(ref model) => { false,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) )
} .into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
} }
} }
} }

View File

@ -249,56 +249,76 @@ impl TextGenerationOption {
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
match *self { match *self {
Self::GPT(ref model) => model.generate_indices( Self::GPT(ref model) => model
prompt_texts, .generate_indices(
attention_mask, prompt_texts,
min_length, attention_mask,
max_length, min_length,
None, max_length,
None, None,
None, None,
false, None,
), false,
Self::GPT2(ref model) => model.generate_indices( )
prompt_texts, .into_iter()
attention_mask, .map(|output| output.indices)
min_length, .collect(),
max_length, Self::GPT2(ref model) => model
None, .generate_indices(
None, prompt_texts,
None, attention_mask,
false, min_length,
), max_length,
Self::GPTNeo(ref model) => model.generate_indices( None,
prompt_texts, None,
attention_mask, None,
min_length, false,
max_length, )
None, .into_iter()
None, .map(|output| output.indices)
None, .collect(),
false, Self::GPTNeo(ref model) => model
), .generate_indices(
Self::XLNet(ref model) => model.generate_indices( prompt_texts,
prompt_texts, attention_mask,
attention_mask, min_length,
min_length, max_length,
max_length, None,
None, None,
None, None,
None, false,
false, )
), .into_iter()
Self::Reformer(ref model) => model.generate_indices( .map(|output| output.indices)
prompt_texts, .collect(),
attention_mask, Self::XLNet(ref model) => model
min_length, .generate_indices(
max_length, prompt_texts,
None, attention_mask,
None, min_length,
None, max_length,
false, None,
), None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
Self::Reformer(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
} }
} }
} }

View File

@ -674,12 +674,34 @@ impl TranslationOption {
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
match *self { match *self {
Self::Marian(ref model) => { Self::Marian(ref model) => model
model.generate(prompt_texts, attention_mask, None, None, None, None, None) .generate(
} prompt_texts,
Self::T5(ref model) => { attention_mask,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) None,
} None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
} }
} }
} }

View File

@ -428,17 +428,20 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
None, None,
None, None,
Some(&force_one_paragraph), Some(&force_one_paragraph),
true,
); );
assert_eq!(output.len(), 2); assert_eq!(output.len(), 2);
assert_eq!( assert_eq!(
output[0], output[0].text,
"Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n" "Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n"
); );
assert!(output[0].score.unwrap().is_nan());
assert_eq!( assert_eq!(
output[1], output[1].text,
"There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And" "There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And"
); );
assert!(output[1].score.unwrap().is_nan());
Ok(()) Ok(())
} }
@ -493,17 +496,20 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
None, None,
None, None,
Some(&force_one_paragraph), Some(&force_one_paragraph),
true,
); );
assert_eq!(output.len(), 2); assert_eq!(output.len(), 2);
assert_eq!( assert_eq!(
output[0], output[0].text,
"Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and" "Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and"
); );
assert!((output[0].score.unwrap() - (-1.2750)).abs() < 1e-4);
assert_eq!( assert_eq!(
output[1], output[1].text,
"There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I" "There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I"
); );
assert!((output[1].score.unwrap() - (-1.3326)).abs() < 1e-4);
Ok(()) Ok(())
} }

View File

@ -97,11 +97,12 @@ fn mbart_translation() -> anyhow::Result<()> {
None, None,
target_language, target_language,
None, None,
false,
); );
assert_eq!(output.len(), 1); assert_eq!(output.len(), 1);
assert_eq!( assert_eq!(
output[0], output[0].text,
"de_DE Der schnelle braune Fuchs springt über den faulen Hund." "de_DE Der schnelle braune Fuchs springt über den faulen Hund."
); );