mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Make predict methods in ZeroShot pipeline return Result instead of panicking on unwrap (#301)
* Add checked prediction methods - Add checked prediction methods to ZeroShotClassificationModel. These methods return Option and convert any underlying errors into None, to allow callers to implement appropriate error handling logic. * Update ZeroShot example to use checked method. * Add tests for ZeroShot checked methods * Change checked prediction methods to return Result * refactor: rename *_checked into try_* Rename *_checked methods into try_* methods. This is more idiomatic vis-a-vis the Rust standard library. * refactor: remove try_ prefix from predict methods * refactor: change return from Option to Result Change return type of ZeroShotClassificationModel.prepare_for_model from option into Result. This simplifies the code, and returns the error closer to its origin. This addresses comments from @guillaume-be. * refactor: address clippy lints in tests Co-authored-by: guillaume-be <guillaume.becquin@gmail.com>
This commit is contained in:
parent
a0ef06bccf
commit
a34cf9f8e4
@ -22,14 +22,16 @@ fn main() -> anyhow::Result<()> {
|
||||
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
|
||||
let candidate_labels = &["politics", "public health", "economy", "sports"];
|
||||
|
||||
let output = sequence_classification_model.predict_multilabel(
|
||||
[input_sentence, input_sequence_2],
|
||||
candidate_labels,
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
})),
|
||||
128,
|
||||
);
|
||||
let output = sequence_classification_model
|
||||
.predict_multilabel(
|
||||
[input_sentence, input_sequence_2],
|
||||
candidate_labels,
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
})),
|
||||
128,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", output);
|
||||
|
||||
|
@ -590,7 +590,7 @@ impl ZeroShotClassificationModel {
|
||||
labels: T,
|
||||
template: Option<ZeroShotTemplate>,
|
||||
max_len: usize,
|
||||
) -> (Tensor, Tensor)
|
||||
) -> Result<(Tensor, Tensor), RustBertError>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
@ -628,7 +628,8 @@ impl ZeroShotClassificationModel {
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
.ok_or_else(|| RustBertError::ValueError("Got empty iterator as input".to_string()))?;
|
||||
|
||||
let pad_id = self
|
||||
.tokenizer
|
||||
.get_pad_id()
|
||||
@ -651,7 +652,7 @@ impl ZeroShotClassificationModel {
|
||||
.expect("The Tokenizer used for zero shot classification should contain a PAD id"))
|
||||
.to_kind(Bool);
|
||||
|
||||
(tokenized_input_tensors, mask)
|
||||
Ok((tokenized_input_tensors, mask))
|
||||
}
|
||||
|
||||
/// Zero shot classification with 1 (and exactly 1) true label.
|
||||
@ -665,7 +666,7 @@ impl ZeroShotClassificationModel {
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Vec<Label>` containing with the most likely label for each input sentence.
|
||||
/// * `Result<Vec<Label>, RustBertError>` containing the most likely label for each input sentence or error, if any.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -679,7 +680,7 @@ impl ZeroShotClassificationModel {
|
||||
/// let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
|
||||
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
|
||||
///
|
||||
/// let output = sequence_classification_model.predict(
|
||||
/// let output = sequence_classification_model.try_predict(
|
||||
/// &[input_sentence, input_sequence_2],
|
||||
/// candidate_labels,
|
||||
/// None,
|
||||
@ -692,7 +693,7 @@ impl ZeroShotClassificationModel {
|
||||
/// outputs:
|
||||
/// ```no_run
|
||||
/// # use rust_bert::pipelines::sequence_classification::Label;
|
||||
/// let output = [
|
||||
/// let output = Ok([
|
||||
/// Label {
|
||||
/// text: "politics".to_string(),
|
||||
/// score: 0.959,
|
||||
@ -706,7 +707,7 @@ impl ZeroShotClassificationModel {
|
||||
/// sentence: 1,
|
||||
/// },
|
||||
/// ]
|
||||
/// .to_vec();
|
||||
/// .to_vec());
|
||||
/// ```
|
||||
pub fn predict<'a, S, T>(
|
||||
&self,
|
||||
@ -714,14 +715,15 @@ impl ZeroShotClassificationModel {
|
||||
labels: T,
|
||||
template: Option<ZeroShotTemplate>,
|
||||
max_length: usize,
|
||||
) -> Vec<Label>
|
||||
) -> Result<Vec<Label>, RustBertError>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let num_inputs = inputs.as_ref().len();
|
||||
let (input_tensor, mask) =
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
|
||||
|
||||
let output = no_grad(|| {
|
||||
let output = self.zero_shot_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
@ -739,8 +741,8 @@ impl ZeroShotClassificationModel {
|
||||
let scores = scores
|
||||
.gather(1, &label_indices.unsqueeze(-1), false)
|
||||
.squeeze_dim(1);
|
||||
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||
let label_indices = label_indices.iter::<i64>()?.collect::<Vec<i64>>();
|
||||
let scores = scores.iter::<f64>()?.collect::<Vec<f64>>();
|
||||
|
||||
let mut output_labels: Vec<Label> = vec![];
|
||||
for sentence_idx in 0..label_indices.len() {
|
||||
@ -753,7 +755,7 @@ impl ZeroShotClassificationModel {
|
||||
};
|
||||
output_labels.push(label)
|
||||
}
|
||||
output_labels
|
||||
Ok(output_labels)
|
||||
}
|
||||
|
||||
/// Zero shot multi-label classification with 0, 1 or no true label.
|
||||
@ -767,7 +769,7 @@ impl ZeroShotClassificationModel {
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Vec<Vec<Label>>` containing a vector of labels and their probability for each input text
|
||||
/// * `Result<Vec<Vec<Label>>, RustBertError>` containing a vector of labels and their probability for each input text, or error, if any.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -781,7 +783,7 @@ impl ZeroShotClassificationModel {
|
||||
/// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
|
||||
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
|
||||
///
|
||||
/// let output = sequence_classification_model.predict_multilabel(
|
||||
/// let output = sequence_classification_model.try_predict_multilabel(
|
||||
/// &[input_sentence, input_sequence_2],
|
||||
/// candidate_labels,
|
||||
/// None,
|
||||
@ -793,7 +795,7 @@ impl ZeroShotClassificationModel {
|
||||
/// outputs:
|
||||
/// ```no_run
|
||||
/// # use rust_bert::pipelines::sequence_classification::Label;
|
||||
/// let output = [
|
||||
/// let output = Ok([
|
||||
/// [
|
||||
/// Label {
|
||||
/// text: "politics".to_string(),
|
||||
@ -847,7 +849,7 @@ impl ZeroShotClassificationModel {
|
||||
/// },
|
||||
/// ],
|
||||
/// ]
|
||||
/// .to_vec();
|
||||
/// .to_vec());
|
||||
/// ```
|
||||
pub fn predict_multilabel<'a, S, T>(
|
||||
&self,
|
||||
@ -855,14 +857,15 @@ impl ZeroShotClassificationModel {
|
||||
labels: T,
|
||||
template: Option<ZeroShotTemplate>,
|
||||
max_length: usize,
|
||||
) -> Vec<Vec<Label>>
|
||||
) -> Result<Vec<Vec<Label>>, RustBertError>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let num_inputs = inputs.as_ref().len();
|
||||
let (input_tensor, mask) =
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
|
||||
|
||||
let output = no_grad(|| {
|
||||
let output = self.zero_shot_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
@ -882,8 +885,7 @@ impl ZeroShotClassificationModel {
|
||||
|
||||
for (label_index, score) in scores
|
||||
.select(0, sentence_idx as i64)
|
||||
.iter::<f64>()
|
||||
.unwrap()
|
||||
.iter::<f64>()?
|
||||
.enumerate()
|
||||
{
|
||||
let label_string = labels.as_ref()[label_index].to_string();
|
||||
@ -897,7 +899,7 @@ impl ZeroShotClassificationModel {
|
||||
}
|
||||
output_labels.push(sentence_labels);
|
||||
}
|
||||
output_labels
|
||||
Ok(output_labels)
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
|
@ -7,7 +7,7 @@ use rust_bert::pipelines::zero_shot_classification::{
|
||||
ZeroShotClassificationConfig, ZeroShotClassificationModel,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
@ -218,7 +218,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
|
||||
format!("This example is about {}.", label)
|
||||
})),
|
||||
128,
|
||||
);
|
||||
)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
|
||||
@ -230,6 +230,34 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_zero_shot_classification_try_error() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let zero_shot_config = ZeroShotClassificationConfig {
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
};
|
||||
let sequence_classification_model = ZeroShotClassificationModel::new(zero_shot_config)?;
|
||||
|
||||
let output = sequence_classification_model.predict(
|
||||
[],
|
||||
[],
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
})),
|
||||
128,
|
||||
);
|
||||
|
||||
let output_is_error = match output {
|
||||
Err(RustBertError::ValueError(_)) => true,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert!(output_is_error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
|
||||
@ -251,7 +279,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
|
||||
format!("This example is about {}.", label)
|
||||
})),
|
||||
128,
|
||||
);
|
||||
)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0].len(), candidate_labels.len());
|
||||
@ -276,3 +304,31 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
|
||||
assert!((output[1][3].score - 0.0004).abs() < 1e-4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_zero_shot_classification_multilabel_try_error() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let zero_shot_config = ZeroShotClassificationConfig {
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
};
|
||||
let sequence_classification_model = ZeroShotClassificationModel::new(zero_shot_config)?;
|
||||
|
||||
let output = sequence_classification_model.predict_multilabel(
|
||||
[],
|
||||
[],
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
})),
|
||||
128,
|
||||
);
|
||||
|
||||
let output_is_error = match output {
|
||||
Err(RustBertError::ValueError(_)) => true,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert!(output_is_error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user