Change generate return type to Result (#437)

* - Changed the return type of generate method to be `Result`, removed fallible unwraps

* Fix doctests
This commit is contained in:
guillaume-be 2023-12-04 17:58:21 +00:00 committed by GitHub
parent 9f2cd17e91
commit 1f4d344668
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 134 additions and 113 deletions

View File

@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file. The format
## Changed
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).
- (BREAKING) Text generation traits and pipelines (including conversation, summarization and translation) now return a `Result` for improved error handling
## [0.21.0] - 2023-06-03
## Added

View File

@ -49,7 +49,7 @@ about exoplanets like K2-18b."];
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
let output = summarization_model.summarize(&input)?;
for sentence in output {
println!("{sentence}");
}
@ -58,7 +58,7 @@ about exoplanets like K2-18b."];
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
let output = summarization_model.summarize(&input)?;
for sentence in output {
println!("{sentence}");
}

View File

@ -30,7 +30,7 @@ fn main() -> anyhow::Result<()> {
let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
for sentence in output {
println!("{sentence:?}");

View File

@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
for sentence in output {
println!("{sentence:?}");

View File

@ -57,7 +57,7 @@ fn main() -> anyhow::Result<()> {
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
for sentence in output {
println!("{sentence}");

View File

@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
"It was a very nice and sunny",
"It was a gloom winter night, and",
];
let output = model.generate(&prompts, None);
let output = model.generate(&prompts, None)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0], "It was a very nice and sunny day, and I was sitting in the garden of my house, enjoying the sun and the fresh air. I was thinking");

View File

@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
let input_context_1 = "The really great men must, I think,";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
for sentence in output {
println!("{sentence}");

View File

@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "Once upon a time,";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
for sentence in output {
println!("{sentence}");

View File

@ -72,7 +72,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -66,7 +66,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -68,7 +68,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -54,7 +54,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -55,7 +55,7 @@
//!
//! let input_context_1 = "It was a very nice and sunny";
//! let input_context_2 = "It was a gloom winter night, and";
//! let output = model.generate(&[input_context_1, input_context_2], None);
//! let output = model.generate(&[input_context_1, input_context_2], None)?;
//!
//! for sentence in output {
//! println!("{}", sentence);

View File

@ -72,7 +72,7 @@
//! about exoplanets like K2-18b."];
//!
//! // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
//! let _output = summarization_model.summarize(&input);
//! let _output = summarization_model.summarize(&input)?;
//! for sentence in _output {
//! println!("{}", sentence);
//! }

View File

@ -763,14 +763,14 @@ impl ConversationOption {
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
match *self {
) -> Result<Vec<Vec<i64>>, RustBertError> {
Ok(match *self {
Self::GPT2(ref model) => model
.generate_from_ids_and_past(input_ids, attention_mask, None)
.generate_from_ids_and_past(input_ids, attention_mask, None)?
.into_iter()
.map(|output| output.indices)
.collect(),
}
})
}
}
@ -887,9 +887,9 @@ impl ConversationModel {
pub fn generate_responses<'a>(
&self,
conversation_manager: &'a mut ConversationManager,
) -> HashMap<&'a Uuid, &'a str> {
) -> Result<HashMap<&'a Uuid, &'a str>, RustBertError> {
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
if !active_uuid.is_empty() {
let updated_conversations = if !active_uuid.is_empty() {
let texts = active_conversations
.iter()
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
@ -906,7 +906,7 @@ impl ConversationModel {
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self
.model
.generate_from_ids_and_past(input_tensor, Some(attention_mask));
.generate_from_ids_and_past(input_tensor, Some(attention_mask))?;
let removed_padding_quantities = self.clean_padding_indices(&mut generated);
let mut output = HashMap::with_capacity(active_uuid.len());
@ -936,7 +936,8 @@ impl ConversationModel {
output
} else {
HashMap::new()
}
};
Ok(updated_conversations)
}
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) -> Vec<(usize, usize)> {

View File

@ -1775,11 +1775,11 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedTextOutput>
) -> Result<Vec<GeneratedTextOutput>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let indices_outputs = self.generate_indices(prompt_texts, generate_options);
let indices_outputs = self.generate_indices(prompt_texts, generate_options)?;
let mut output = Vec::with_capacity(indices_outputs.len());
for generated_sequence in indices_outputs {
output.push(GeneratedTextOutput {
@ -1789,7 +1789,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
score: generated_sequence.score,
});
}
output
Ok(output)
}
/// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
@ -1869,7 +1869,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedIndicesOutput>
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
@ -1896,11 +1896,12 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
}
None => match self.get_bos_id() {
Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
None => panic!(
None => return Err(RustBertError::ValueError(
"A model with a BOS token must be used to start generation with an empty input"
),
.to_string(),
)),
},
_ => return Vec::new(),
_ => return Ok(Vec::new()),
};
self.generate_from_ids_and_past(input_ids, None, generate_options)
}
@ -1960,7 +1961,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
mut input_ids: Tensor,
mut attention_mask: Option<Tensor>,
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedIndicesOutput> {
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).cloned();
let config = PrivateLanguageGenerator::get_config(self);
@ -2033,7 +2034,9 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
};
let encoder_outputs = if self.is_encoder_decoder() {
let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap();
let encoder_outputs = self
.encode(&input_ids, Some(&attention_mask))
.ok_or(RustBertError::UnsupportedError)?;
let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
.view((-1, 1))
.repeat([1, num_beams * effective_batch_mult])
@ -2067,10 +2070,11 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
(input_ids, attention_mask)
}
} else {
let decoder_start_token_id = decoder_start_token_id.unwrap_or_else(|| {
self.get_decoder_start_id()
.expect("decoder start id must be specified for encoder decoders")
});
let decoder_start_token_id = decoder_start_token_id
.or(self.get_decoder_start_id())
.ok_or(RustBertError::ValueError(
"decoder start id must be specified for encoder decoders".to_string(),
))?;
let input_ids = Tensor::full(
[effective_batch_size * num_beams, 1],
decoder_start_token_id,
@ -2103,9 +2107,16 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
config.max_length
};
if let Some(max_length) = max_length {
if input_ids.size2()?.1 > max_length {
return Err(RustBertError::ValueError("The input ids exceeds the maximum length for generation.\
Reduce the size of the provided input ids or increase the allowable maximum generation length.".to_string()));
}
}
if max_length.is_none() & eos_token_ids.is_none() {
panic!("No maximum length given for a model without an EOS token. \
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`")
return Err(RustBertError::InvalidConfigurationError("No maximum length given for a model without an EOS token. \
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`".to_string()));
}
let gen_opt = InternalGenerateOptions {
@ -2182,7 +2193,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
token_scores,
});
}
output
Ok(output)
}
/// Returns a reference to the text generator's tokenizer

View File

@ -338,43 +338,43 @@ impl SummarizationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Vec<String>
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
match *self {
Ok(match *self {
Self::Bart(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::LongT5(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
}
})
}
}
@ -506,7 +506,7 @@ impl SummarizationModel {
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
pub fn summarize<S>(&self, texts: &[S]) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{

View File

@ -335,7 +335,7 @@ impl TextGenerationOption {
prompt_texts: Option<&[S]>,
min_length: Option<i64>,
max_length: Option<i64>,
) -> Vec<Vec<i64>>
) -> Result<Vec<Vec<i64>>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
@ -344,49 +344,49 @@ impl TextGenerationOption {
max_length,
..Default::default()
});
match *self {
Ok(match *self {
Self::GPT(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPT2(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPTNeo(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPTJ(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::XLNet(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::Reformer(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::T5(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate_indices(prompt_texts, generate_options)
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
}
})
}
pub fn half(&mut self) -> Result<(), RustBertError> {
@ -599,7 +599,11 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
/// # Ok(())
/// # }
/// ```
pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into<Option<&'a str>>) -> Vec<String>
pub fn generate<'a, S>(
&self,
texts: &[S],
prefix: impl Into<Option<&'a str>>,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
@ -625,8 +629,10 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
self.max_length.map(|max_length| max_length + prefix_length),
)
}
_ => panic!("Prefix length not defined but prefix provided!"),
};
_ => Err(RustBertError::ValueError(
"Prefix length not defined but prefix provided!".to_string(),
)),
}?;
let mut output = Vec::with_capacity(generated_indices.len());
for generated_sequence in generated_indices {
@ -636,7 +642,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
true,
));
}
output
Ok(output)
}
}

View File

@ -1219,18 +1219,18 @@ impl TranslationOption {
&self,
prompt_texts: Option<&[S]>,
forced_bos_token_id: Option<i64>,
) -> Vec<String>
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
match *self {
Ok(match *self {
Self::Marian(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
@ -1240,7 +1240,7 @@ impl TranslationOption {
..Default::default()
};
model
.generate(prompt_texts, Some(generate_options))
.generate(prompt_texts, Some(generate_options))?
.into_iter()
.map(|output| output.text)
.collect()
@ -1251,7 +1251,7 @@ impl TranslationOption {
..Default::default()
};
model
.generate(prompt_texts, Some(generate_options))
.generate(prompt_texts, Some(generate_options))?
.into_iter()
.map(|output| output.text)
.collect()
@ -1264,12 +1264,12 @@ impl TranslationOption {
..Default::default()
});
model
.generate(prompt_texts, generate_options)
.generate(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.text)
.collect()
}
}
})
}
}
@ -1484,7 +1484,7 @@ impl TranslationModel {
&self.supported_target_languages,
)?;
Ok(match prefix {
match prefix {
Some(value) => {
let texts = texts
.iter()
@ -1493,7 +1493,7 @@ impl TranslationModel {
self.model.generate(Some(&texts), forced_bos_token_id)
}
None => self.model.generate(Some(texts), forced_bos_token_id),
})
}
}
}

View File

@ -128,7 +128,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = model.summarize(&input);
let output = model.summarize(&input)?;
assert_eq!(output.len(), 1);
assert_eq!(output[0], " K2-18b is not too hot and not too cold for liquid water to exist. \
@ -189,7 +189,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = model.summarize(&input);
let output = model.summarize(&input)?;
assert_eq!(output.len(), 1);
assert_eq!(output[0], " K2-18b, a planet circling a star in the constellation Leo, is not too hot and not too cold for liquid water to exist. \

View File

@ -121,7 +121,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 1);
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
@ -154,7 +154,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 3);
assert_eq!(
@ -199,7 +199,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
let input_context_1 = "The dog";
let input_context_2 = "The cat";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 6);
assert_eq!(
@ -255,7 +255,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 6);
assert_eq!(
@ -313,7 +313,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
let input_context_1 = "It was a nice and";
let input_context_2 = "Language models can generate";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 6);
assert_eq!(
@ -392,7 +392,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
let output = model.generate(
Some(&[input_context_1, input_context_2]),
Some(generate_options),
);
)?;
assert_eq!(output.len(), 2);
assert_eq!(
@ -455,8 +455,9 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
..Default::default()
};
let baseline_output = model.generate(Some(&[input_context_1]), Some(baseline_generate_options));
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options));
let baseline_output =
model.generate(Some(&[input_context_1]), Some(baseline_generate_options))?;
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options))?;
assert_eq!(baseline_output.len(), 1);
assert_eq!(
@ -521,8 +522,9 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
..Default::default()
};
let baseline_output = model.generate(Some(&[input_context_1]), Some(baseline_generate_options));
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options));
let baseline_output =
model.generate(Some(&[input_context_1]), Some(baseline_generate_options))?;
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options))?;
assert_eq!(baseline_output.len(), 1);
assert_eq!(
@ -589,7 +591,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
let output = model.generate(
Some(&[input_context_1, input_context_2]),
Some(generate_options),
);
)?;
assert_eq!(output.len(), 2);
assert_eq!(
@ -638,7 +640,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
let output = model.generate_indices(
Some(&[input_context_1, input_context_2]),
Some(generate_options),
);
)?;
assert_eq!(output.len(), 2);
assert_eq!(
@ -694,7 +696,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
let output = model.generate_indices(
Some(&[input_context_1, input_context_2]),
Some(generate_options),
);
)?;
assert_eq!(output.len(), 2);
assert_eq!(
@ -735,7 +737,7 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> {
conversation_manager.create("Going to the movies tonight - any suggestions?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_id).unwrap(), &"The Big Lebowski");
@ -744,12 +746,12 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> {
.get(&conversation_id)
.unwrap()
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_id).unwrap(), &"It\'s a comedy.");
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 0);
Ok(())
@ -773,7 +775,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 2);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
assert_eq!(
@ -786,12 +788,12 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
.get(&conversation_1_id)
.unwrap()
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 0);
Ok(())
@ -817,7 +819,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result
let conversation_2_id = conversation_manager.create("Hello how are you?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 2);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
assert_eq!(
@ -835,12 +837,12 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result
.unwrap()
.add_user_input("Fine.");
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 2);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 0);
Ok(())
@ -864,7 +866,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> any
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 2);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
assert_eq!(
@ -878,7 +880,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> any
.get(&conversation_2_id)
.unwrap()
.add_user_input("Why do you recommend it?");
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 1);
assert_eq!(
output.get(&conversation_2_id).unwrap(),
@ -886,7 +888,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> any
);
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
let output = conversation_model.generate_responses(&mut conversation_manager)?;
assert_eq!(output.len(), 0);
Ok(())

View File

@ -143,7 +143,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0], "It was a very nice and sunny day. The sun was shining through the clouds, and the sky was clear. The wind was blowing through the trees,");

View File

@ -53,7 +53,7 @@ mod tests {
let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
let input_context = "The dog";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 3);
assert_eq!(

View File

@ -49,7 +49,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
let output = model.summarize(&input);
let output = model.summarize(&input)?;
assert_eq! (
output[0],

View File

@ -222,7 +222,7 @@ mod tests {
..Default::default()
})?;
let prompts = ["It was a very nice and sunny"];
let output = text_generation_model.generate(&prompts, None);
let output = text_generation_model.generate(&prompts, None)?;
assert_eq!(output.len(), 1);
assert_eq!(output[0], "It was a very nice and sunny day. I was very happy with the weather. I was very happy with the weather. I was very happy with");
Ok(())

View File

@ -134,7 +134,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 1);
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
@ -176,7 +176,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 3);
assert_eq!(
@ -230,7 +230,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 6);
@ -298,7 +298,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 6);
// Left padding impacts the generated sentences output

View File

@ -54,7 +54,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
let output = summarization_model.summarize(&input)?;
assert_eq!(output.len(), 1);
assert_eq!(

View File

@ -57,7 +57,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
let output = summarization_model.summarize(&input)?;
assert_eq!(output.len(), 1);
assert_eq!(

View File

@ -64,7 +64,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
let input_context_1 = "The really great men must, I think,";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0], " The really great men must, I think, anyway waiting for some unknown reason, but Nikodim Fomitch and Ilya Petrovitch looked at him anguish invitable incidently at him. He could not resist an impression which might be setting");

View File

@ -102,7 +102,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
let output = model.summarize(&input);
let output = model.summarize(&input)?;
assert_eq! (
output[0],

View File

@ -222,7 +222,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "Once upon a time,";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 1);
assert_eq!(