Updated integration tests

This commit is contained in:
Guillaume B 2020-10-27 18:52:48 +01:00
parent 297c9d2c8d
commit 3d4bb6535d
9 changed files with 99 additions and 55 deletions

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None);
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None);
println!("{:?}", output);
Ok(())

View File

@ -706,7 +706,9 @@ impl ConversationOption {
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
match *self {
Self::GPT2(ref model) => model.generate_from_ids_and_past(input_ids, attention_mask),
Self::GPT2(ref model) => {
model.generate_from_ids_and_past(input_ids, attention_mask, None, None, None)
}
}
}
}

View File

@ -1412,7 +1412,7 @@ impl XLNetGenerator {
/// # Ok(())
/// # }
/// ```
pub fn new(mut generate_config: GenerateConfig) -> Result<XLNetGenerator, RustBertError> {
pub fn new(generate_config: GenerateConfig) -> Result<XLNetGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
@ -2434,11 +2434,20 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
decoder_start_token_id: impl Into<Option<i64>>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
let generated = self.generate_indices(prompt_texts, attention_mask);
let generated = self.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
decoder_start_token_id,
);
let mut output = Vec::with_capacity(generated.len());
for generated_sequence in generated {
output.push(self.get_tokenizer().decode(generated_sequence, true, true));
@ -2490,6 +2499,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
decoder_start_token_id: impl Into<Option<i64>>,
) -> Vec<Vec<i64>>
where
S: AsRef<[&'a str]>,
@ -2497,7 +2509,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let max_length = config.max_length;
let max_length = max_length.into().unwrap_or(config.max_length);
let encoding_max_len = if self.is_encoder_decoder() {
1024i64
} else {
@ -2522,13 +2534,22 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
),
},
};
self.generate_from_ids_and_past(input_ids, attention_mask)
self.generate_from_ids_and_past(
input_ids,
attention_mask,
min_length,
max_length,
decoder_start_token_id,
)
}
fn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
decoder_start_token_id: impl Into<Option<i64>>,
) -> Vec<Vec<i64>> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
@ -2536,8 +2557,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let do_sample = config.do_sample;
let num_return_sequences = config.num_return_sequences;
let num_beams = config.num_beams;
let min_length = config.min_length;
let max_length = config.max_length;
let min_length = min_length.into().unwrap_or(config.min_length);
let max_length = max_length.into().unwrap_or(config.max_length);
let early_stopping = config.early_stopping;
let temperature = config.temperature;
let top_k = config.top_k;
@ -2613,9 +2634,10 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
(input_ids, attention_mask)
}
} else {
let decoder_start_token_id = self
.get_decoder_start_id()
.expect("decoder start id must be specified for encoder decoders");
let decoder_start_token_id = decoder_start_token_id.into().unwrap_or(
self.get_decoder_start_id()
.expect("decoder start id must be specified for encoder decoders"),
);
let input_ids = Tensor::full(
&[effective_batch_size * num_beams as i64, 1],
decoder_start_token_id,

View File

@ -262,8 +262,8 @@ impl SummarizationOption {
S: AsRef<[&'a str]>,
{
match *self {
Self::Bart(ref model) => model.generate(prompt_texts, attention_mask),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask),
Self::Bart(ref model) => model.generate(prompt_texts, attention_mask, None, None, None),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask, None, None, None),
}
}
}

View File

@ -96,14 +96,22 @@ impl TextGenerationOption {
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
min_length: Option<i64>,
max_length: Option<i64>,
) -> Vec<Vec<i64>>
where
S: AsRef<[&'a str]>,
{
match *self {
Self::GPT2(ref model) => model.generate_indices(prompt_texts, attention_mask),
Self::GPT(ref model) => model.generate_indices(prompt_texts, attention_mask),
Self::XLNet(ref model) => model.generate_indices(prompt_texts, attention_mask),
Self::GPT2(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
Self::GPT(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
Self::XLNet(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
}
}
}
@ -113,6 +121,8 @@ pub struct TextGenerationModel {
model: TextGenerationOption,
prefix: Option<String>,
prefix_length: Option<i64>,
min_length: i64,
max_length: i64,
}
impl TextGenerationModel {
@ -132,9 +142,7 @@ impl TextGenerationModel {
/// # Ok(())
/// # }
/// ```
pub fn new(
mut generation_config: GenerateConfig,
) -> Result<TextGenerationModel, RustBertError> {
pub fn new(generation_config: GenerateConfig) -> Result<TextGenerationModel, RustBertError> {
let prefix = match generation_config.model_type {
ModelType::XLNet => Some(
"In 1991, the remains of Russian Tsar Nicholas II and his family \
@ -151,23 +159,22 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
),
_ => None,
};
let min_length = generation_config.min_length;
let max_length = generation_config.max_length;
let model = TextGenerationOption::new(generation_config)?;
let prefix_length = if let Some(prefix) = &prefix {
Some(model.get_tokenizer().tokenize(prefix).len() as i64)
} else {
None
};
let model = TextGenerationOption::new(generation_config)?;
if let Some(prefix_length) = prefix_length {
generation_config.min_length += prefix_length;
generation_config.max_length += prefix_length;
}
Ok(TextGenerationModel {
model,
prefix,
prefix_length,
min_length,
max_length,
})
}
@ -206,9 +213,9 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
(None, Some(pipeline_prefix)) => (Some(pipeline_prefix.as_str()), self.prefix_length),
(None, None) => (None, None),
};
let generated_indices = match prefix {
None => self.model.generate_indices(Some(texts), None),
Some(prefix) => {
let generated_indices = match (prefix, prefix_length) {
(None, _) => self.model.generate_indices(Some(texts), None, None, None),
(Some(prefix), Some(prefix_length)) => {
let texts = texts
.as_ref()
.iter()
@ -217,8 +224,11 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
self.model.generate_indices(
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
None,
Some(self.min_length + prefix_length),
Some(self.max_length + prefix_length),
)
}
_ => panic!("Prefix length not defined but prefix provided!"),
};
let mut output = Vec::with_capacity(generated_indices.len());

View File

@ -556,8 +556,10 @@ impl TranslationOption {
S: AsRef<[&'a str]>,
{
match *self {
Self::Marian(ref model) => model.generate(prompt_texts, attention_mask),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask),
Self::Marian(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None)
}
Self::T5(ref model) => model.generate(prompt_texts, attention_mask, None, None, None),
}
}
}

View File

@ -2,12 +2,12 @@ use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::conversation::{
ConversationConfig, ConversationManager, ConversationModel,
};
use rust_bert::pipelines::generation_utils::{
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
};
use rust_bert::pipelines::generation_utils::{Cache, GenerateConfig, LMHeadModel};
use rust_bert::pipelines::text_generation::TextGenerationModel;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
@ -124,6 +124,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::GPT2,
model_resource,
config_resource,
vocab_resource,
@ -135,10 +136,10 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
repetition_penalty: 1.1,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(Some(vec![input_context]), None);
let output = model.generate(&[input_context], None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
@ -160,6 +161,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::GPT2,
model_resource,
config_resource,
vocab_resource,
@ -171,10 +173,10 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
num_return_sequences: 3,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(Some(vec![input_context]), None);
let output = model.generate(&[input_context], None);
assert_eq!(output.len(), 3);
assert_eq!(
@ -207,6 +209,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::GPT2,
model_resource,
config_resource,
vocab_resource,
@ -218,11 +221,11 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
num_return_sequences: 3,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat";
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
let output = model.generate(&[input_context_1, input_context_2], None);
assert_eq!(output.len(), 6);
assert_eq!(
@ -267,6 +270,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::GPT2,
model_resource,
config_resource,
vocab_resource,
@ -278,11 +282,11 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
num_return_sequences: 3,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
let output = model.generate(&[input_context_1, input_context_2], None);
assert_eq!(output.len(), 6);
assert_eq!(

View File

@ -3,9 +3,9 @@ use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::generation_utils::{Cache, GenerateConfig, LMHeadModel};
use rust_bert::pipelines::text_generation::TextGenerationModel;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
@ -119,6 +119,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
config_resource,
vocab_resource,
@ -131,10 +132,10 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
temperature: 1.1,
..Default::default()
};
let model = OpenAIGenerator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec![input_context]), None);
let output = model.generate(&[input_context], None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
@ -160,6 +161,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
config_resource,
vocab_resource,
@ -171,10 +173,10 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
num_return_sequences: 3,
..Default::default()
};
let model = OpenAIGenerator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(Some(vec![input_context]), None);
let output = model.generate(&[input_context], None);
assert_eq!(output.len(), 3);
assert_eq!(
@ -211,6 +213,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
config_resource,
vocab_resource,
@ -222,11 +225,11 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
num_return_sequences: 3,
..Default::default()
};
let model = OpenAIGenerator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
let output = model.generate(&[input_context_1, input_context_2], None);
assert_eq!(output.len(), 6);
@ -278,6 +281,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
// Set-up masked LM model
let generate_config = GenerateConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
config_resource,
vocab_resource,
@ -289,11 +293,11 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
num_return_sequences: 3,
..Default::default()
};
let model = OpenAIGenerator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
let output = model.generate(&[input_context_1, input_context_2], None);
assert_eq!(output.len(), 6);
// Left padding impacts the generated sentences output

View File

@ -1,5 +1,5 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator, XLNetGenerator};
use rust_bert::pipelines::generation_utils::GenerateConfig;
use rust_bert::pipelines::text_generation::TextGenerationModel;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::xlnet::{
@ -231,7 +231,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
assert_eq!(output.len(), 1);
assert_eq!(
output[0],
" Once upon a time, there was a time when there was only one man in the world who could do all the things he wanted to do. There was no one who"
" Once upon a time, there was a time when there was only one man in the world who could do all the things he wanted to do. There was only"
);
Ok(())