mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Addition of TextOutput and IndicesOutput, updated pipelines and tests
This commit is contained in:
parent
c40a218b37
commit
f29e02ecbc
@ -50,10 +50,11 @@ fn main() -> anyhow::Result<()> {
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence);
|
||||
println!("{:?}", sentence.text);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -41,7 +41,16 @@ fn main() -> anyhow::Result<()> {
|
||||
// Define input
|
||||
let input = ["translate English to German: This sentence will get translated to German"];
|
||||
|
||||
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None);
|
||||
let output = t5_model.generate(
|
||||
Some(input.to_vec()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
println!("{:?}", output);
|
||||
|
||||
Ok(())
|
||||
|
@ -716,16 +716,20 @@ impl ConversationOption {
|
||||
attention_mask: Option<Tensor>,
|
||||
) -> Vec<Vec<i64>> {
|
||||
match *self {
|
||||
Self::GPT2(ref model) => model.generate_from_ids_and_past(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Self::GPT2(ref model) => model
|
||||
.generate_from_ids_and_past(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,6 +48,7 @@
|
||||
//! decoder_start_id,
|
||||
//! forced_bos_token_id,
|
||||
//! None,
|
||||
//! false,
|
||||
//! );
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
@ -1215,6 +1216,22 @@ pub(crate) mod private_generation_utils {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// # Generated text output
|
||||
/// Contains generated text and an optional log-likelihood score for the generated sequence
|
||||
pub struct TextOutput {
|
||||
pub text: String,
|
||||
pub score: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// # Generated indices output
|
||||
/// Contains generated indices and an optional log-likelihood score for the generated sequence
|
||||
pub struct IndicesOutput {
|
||||
pub indices: Vec<i64>,
|
||||
pub score: Option<f64>,
|
||||
}
|
||||
|
||||
/// # Common trait for text generation models.
|
||||
/// Main API for text generation
|
||||
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
@ -1232,7 +1249,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Vec<String>` Vector of generated strings based on the prompts of length *number_of_prompts* x *num_return_sequences*.
|
||||
/// * `Vec<TextOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing TextOutput with the generated texts and the generation score if `output_scores` is true.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -1268,6 +1285,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// let max_length = 128;
|
||||
/// let decoder_start_token_id = None;
|
||||
/// let forced_bos_token_id = None;
|
||||
/// let output_scores = true;
|
||||
///
|
||||
/// //Example custom function for fine-grained generation control
|
||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||
@ -1294,6 +1312,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// decoder_start_token_id,
|
||||
/// forced_bos_token_id,
|
||||
/// Some(&force_one_paragraph),
|
||||
/// output_scores
|
||||
/// );
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
@ -1320,11 +1339,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
decoder_start_token_id: impl Into<Option<i64>>,
|
||||
forced_bos_token_id: impl Into<Option<i64>>,
|
||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||
) -> (Vec<String>, Option<Vec<f64>>)
|
||||
output_scores: bool,
|
||||
) -> Vec<TextOutput>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let (generated_indices, scores) = self.generate_indices(
|
||||
let indices_outputs = self.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
@ -1332,13 +1352,18 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
decoder_start_token_id,
|
||||
forced_bos_token_id,
|
||||
prefix_allowed_tokens_fn,
|
||||
true,
|
||||
output_scores,
|
||||
);
|
||||
let mut output = Vec::with_capacity(generated.len());
|
||||
for generated_sequence in generated {
|
||||
output.push(self._get_tokenizer().decode(generated_sequence, true, true));
|
||||
let mut output = Vec::with_capacity(indices_outputs.len());
|
||||
for generated_sequence in indices_outputs {
|
||||
output.push(TextOutput {
|
||||
text: self
|
||||
._get_tokenizer()
|
||||
.decode(generated_sequence.indices, true, true),
|
||||
score: generated_sequence.score,
|
||||
});
|
||||
}
|
||||
(output, scores)
|
||||
output
|
||||
}
|
||||
|
||||
/// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
|
||||
@ -1353,7 +1378,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
|
||||
/// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -1388,6 +1413,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// let max_length = 128;
|
||||
/// let decoder_start_token_id = None;
|
||||
/// let forced_bos_token_id = None;
|
||||
/// let output_scores = true;
|
||||
///
|
||||
/// //Example custom function for fine-grained generation control
|
||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||
@ -1414,6 +1440,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// decoder_start_token_id,
|
||||
/// forced_bos_token_id,
|
||||
/// Some(&force_one_paragraph),
|
||||
/// output_scores
|
||||
/// );
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
@ -1428,7 +1455,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
forced_bos_token_id: impl Into<Option<i64>>,
|
||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||
output_scores: bool,
|
||||
) -> (Vec<Vec<i64>>, Option<Vec<f64>>)
|
||||
) -> Vec<IndicesOutput>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
@ -1482,7 +1509,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
|
||||
/// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -1517,6 +1544,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// let max_length = 128;
|
||||
/// let decoder_start_token_id = None;
|
||||
/// let forced_bos_token_id = None;
|
||||
/// let output_scores = true;
|
||||
///
|
||||
/// //Example custom function for fine-grained generation control
|
||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||
@ -1543,6 +1571,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// decoder_start_token_id,
|
||||
/// forced_bos_token_id,
|
||||
/// Some(&force_one_paragraph),
|
||||
/// output_scores
|
||||
/// );
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
@ -1557,7 +1586,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
forced_bos_token_id: impl Into<Option<i64>>,
|
||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||
output_scores: bool,
|
||||
) -> (Vec<Vec<i64>>, Option<Vec<f64>>) {
|
||||
) -> Vec<IndicesOutput> {
|
||||
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
|
||||
|
||||
let config = PrivateLanguageGenerator::get_config(self);
|
||||
@ -1714,17 +1743,20 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
}
|
||||
});
|
||||
let num_sequences = *decoded.size().first().unwrap();
|
||||
let mut output_ids = Vec::with_capacity(num_sequences as usize);
|
||||
let mut output = Vec::with_capacity(num_sequences as usize);
|
||||
for sequence_index in 0..num_sequences {
|
||||
let sequence_output_ids = decoded
|
||||
let indices = decoded
|
||||
.as_ref()
|
||||
.get(sequence_index)
|
||||
.iter::<i64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<i64>>();
|
||||
output_ids.push(sequence_output_ids.clone());
|
||||
let score = scores
|
||||
.as_ref()
|
||||
.map(|scores_value| scores_value[sequence_index as usize]);
|
||||
output.push(IndicesOutput { indices, score });
|
||||
}
|
||||
(output_ids, scores)
|
||||
output
|
||||
}
|
||||
|
||||
/// Returns a reference to the text generator's tokenizer
|
||||
|
@ -263,18 +263,62 @@ impl SummarizationOption {
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
match *self {
|
||||
Self::Bart(ref model) => {
|
||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
||||
}
|
||||
Self::T5(ref model) => {
|
||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
||||
}
|
||||
Self::ProphetNet(ref model) => {
|
||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
||||
}
|
||||
Self::Pegasus(ref model) => {
|
||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
||||
}
|
||||
Self::Bart(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::T5(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::ProphetNet(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::Pegasus(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -249,56 +249,76 @@ impl TextGenerationOption {
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
match *self {
|
||||
Self::GPT(ref model) => model.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Self::GPT2(ref model) => model.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Self::GPTNeo(ref model) => model.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Self::XLNet(ref model) => model.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Self::Reformer(ref model) => model.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Self::GPT(ref model) => model
|
||||
.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::GPT2(ref model) => model
|
||||
.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::GPTNeo(ref model) => model
|
||||
.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::XLNet(ref model) => model
|
||||
.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::Reformer(ref model) => model
|
||||
.generate_indices(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
min_length,
|
||||
max_length,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -674,12 +674,34 @@ impl TranslationOption {
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
match *self {
|
||||
Self::Marian(ref model) => {
|
||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
||||
}
|
||||
Self::T5(ref model) => {
|
||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
||||
}
|
||||
Self::Marian(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::T5(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -428,17 +428,20 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
||||
None,
|
||||
None,
|
||||
Some(&force_one_paragraph),
|
||||
true,
|
||||
);
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
output[0].text,
|
||||
"Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n"
|
||||
);
|
||||
assert!(output[0].score.unwrap().is_nan());
|
||||
assert_eq!(
|
||||
output[1],
|
||||
output[1].text,
|
||||
"There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And"
|
||||
);
|
||||
assert!(output[1].score.unwrap().is_nan());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -493,17 +496,20 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
||||
None,
|
||||
None,
|
||||
Some(&force_one_paragraph),
|
||||
true,
|
||||
);
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
output[0].text,
|
||||
"Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and"
|
||||
);
|
||||
assert!((output[0].score.unwrap() - (-1.2750)).abs() < 1e-4);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
output[1].text,
|
||||
"There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I"
|
||||
);
|
||||
assert!((output[1].score.unwrap() - (-1.3326)).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -97,11 +97,12 @@ fn mbart_translation() -> anyhow::Result<()> {
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
output[0].text,
|
||||
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
|
||||
);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user