diff --git a/examples/zero_shot_classification.rs b/examples/zero_shot_classification.rs index ec08178..3391026 100644 --- a/examples/zero_shot_classification.rs +++ b/examples/zero_shot_classification.rs @@ -22,14 +22,16 @@ fn main() -> anyhow::Result<()> { let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; let candidate_labels = &["politics", "public health", "economy", "sports"]; - let output = sequence_classification_model.predict_multilabel( - [input_sentence, input_sequence_2], - candidate_labels, - Some(Box::new(|label: &str| { - format!("This example is about {}.", label) - })), - 128, - ); + let output = sequence_classification_model + .predict_multilabel( + [input_sentence, input_sequence_2], + candidate_labels, + Some(Box::new(|label: &str| { + format!("This example is about {}.", label) + })), + 128, + ) + .unwrap(); println!("{:?}", output); diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index b0a62be..998cece 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -590,7 +590,7 @@ impl ZeroShotClassificationModel { labels: T, template: Option, max_len: usize, - ) -> (Tensor, Tensor) + ) -> Result<(Tensor, Tensor), RustBertError> where S: AsRef<[&'a str]>, T: AsRef<[&'a str]>, @@ -628,7 +628,8 @@ impl ZeroShotClassificationModel { .iter() .map(|input| input.token_ids.len()) .max() - .unwrap(); + .ok_or_else(|| RustBertError::ValueError("Got empty iterator as input".to_string()))?; + let pad_id = self .tokenizer .get_pad_id() @@ -651,7 +652,7 @@ impl ZeroShotClassificationModel { .expect("The Tokenizer used for zero shot classification should contain a PAD id")) .to_kind(Bool); - (tokenized_input_tensors, mask) + Ok((tokenized_input_tensors, mask)) } /// Zero shot classification with 1 (and exactly 1) true label. @@ -665,7 +666,7 @@ impl ZeroShotClassificationModel { /// /// # Returns /// - /// * `Vec