Make predict methods in ZeroShot pipeline return Result instead of panicking on unwrap (#301)

* Add checked prediction methods

- Add checked prediction methods to ZeroShotClassificationModel.
These methods return Option and convert any underlying errors into None,
to allow callers to implement appropriate error handling logic.

* Update ZeroShot example to use checked method.

* Add tests for ZeroShot checked methods

* Change checked prediction methods to return Result

* refactor: rename *_checked into try_*

Rename *_checked methods into try_* methods.
This is more idiomatic vis-a-vis the Rust standard library.

* refactor: remove try_ prefix from predict methods

* refactor: change return from Option to Result

Change return type of ZeroShotClassificationModel.prepare_for_model
from option into Result. This simplifies the code, and returns
the error closer to its origin.

This addresses comments from @guillaume-be.

* refactor: address clippy lints in tests

Co-authored-by: guillaume-be <guillaume.becquin@gmail.com>
This commit is contained in:
Anna Melnikov 2022-12-04 03:10:01 -06:00 committed by GitHub
parent a0ef06bccf
commit a34cf9f8e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 32 deletions

View File

@ -22,14 +22,16 @@ fn main() -> anyhow::Result<()> {
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict_multilabel(
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
);
let output = sequence_classification_model
.predict_multilabel(
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
)
.unwrap();
println!("{:?}", output);

View File

@ -590,7 +590,7 @@ impl ZeroShotClassificationModel {
labels: T,
template: Option<ZeroShotTemplate>,
max_len: usize,
) -> (Tensor, Tensor)
) -> Result<(Tensor, Tensor), RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
@ -628,7 +628,8 @@ impl ZeroShotClassificationModel {
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
.ok_or_else(|| RustBertError::ValueError("Got empty iterator as input".to_string()))?;
let pad_id = self
.tokenizer
.get_pad_id()
@ -651,7 +652,7 @@ impl ZeroShotClassificationModel {
.expect("The Tokenizer used for zero shot classification should contain a PAD id"))
.to_kind(Bool);
(tokenized_input_tensors, mask)
Ok((tokenized_input_tensors, mask))
}
/// Zero shot classification with 1 (and exactly 1) true label.
@ -665,7 +666,7 @@ impl ZeroShotClassificationModel {
///
/// # Returns
///
/// * `Vec<Label>` containing with the most likely label for each input sentence.
/// * `Result<Vec<Label>, RustBertError>` containing the most likely label for each input sentence or error, if any.
///
/// # Example
///
@ -679,7 +680,7 @@ impl ZeroShotClassificationModel {
/// let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
///
/// let output = sequence_classification_model.predict(
/// let output = sequence_classification_model.try_predict(
/// &[input_sentence, input_sequence_2],
/// candidate_labels,
/// None,
@ -692,7 +693,7 @@ impl ZeroShotClassificationModel {
/// outputs:
/// ```no_run
/// # use rust_bert::pipelines::sequence_classification::Label;
/// let output = [
/// let output = Ok([
/// Label {
/// text: "politics".to_string(),
/// score: 0.959,
@ -706,7 +707,7 @@ impl ZeroShotClassificationModel {
/// sentence: 1,
/// },
/// ]
/// .to_vec();
/// .to_vec());
/// ```
pub fn predict<'a, S, T>(
&self,
@ -714,14 +715,15 @@ impl ZeroShotClassificationModel {
labels: T,
template: Option<ZeroShotTemplate>,
max_length: usize,
) -> Vec<Label>
) -> Result<Vec<Label>, RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(&input_tensor),
@ -739,8 +741,8 @@ impl ZeroShotClassificationModel {
let scores = scores
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze_dim(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
let label_indices = label_indices.iter::<i64>()?.collect::<Vec<i64>>();
let scores = scores.iter::<f64>()?.collect::<Vec<f64>>();
let mut output_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
@ -753,7 +755,7 @@ impl ZeroShotClassificationModel {
};
output_labels.push(label)
}
output_labels
Ok(output_labels)
}
/// Zero shot multi-label classification with 0, 1 or no true label.
@ -767,7 +769,7 @@ impl ZeroShotClassificationModel {
///
/// # Returns
///
/// * `Vec<Vec<Label>>` containing a vector of labels and their probability for each input text
/// * `Result<Vec<Vec<Label>>, RustBertError>` containing a vector of labels and their probability for each input text, or error, if any.
///
/// # Example
///
@ -781,7 +783,7 @@ impl ZeroShotClassificationModel {
/// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
///
/// let output = sequence_classification_model.predict_multilabel(
/// let output = sequence_classification_model.try_predict_multilabel(
/// &[input_sentence, input_sequence_2],
/// candidate_labels,
/// None,
@ -793,7 +795,7 @@ impl ZeroShotClassificationModel {
/// outputs:
/// ```no_run
/// # use rust_bert::pipelines::sequence_classification::Label;
/// let output = [
/// let output = Ok([
/// [
/// Label {
/// text: "politics".to_string(),
@ -847,7 +849,7 @@ impl ZeroShotClassificationModel {
/// },
/// ],
/// ]
/// .to_vec();
/// .to_vec());
/// ```
pub fn predict_multilabel<'a, S, T>(
&self,
@ -855,14 +857,15 @@ impl ZeroShotClassificationModel {
labels: T,
template: Option<ZeroShotTemplate>,
max_length: usize,
) -> Vec<Vec<Label>>
) -> Result<Vec<Vec<Label>>, RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(&input_tensor),
@ -882,8 +885,7 @@ impl ZeroShotClassificationModel {
for (label_index, score) in scores
.select(0, sentence_idx as i64)
.iter::<f64>()
.unwrap()
.iter::<f64>()?
.enumerate()
{
let label_string = labels.as_ref()[label_index].to_string();
@ -897,7 +899,7 @@ impl ZeroShotClassificationModel {
}
output_labels.push(sentence_labels);
}
output_labels
Ok(output_labels)
}
}
#[cfg(test)]

View File

@ -7,7 +7,7 @@ use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_bert::{Config, RustBertError};
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -218,7 +218,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
format!("This example is about {}.", label)
})),
128,
);
)?;
assert_eq!(output.len(), 2);
@ -230,6 +230,34 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification_try_error() -> anyhow::Result<()> {
// Set-up model
let zero_shot_config = ZeroShotClassificationConfig {
device: Device::Cpu,
..Default::default()
};
let sequence_classification_model = ZeroShotClassificationModel::new(zero_shot_config)?;
let output = sequence_classification_model.predict(
[],
[],
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
);
let output_is_error = match output {
Err(RustBertError::ValueError(_)) => true,
_ => unreachable!(),
};
assert!(output_is_error);
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
@ -251,7 +279,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
format!("This example is about {}.", label)
})),
128,
);
)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0].len(), candidate_labels.len());
@ -276,3 +304,31 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
assert!((output[1][3].score - 0.0004).abs() < 1e-4);
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification_multilabel_try_error() -> anyhow::Result<()> {
// Set-up model
let zero_shot_config = ZeroShotClassificationConfig {
device: Device::Cpu,
..Default::default()
};
let sequence_classification_model = ZeroShotClassificationModel::new(zero_shot_config)?;
let output = sequence_classification_model.predict_multilabel(
[],
[],
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
);
let output_is_error = match output {
Err(RustBertError::ValueError(_)) => true,
_ => unreachable!(),
};
assert!(output_is_error);
Ok(())
}