Addition of TextOutput and IndicesOutput, updated pipelines and tests

This commit is contained in:
Guillaume B 2021-06-16 18:15:22 +02:00
parent c40a218b37
commit f29e02ecbc
9 changed files with 240 additions and 101 deletions

View File

@ -50,10 +50,11 @@ fn main() -> anyhow::Result<()> {
None,
target_language,
None,
false,
);
for sentence in output {
println!("{:?}", sentence);
println!("{:?}", sentence.text);
}
Ok(())
}

View File

@ -41,7 +41,16 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None);
let output = t5_model.generate(
Some(input.to_vec()),
None,
None,
None,
None,
None,
None,
false,
);
println!("{:?}", output);
Ok(())

View File

@ -716,16 +716,20 @@ impl ConversationOption {
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
match *self {
Self::GPT2(ref model) => model.generate_from_ids_and_past(
input_ids,
attention_mask,
None,
None,
None,
None,
None,
false,
),
Self::GPT2(ref model) => model
.generate_from_ids_and_past(
input_ids,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
}
}
}

View File

@ -48,6 +48,7 @@
//! decoder_start_id,
//! forced_bos_token_id,
//! None,
//! false,
//! );
//! # Ok(())
//! # }
@ -1215,6 +1216,22 @@ pub(crate) mod private_generation_utils {
}
}
#[derive(Debug, Clone)]
/// # Generated text output
/// Contains generated text and an optional log-likelihood score for the generated sequence
pub struct TextOutput {
pub text: String,
pub score: Option<f64>,
}
#[derive(Debug, Clone)]
/// # Generated indices output
/// Contains generated indices and an optional log-likelihood score for the generated sequence
pub struct IndicesOutput {
pub indices: Vec<i64>,
pub score: Option<f64>,
}
/// # Common trait for text generation models.
/// Main API for text generation
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
@ -1232,7 +1249,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
///
/// # Returns
/// * `Vec<String>` Vector of generated strings based on the prompts of length *number_of_prompts* x *num_return_sequences*.
/// * `Vec<TextOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing TextOutput with the generated texts and the generation score if `output_scores` is true.
///
/// # Example
///
@ -1268,6 +1285,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128;
/// let decoder_start_token_id = None;
/// let forced_bos_token_id = None;
/// let output_scores = true;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1294,6 +1312,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id,
/// forced_bos_token_id,
/// Some(&force_one_paragraph),
/// output_scores
/// );
/// # Ok(())
/// # }
@ -1320,11 +1339,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> (Vec<String>, Option<Vec<f64>>)
output_scores: bool,
) -> Vec<TextOutput>
where
S: AsRef<[&'a str]>,
{
let (generated_indices, scores) = self.generate_indices(
let indices_outputs = self.generate_indices(
prompt_texts,
attention_mask,
min_length,
@ -1332,13 +1352,18 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id,
forced_bos_token_id,
prefix_allowed_tokens_fn,
true,
output_scores,
);
let mut output = Vec::with_capacity(generated.len());
for generated_sequence in generated {
output.push(self._get_tokenizer().decode(generated_sequence, true, true));
let mut output = Vec::with_capacity(indices_outputs.len());
for generated_sequence in indices_outputs {
output.push(TextOutput {
text: self
._get_tokenizer()
.decode(generated_sequence.indices, true, true),
score: generated_sequence.score,
});
}
(output, scores)
output
}
/// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
@ -1353,7 +1378,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
///
/// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
/// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
///
/// # Example
///
@ -1388,6 +1413,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128;
/// let decoder_start_token_id = None;
/// let forced_bos_token_id = None;
/// let output_scores = true;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1414,6 +1440,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id,
/// forced_bos_token_id,
/// Some(&force_one_paragraph),
/// output_scores
/// );
/// # Ok(())
/// # }
@ -1428,7 +1455,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
output_scores: bool,
) -> (Vec<Vec<i64>>, Option<Vec<f64>>)
) -> Vec<IndicesOutput>
where
S: AsRef<[&'a str]>,
{
@ -1482,7 +1509,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
///
/// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
/// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
///
/// # Example
///
@ -1517,6 +1544,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128;
/// let decoder_start_token_id = None;
/// let forced_bos_token_id = None;
/// let output_scores = true;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1543,6 +1571,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id,
/// forced_bos_token_id,
/// Some(&force_one_paragraph),
/// output_scores
/// );
/// # Ok(())
/// # }
@ -1557,7 +1586,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
output_scores: bool,
) -> (Vec<Vec<i64>>, Option<Vec<f64>>) {
) -> Vec<IndicesOutput> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
@ -1714,17 +1743,20 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
}
});
let num_sequences = *decoded.size().first().unwrap();
let mut output_ids = Vec::with_capacity(num_sequences as usize);
let mut output = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences {
let sequence_output_ids = decoded
let indices = decoded
.as_ref()
.get(sequence_index)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
output_ids.push(sequence_output_ids.clone());
let score = scores
.as_ref()
.map(|scores_value| scores_value[sequence_index as usize]);
output.push(IndicesOutput { indices, score });
}
(output_ids, scores)
output
}
/// Returns a reference to the text generator's tokenizer

View File

@ -263,18 +263,62 @@ impl SummarizationOption {
S: AsRef<[&'a str]>,
{
match *self {
Self::Bart(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::T5(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::ProphetNet(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::Pegasus(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::Bart(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
}
}
}

View File

@ -249,56 +249,76 @@ impl TextGenerationOption {
S: AsRef<[&'a str]>,
{
match *self {
Self::GPT(ref model) => model.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
),
Self::GPT2(ref model) => model.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
),
Self::GPTNeo(ref model) => model.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
),
Self::XLNet(ref model) => model.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
),
Self::Reformer(ref model) => model.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
),
Self::GPT(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPT2(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPTNeo(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
Self::XLNet(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
Self::Reformer(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
}
}
}

View File

@ -674,12 +674,34 @@ impl TranslationOption {
S: AsRef<[&'a str]>,
{
match *self {
Self::Marian(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::T5(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::Marian(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
}
}
}

View File

@ -428,17 +428,20 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
None,
None,
Some(&force_one_paragraph),
true,
);
assert_eq!(output.len(), 2);
assert_eq!(
output[0],
output[0].text,
"Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n"
);
assert!(output[0].score.unwrap().is_nan());
assert_eq!(
output[1],
output[1].text,
"There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And"
);
assert!(output[1].score.unwrap().is_nan());
Ok(())
}
@ -493,17 +496,20 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
None,
None,
Some(&force_one_paragraph),
true,
);
assert_eq!(output.len(), 2);
assert_eq!(
output[0],
output[0].text,
"Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and"
);
assert!((output[0].score.unwrap() - (-1.2750)).abs() < 1e-4);
assert_eq!(
output[1],
output[1].text,
"There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I"
);
assert!((output[1].score.unwrap() - (-1.3326)).abs() < 1e-4);
Ok(())
}

View File

@ -97,11 +97,12 @@ fn mbart_translation() -> anyhow::Result<()> {
None,
target_language,
None,
false,
);
assert_eq!(output.len(), 1);
assert_eq!(
output[0],
output[0].text,
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
);