mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-03 23:57:15 +03:00
Change generate
return type to Result (#437)
* - Changed the return type of generate method to be `Result`, removed fallible unwraps * Fix doctests
This commit is contained in:
parent
9f2cd17e91
commit
1f4d344668
@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file. The format
|
||||
|
||||
## Changed
|
||||
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).
|
||||
- (BREAKING) Text generation traits and pipelines (including conversation, summarization and translation) now return a `Result` for improved error handling
|
||||
|
||||
## [0.21.0] - 2023-06-03
|
||||
## Added
|
||||
|
@ -49,7 +49,7 @@ about exoplanets like K2-18b."];
|
||||
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = summarization_model.summarize(&input);
|
||||
let output = summarization_model.summarize(&input)?;
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
@ -58,7 +58,7 @@ about exoplanets like K2-18b."];
|
||||
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = summarization_model.summarize(&input);
|
||||
let output = summarization_model.summarize(&input)?;
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let input_context = "The dog";
|
||||
// let second_input_context = "The cat was";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{sentence:?}");
|
||||
|
@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let input_context = "The dog";
|
||||
// let second_input_context = "The cat was";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{sentence:?}");
|
||||
|
@ -57,7 +57,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let input_context_1 = "It was a very nice and sunny";
|
||||
let input_context_2 = "It was a gloom winter night, and";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
|
@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
|
||||
"It was a very nice and sunny",
|
||||
"It was a gloom winter night, and",
|
||||
];
|
||||
let output = model.generate(&prompts, None);
|
||||
let output = model.generate(&prompts, None)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0], "It was a very nice and sunny day, and I was sitting in the garden of my house, enjoying the sun and the fresh air. I was thinking");
|
||||
|
@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let input_context_1 = "The really great men must, I think,";
|
||||
let input_context_2 = "It was a gloom winter night, and";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
|
@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model = TextGenerationModel::new(generate_config)?;
|
||||
|
||||
let input_context = "Once upon a time,";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
|
@ -72,7 +72,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
let _output = summarization_model.summarize(&input)?;
|
||||
for sentence in _output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
let _output = summarization_model.summarize(&input)?;
|
||||
for sentence in _output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
let _output = summarization_model.summarize(&input)?;
|
||||
for sentence in _output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
let _output = summarization_model.summarize(&input)?;
|
||||
for sentence in _output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
@ -55,7 +55,7 @@
|
||||
//!
|
||||
//! let input_context_1 = "It was a very nice and sunny";
|
||||
//! let input_context_2 = "It was a gloom winter night, and";
|
||||
//! let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
//! let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
//!
|
||||
//! for sentence in output {
|
||||
//! println!("{}", sentence);
|
||||
|
@ -72,7 +72,7 @@
|
||||
//! about exoplanets like K2-18b."];
|
||||
//!
|
||||
//! // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
//! let _output = summarization_model.summarize(&input);
|
||||
//! let _output = summarization_model.summarize(&input)?;
|
||||
//! for sentence in _output {
|
||||
//! println!("{}", sentence);
|
||||
//! }
|
||||
|
@ -763,14 +763,14 @@ impl ConversationOption {
|
||||
&self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Option<Tensor>,
|
||||
) -> Vec<Vec<i64>> {
|
||||
match *self {
|
||||
) -> Result<Vec<Vec<i64>>, RustBertError> {
|
||||
Ok(match *self {
|
||||
Self::GPT2(ref model) => model
|
||||
.generate_from_ids_and_past(input_ids, attention_mask, None)
|
||||
.generate_from_ids_and_past(input_ids, attention_mask, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -887,9 +887,9 @@ impl ConversationModel {
|
||||
pub fn generate_responses<'a>(
|
||||
&self,
|
||||
conversation_manager: &'a mut ConversationManager,
|
||||
) -> HashMap<&'a Uuid, &'a str> {
|
||||
) -> Result<HashMap<&'a Uuid, &'a str>, RustBertError> {
|
||||
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
|
||||
if !active_uuid.is_empty() {
|
||||
let updated_conversations = if !active_uuid.is_empty() {
|
||||
let texts = active_conversations
|
||||
.iter()
|
||||
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
|
||||
@ -906,7 +906,7 @@ impl ConversationModel {
|
||||
let input_length = *input_tensor.size().last().unwrap() as usize;
|
||||
let mut generated = self
|
||||
.model
|
||||
.generate_from_ids_and_past(input_tensor, Some(attention_mask));
|
||||
.generate_from_ids_and_past(input_tensor, Some(attention_mask))?;
|
||||
let removed_padding_quantities = self.clean_padding_indices(&mut generated);
|
||||
|
||||
let mut output = HashMap::with_capacity(active_uuid.len());
|
||||
@ -936,7 +936,8 @@ impl ConversationModel {
|
||||
output
|
||||
} else {
|
||||
HashMap::new()
|
||||
}
|
||||
};
|
||||
Ok(updated_conversations)
|
||||
}
|
||||
|
||||
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) -> Vec<(usize, usize)> {
|
||||
|
@ -1775,11 +1775,11 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
&self,
|
||||
prompt_texts: Option<&[S]>,
|
||||
generate_options: Option<GenerateOptions>,
|
||||
) -> Vec<GeneratedTextOutput>
|
||||
) -> Result<Vec<GeneratedTextOutput>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
let indices_outputs = self.generate_indices(prompt_texts, generate_options);
|
||||
let indices_outputs = self.generate_indices(prompt_texts, generate_options)?;
|
||||
let mut output = Vec::with_capacity(indices_outputs.len());
|
||||
for generated_sequence in indices_outputs {
|
||||
output.push(GeneratedTextOutput {
|
||||
@ -1789,7 +1789,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
score: generated_sequence.score,
|
||||
});
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
|
||||
@ -1869,7 +1869,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
&self,
|
||||
prompt_texts: Option<&[S]>,
|
||||
generate_options: Option<GenerateOptions>,
|
||||
) -> Vec<GeneratedIndicesOutput>
|
||||
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
@ -1896,11 +1896,12 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
}
|
||||
None => match self.get_bos_id() {
|
||||
Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
|
||||
None => panic!(
|
||||
None => return Err(RustBertError::ValueError(
|
||||
"A model with a BOS token must be used to start generation with an empty input"
|
||||
),
|
||||
.to_string(),
|
||||
)),
|
||||
},
|
||||
_ => return Vec::new(),
|
||||
_ => return Ok(Vec::new()),
|
||||
};
|
||||
self.generate_from_ids_and_past(input_ids, None, generate_options)
|
||||
}
|
||||
@ -1960,7 +1961,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
mut input_ids: Tensor,
|
||||
mut attention_mask: Option<Tensor>,
|
||||
generate_options: Option<GenerateOptions>,
|
||||
) -> Vec<GeneratedIndicesOutput> {
|
||||
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError> {
|
||||
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).cloned();
|
||||
|
||||
let config = PrivateLanguageGenerator::get_config(self);
|
||||
@ -2033,7 +2034,9 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
};
|
||||
|
||||
let encoder_outputs = if self.is_encoder_decoder() {
|
||||
let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap();
|
||||
let encoder_outputs = self
|
||||
.encode(&input_ids, Some(&attention_mask))
|
||||
.ok_or(RustBertError::UnsupportedError)?;
|
||||
let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
|
||||
.view((-1, 1))
|
||||
.repeat([1, num_beams * effective_batch_mult])
|
||||
@ -2067,10 +2070,11 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
(input_ids, attention_mask)
|
||||
}
|
||||
} else {
|
||||
let decoder_start_token_id = decoder_start_token_id.unwrap_or_else(|| {
|
||||
self.get_decoder_start_id()
|
||||
.expect("decoder start id must be specified for encoder decoders")
|
||||
});
|
||||
let decoder_start_token_id = decoder_start_token_id
|
||||
.or(self.get_decoder_start_id())
|
||||
.ok_or(RustBertError::ValueError(
|
||||
"decoder start id must be specified for encoder decoders".to_string(),
|
||||
))?;
|
||||
let input_ids = Tensor::full(
|
||||
[effective_batch_size * num_beams, 1],
|
||||
decoder_start_token_id,
|
||||
@ -2103,9 +2107,16 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
config.max_length
|
||||
};
|
||||
|
||||
if let Some(max_length) = max_length {
|
||||
if input_ids.size2()?.1 > max_length {
|
||||
return Err(RustBertError::ValueError("The input ids exceeds the maximum length for generation.\
|
||||
Reduce the size of the provided input ids or increase the allowable maximum generation length.".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if max_length.is_none() & eos_token_ids.is_none() {
|
||||
panic!("No maximum length given for a model without an EOS token. \
|
||||
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`")
|
||||
return Err(RustBertError::InvalidConfigurationError("No maximum length given for a model without an EOS token. \
|
||||
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`".to_string()));
|
||||
}
|
||||
|
||||
let gen_opt = InternalGenerateOptions {
|
||||
@ -2182,7 +2193,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
token_scores,
|
||||
});
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Returns a reference to the text generator's tokenizer
|
||||
|
@ -338,43 +338,43 @@ impl SummarizationOption {
|
||||
}
|
||||
|
||||
/// Interface method to generate() of the particular models.
|
||||
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Vec<String>
|
||||
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Result<Vec<String>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
match *self {
|
||||
Ok(match *self {
|
||||
Self::Bart(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::T5(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::LongT5(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::ProphetNet(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::Pegasus(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -506,7 +506,7 @@ impl SummarizationModel {
|
||||
/// # }
|
||||
/// ```
|
||||
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
||||
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
|
||||
pub fn summarize<S>(&self, texts: &[S]) -> Result<Vec<String>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
|
@ -335,7 +335,7 @@ impl TextGenerationOption {
|
||||
prompt_texts: Option<&[S]>,
|
||||
min_length: Option<i64>,
|
||||
max_length: Option<i64>,
|
||||
) -> Vec<Vec<i64>>
|
||||
) -> Result<Vec<Vec<i64>>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
@ -344,49 +344,49 @@ impl TextGenerationOption {
|
||||
max_length,
|
||||
..Default::default()
|
||||
});
|
||||
match *self {
|
||||
Ok(match *self {
|
||||
Self::GPT(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::GPT2(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::GPTNeo(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::GPTJ(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::XLNet(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::Reformer(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
Self::T5(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.generate_indices(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn half(&mut self) -> Result<(), RustBertError> {
|
||||
@ -599,7 +599,11 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into<Option<&'a str>>) -> Vec<String>
|
||||
pub fn generate<'a, S>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
prefix: impl Into<Option<&'a str>>,
|
||||
) -> Result<Vec<String>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
@ -625,8 +629,10 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
self.max_length.map(|max_length| max_length + prefix_length),
|
||||
)
|
||||
}
|
||||
_ => panic!("Prefix length not defined but prefix provided!"),
|
||||
};
|
||||
_ => Err(RustBertError::ValueError(
|
||||
"Prefix length not defined but prefix provided!".to_string(),
|
||||
)),
|
||||
}?;
|
||||
|
||||
let mut output = Vec::with_capacity(generated_indices.len());
|
||||
for generated_sequence in generated_indices {
|
||||
@ -636,7 +642,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
true,
|
||||
));
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1219,18 +1219,18 @@ impl TranslationOption {
|
||||
&self,
|
||||
prompt_texts: Option<&[S]>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) -> Vec<String>
|
||||
) -> Result<Vec<String>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Send + Sync,
|
||||
{
|
||||
match *self {
|
||||
Ok(match *self {
|
||||
Self::Marian(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::T5(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.generate(prompt_texts, None)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
@ -1240,7 +1240,7 @@ impl TranslationOption {
|
||||
..Default::default()
|
||||
};
|
||||
model
|
||||
.generate(prompt_texts, Some(generate_options))
|
||||
.generate(prompt_texts, Some(generate_options))?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect()
|
||||
@ -1251,7 +1251,7 @@ impl TranslationOption {
|
||||
..Default::default()
|
||||
};
|
||||
model
|
||||
.generate(prompt_texts, Some(generate_options))
|
||||
.generate(prompt_texts, Some(generate_options))?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect()
|
||||
@ -1264,12 +1264,12 @@ impl TranslationOption {
|
||||
..Default::default()
|
||||
});
|
||||
model
|
||||
.generate(prompt_texts, generate_options)
|
||||
.generate(prompt_texts, generate_options)?
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1484,7 +1484,7 @@ impl TranslationModel {
|
||||
&self.supported_target_languages,
|
||||
)?;
|
||||
|
||||
Ok(match prefix {
|
||||
match prefix {
|
||||
Some(value) => {
|
||||
let texts = texts
|
||||
.iter()
|
||||
@ -1493,7 +1493,7 @@ impl TranslationModel {
|
||||
self.model.generate(Some(&texts), forced_bos_token_id)
|
||||
}
|
||||
None => self.model.generate(Some(texts), forced_bos_token_id),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,7 +128,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = model.summarize(&input);
|
||||
let output = model.summarize(&input)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], " K2-18b is not too hot and not too cold for liquid water to exist. \
|
||||
@ -189,7 +189,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = model.summarize(&input);
|
||||
let output = model.summarize(&input)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], " K2-18b, a planet circling a star in the constellation Leo, is not too hot and not too cold for liquid water to exist. \
|
||||
|
@ -121,7 +121,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
|
||||
let model = TextGenerationModel::new(generate_config)?;
|
||||
|
||||
let input_context = "The cat";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
|
||||
@ -154,7 +154,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
|
||||
let model = TextGenerationModel::new(generate_config)?;
|
||||
|
||||
let input_context = "The dog";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(
|
||||
@ -199,7 +199,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
|
||||
|
||||
let input_context_1 = "The dog";
|
||||
let input_context_2 = "The cat";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
assert_eq!(
|
||||
@ -255,7 +255,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
|
||||
|
||||
let input_context_1 = "The dog";
|
||||
let input_context_2 = "The cat was";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
assert_eq!(
|
||||
@ -313,7 +313,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
|
||||
|
||||
let input_context_1 = "It was a nice and";
|
||||
let input_context_2 = "Language models can generate";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
assert_eq!(
|
||||
@ -392,7 +392,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
||||
let output = model.generate(
|
||||
Some(&[input_context_1, input_context_2]),
|
||||
Some(generate_options),
|
||||
);
|
||||
)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(
|
||||
@ -455,8 +455,9 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let baseline_output = model.generate(Some(&[input_context_1]), Some(baseline_generate_options));
|
||||
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options));
|
||||
let baseline_output =
|
||||
model.generate(Some(&[input_context_1]), Some(baseline_generate_options))?;
|
||||
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options))?;
|
||||
|
||||
assert_eq!(baseline_output.len(), 1);
|
||||
assert_eq!(
|
||||
@ -521,8 +522,9 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let baseline_output = model.generate(Some(&[input_context_1]), Some(baseline_generate_options));
|
||||
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options));
|
||||
let baseline_output =
|
||||
model.generate(Some(&[input_context_1]), Some(baseline_generate_options))?;
|
||||
let output = model.generate(Some(&[input_context_1]), Some(test_generate_options))?;
|
||||
|
||||
assert_eq!(baseline_output.len(), 1);
|
||||
assert_eq!(
|
||||
@ -589,7 +591,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
||||
let output = model.generate(
|
||||
Some(&[input_context_1, input_context_2]),
|
||||
Some(generate_options),
|
||||
);
|
||||
)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(
|
||||
@ -638,7 +640,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
|
||||
let output = model.generate_indices(
|
||||
Some(&[input_context_1, input_context_2]),
|
||||
Some(generate_options),
|
||||
);
|
||||
)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(
|
||||
@ -694,7 +696,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
|
||||
let output = model.generate_indices(
|
||||
Some(&[input_context_1, input_context_2]),
|
||||
Some(generate_options),
|
||||
);
|
||||
)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(
|
||||
@ -735,7 +737,7 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> {
|
||||
conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output.get(&conversation_id).unwrap(), &"The Big Lebowski");
|
||||
|
||||
@ -744,12 +746,12 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> {
|
||||
.get(&conversation_id)
|
||||
.unwrap()
|
||||
.add_user_input("Is it an action movie?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output.get(&conversation_id).unwrap(), &"It\'s a comedy.");
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
@ -773,7 +775,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
|
||||
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
|
||||
assert_eq!(
|
||||
@ -786,12 +788,12 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
|
||||
.get(&conversation_1_id)
|
||||
.unwrap()
|
||||
.add_user_input("Is it an action movie?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
@ -817,7 +819,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result
|
||||
let conversation_2_id = conversation_manager.create("Hello how are you?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
|
||||
assert_eq!(
|
||||
@ -835,12 +837,12 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result
|
||||
.unwrap()
|
||||
.add_user_input("Fine.");
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
@ -864,7 +866,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> any
|
||||
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
|
||||
assert_eq!(
|
||||
@ -878,7 +880,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> any
|
||||
.get(&conversation_2_id)
|
||||
.unwrap()
|
||||
.add_user_input("Why do you recommend it?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
output.get(&conversation_2_id).unwrap(),
|
||||
@ -886,7 +888,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> any
|
||||
);
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager)?;
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
|
@ -143,7 +143,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
|
||||
|
||||
let input_context_1 = "It was a very nice and sunny";
|
||||
let input_context_2 = "It was a gloom winter night, and";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0], "It was a very nice and sunny day. The sun was shining through the clouds, and the sky was clear. The wind was blowing through the trees,");
|
||||
|
@ -53,7 +53,7 @@ mod tests {
|
||||
let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
|
||||
|
||||
let input_context = "The dog";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(
|
||||
|
@ -49,7 +49,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
|
||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
let output = model.summarize(&input);
|
||||
let output = model.summarize(&input)?;
|
||||
|
||||
assert_eq! (
|
||||
output[0],
|
||||
|
@ -222,7 +222,7 @@ mod tests {
|
||||
..Default::default()
|
||||
})?;
|
||||
let prompts = ["It was a very nice and sunny"];
|
||||
let output = text_generation_model.generate(&prompts, None);
|
||||
let output = text_generation_model.generate(&prompts, None)?;
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "It was a very nice and sunny day. I was very happy with the weather. I was very happy with the weather. I was very happy with");
|
||||
Ok(())
|
||||
|
@ -134,7 +134,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
|
||||
let model = TextGenerationModel::new(generate_config)?;
|
||||
|
||||
let input_context = "It was an intense machine dialogue. ";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
|
||||
@ -176,7 +176,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
|
||||
let model = TextGenerationModel::new(generate_config)?;
|
||||
|
||||
let input_context = "The dog is";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(
|
||||
@ -230,7 +230,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
|
||||
|
||||
let input_context_1 = "The dog is";
|
||||
let input_context_2 = "The cat";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
|
||||
@ -298,7 +298,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
|
||||
|
||||
let input_context_1 = "The dog is";
|
||||
let input_context_2 = "The cat was in";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
// Left padding impacts the generated sentences output
|
||||
|
@ -54,7 +54,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = summarization_model.summarize(&input);
|
||||
let output = summarization_model.summarize(&input)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
|
@ -57,7 +57,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = summarization_model.summarize(&input);
|
||||
let output = summarization_model.summarize(&input)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
|
@ -64,7 +64,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
|
||||
|
||||
let input_context_1 = "The really great men must, I think,";
|
||||
let input_context_2 = "It was a gloom winter night, and";
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
let output = model.generate(&[input_context_1, input_context_2], None)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0], " The really great men must, I think, anyway waiting for some unknown reason, but Nikodim Fomitch and Ilya Petrovitch looked at him anguish invitable incidently at him. He could not resist an impression which might be setting");
|
||||
|
@ -102,7 +102,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
|
||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
let output = model.summarize(&input);
|
||||
let output = model.summarize(&input)?;
|
||||
|
||||
assert_eq! (
|
||||
output[0],
|
||||
|
@ -222,7 +222,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
|
||||
let model = TextGenerationModel::new(generate_config)?;
|
||||
|
||||
let input_context = "Once upon a time,";
|
||||
let output = model.generate(&[input_context], None);
|
||||
let output = model.generate(&[input_context], None)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
|
Loading…
Reference in New Issue
Block a user