From f12e8ef475e86b7d537bfd1b00275cda1477e308 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 15 Jan 2023 11:10:38 +0000 Subject: [PATCH] Aligned ModelForTokenClassification and ModelForSequenceClassification APIs (#323) --- CHANGELOG.md | 1 + .../natural_language_inference_deberta.rs | 2 +- src/albert/albert_model.rs | 17 +++++++---- src/bart/bart_model.rs | 25 +++++++++++------ src/bert/bert_model.rs | 17 +++++++---- src/deberta/deberta_model.rs | 17 +++++++---- src/deberta_v2/deberta_v2_model.rs | 17 +++++++---- src/distilbert/distilbert_model.rs | 20 +++++++++---- src/fnet/fnet_model.rs | 17 +++++++---- src/longformer/longformer_model.rs | 28 +++++++++++++------ src/mbart/mbart_model.rs | 26 ++++++++++------- src/mobilebert/mobilebert_model.rs | 17 +++++++---- src/pipelines/sequence_classification.rs | 24 ++++++++-------- src/pipelines/zero_shot_classification.rs | 20 ++++++------- src/reformer/reformer_model.rs | 17 +++++------ src/roberta/roberta_model.rs | 25 +++++++++++------ src/xlnet/xlnet_model.rs | 8 ++++-- tests/albert.rs | 2 +- tests/bert.rs | 2 +- tests/deberta.rs | 2 +- tests/deberta_v2.rs | 2 +- tests/longformer.rs | 2 +- tests/mobilebert.rs | 2 +- tests/roberta.rs | 2 +- 24 files changed, 199 insertions(+), 113 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7d2e2..84eccab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file. The format - Allow mixing local and remote resources in pipelines. - Upgraded to `torch` 1.13 (via `tch` 0.9.0). - (BREAKING) Made the `max_length` argument for generation methods and pipelines optional. +- (BREAKING) Changed return type of `ModelForSequenceClassification` and `ModelForTokenClassification` to `Result` allowing error handling if no labels are provided in the configuration. ## Fixed - Fixed configuration check for RoBERTa models for sentence classification. diff --git a/examples/natural_language_inference_deberta.rs b/examples/natural_language_inference_deberta.rs index e779a07..32a7639 100644 --- a/examples/natural_language_inference_deberta.rs +++ b/examples/natural_language_inference_deberta.rs @@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> { false, )?; let config = DebertaConfig::from_file(config_path); - let model = DebertaForSequenceClassification::new(vs.root(), &config); + let model = DebertaForSequenceClassification::new(vs.root(), &config)?; vs.load(weights_path)?; // Define input diff --git a/src/albert/albert_model.rs b/src/albert/albert_model.rs index 720273c..c87644f 100644 --- a/src/albert/albert_model.rs +++ b/src/albert/albert_model.rs @@ -505,9 +505,12 @@ impl AlbertForSequenceClassification { /// let p = nn::VarStore::new(device); /// let config = AlbertConfig::from_file(config_path); /// let albert: AlbertForSequenceClassification = - /// AlbertForSequenceClassification::new(&p.root(), &config); + /// AlbertForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &AlbertConfig, + ) -> Result where P: Borrow>, { @@ -519,7 +522,11 @@ impl AlbertForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let classifier = nn::linear( p / "classifier", @@ -528,11 +535,11 @@ impl AlbertForSequenceClassification { Default::default(), ); - AlbertForSequenceClassification { + Ok(AlbertForSequenceClassification { albert, dropout, classifier, - } + }) } /// Forward pass through the model diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index 2b9eb32..3cf30dd 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -695,7 +695,7 @@ pub struct BartClassificationHead { } impl BartClassificationHead { - pub fn new<'p, P>(p: P, config: &BartConfig) -> BartClassificationHead + pub fn new<'p, P>(p: P, config: &BartConfig) -> Result where P: Borrow>, { @@ -703,7 +703,11 @@ impl BartClassificationHead { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let dense = nn::linear( p / "dense", @@ -719,11 +723,11 @@ impl BartClassificationHead { Default::default(), ); - BartClassificationHead { + Ok(BartClassificationHead { dense, dropout, out_proj, - } + }) } pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor { @@ -768,22 +772,25 @@ impl BartForSequenceClassification { /// let p = nn::VarStore::new(device); /// let config = BartConfig::from_file(config_path); /// let bart: BartForSequenceClassification = - /// BartForSequenceClassification::new(&p.root() / "bart", &config); + /// BartForSequenceClassification::new(&p.root() / "bart", &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &BartConfig, + ) -> Result where P: Borrow>, { let p = p.borrow(); let base_model = BartModel::new(p / "model", config); - let classification_head = BartClassificationHead::new(p / "classification_head", config); + let classification_head = BartClassificationHead::new(p / "classification_head", config)?; let eos_token_id = config.eos_token_id.unwrap_or(3); - BartForSequenceClassification { + Ok(BartForSequenceClassification { base_model, classification_head, eos_token_id, - } + }) } /// Forward pass through the model diff --git a/src/bert/bert_model.rs b/src/bert/bert_model.rs index b8d9956..7dde511 100644 --- a/src/bert/bert_model.rs +++ b/src/bert/bert_model.rs @@ -682,9 +682,12 @@ impl BertForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config); + /// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &BertConfig, + ) -> Result where P: Borrow>, { @@ -695,7 +698,11 @@ impl BertForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let classifier = nn::linear( p / "classifier", @@ -704,11 +711,11 @@ impl BertForSequenceClassification { Default::default(), ); - BertForSequenceClassification { + Ok(BertForSequenceClassification { bert, dropout, classifier, - } + }) } /// Forward pass through the model diff --git a/src/deberta/deberta_model.rs b/src/deberta/deberta_model.rs index b39e357..6901b40 100644 --- a/src/deberta/deberta_model.rs +++ b/src/deberta/deberta_model.rs @@ -732,9 +732,12 @@ impl DebertaForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DebertaConfig::from_file(config_path); - /// let model = DebertaForSequenceClassification::new(&p.root(), &config); + /// let model = DebertaForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &DebertaConfig, + ) -> Result where P: Borrow>, { @@ -751,7 +754,11 @@ impl DebertaForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let classifier = nn::linear( @@ -761,12 +768,12 @@ impl DebertaForSequenceClassification { Default::default(), ); - DebertaForSequenceClassification { + Ok(DebertaForSequenceClassification { deberta, pooler, classifier, dropout, - } + }) } /// Forward pass through the model diff --git a/src/deberta_v2/deberta_v2_model.rs b/src/deberta_v2/deberta_v2_model.rs index 6826d85..2164e0a 100644 --- a/src/deberta_v2/deberta_v2_model.rs +++ b/src/deberta_v2/deberta_v2_model.rs @@ -594,9 +594,12 @@ impl DebertaV2ForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DebertaV2Config::from_file(config_path); - /// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config); + /// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &DebertaV2Config, + ) -> Result where P: Borrow>, { @@ -613,7 +616,11 @@ impl DebertaV2ForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let classifier = nn::linear( @@ -623,12 +630,12 @@ impl DebertaV2ForSequenceClassification { Default::default(), ); - DebertaV2ForSequenceClassification { + Ok(DebertaV2ForSequenceClassification { deberta, pooler, classifier, dropout, - } + }) } /// Forward pass through the model diff --git a/src/distilbert/distilbert_model.rs b/src/distilbert/distilbert_model.rs index 5859c5f..714d7fd 100644 --- a/src/distilbert/distilbert_model.rs +++ b/src/distilbert/distilbert_model.rs @@ -287,9 +287,12 @@ impl DistilBertModelClassifier { /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); /// let distil_bert: DistilBertModelClassifier = - /// DistilBertModelClassifier::new(&p.root() / "distilbert", &config); + /// DistilBertModelClassifier::new(&p.root() / "distilbert", &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier + pub fn new<'p, P>( + p: P, + config: &DistilBertConfig, + ) -> Result where P: Borrow>, { @@ -300,7 +303,11 @@ impl DistilBertModelClassifier { let num_labels = config .id2label .as_ref() - .expect("id2label must be provided for classifiers") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let pre_classifier = nn::linear( @@ -312,12 +319,12 @@ impl DistilBertModelClassifier { let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default()); let dropout = Dropout::new(config.seq_classif_dropout); - DistilBertModelClassifier { + Ok(DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout, - } + }) } /// Forward pass through the model @@ -680,7 +687,8 @@ impl DistilBertForTokenClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); - /// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap(); + /// let distil_bert = + /// DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap(); /// ``` pub fn new<'p, P>( p: P, diff --git a/src/fnet/fnet_model.rs b/src/fnet/fnet_model.rs index eb3a02c..9a2d9dd 100644 --- a/src/fnet/fnet_model.rs +++ b/src/fnet/fnet_model.rs @@ -499,9 +499,12 @@ impl FNetForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = FNetConfig::from_file(config_path); - /// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config); + /// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &FNetConfig, + ) -> Result where P: Borrow>, { @@ -512,7 +515,11 @@ impl FNetForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let classifier = nn::linear( p / "classifier", @@ -521,11 +528,11 @@ impl FNetForSequenceClassification { Default::default(), ); - FNetForSequenceClassification { + Ok(FNetForSequenceClassification { fnet, dropout, classifier, - } + }) } /// Forward pass through the model diff --git a/src/longformer/longformer_model.rs b/src/longformer/longformer_model.rs index f2ffeef..b9a0346 100644 --- a/src/longformer/longformer_model.rs +++ b/src/longformer/longformer_model.rs @@ -791,7 +791,10 @@ pub struct LongformerClassificationHead { } impl LongformerClassificationHead { - pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerClassificationHead + pub fn new<'p, P>( + p: P, + config: &LongformerConfig, + ) -> Result where P: Borrow>, { @@ -808,7 +811,11 @@ impl LongformerClassificationHead { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let out_proj = nn::linear( p / "out_proj", @@ -817,11 +824,11 @@ impl LongformerClassificationHead { Default::default(), ); - LongformerClassificationHead { + Ok(LongformerClassificationHead { dense, dropout, out_proj, - } + }) } pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { @@ -865,21 +872,24 @@ impl LongformerForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = LongformerConfig::from_file(config_path); - /// let longformer_model = LongformerForSequenceClassification::new(&p.root(), &config); + /// let longformer_model = LongformerForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &LongformerConfig, + ) -> Result where P: Borrow>, { let p = p.borrow(); let longformer = LongformerModel::new(p / "longformer", config, false); - let classifier = LongformerClassificationHead::new(p / "classifier", config); + let classifier = LongformerClassificationHead::new(p / "classifier", config)?; - LongformerForSequenceClassification { + Ok(LongformerForSequenceClassification { longformer, classifier, - } + }) } /// Forward pass through the model diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index f09c706..a3057e4 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -178,7 +178,7 @@ pub struct MBartClassificationHead { } impl MBartClassificationHead { - pub fn new<'p, P>(p: P, config: &MBartConfig) -> MBartClassificationHead + pub fn new<'p, P>(p: P, config: &MBartConfig) -> Result where P: Borrow>, { @@ -194,9 +194,12 @@ impl MBartClassificationHead { let num_labels = config .id2label .as_ref() - .expect("id2label not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; - let out_proj = nn::linear( p / "out_proj", config.d_model, @@ -206,11 +209,11 @@ impl MBartClassificationHead { let dropout = Dropout::new(config.classifier_dropout.unwrap_or(0.0)); - MBartClassificationHead { + Ok(MBartClassificationHead { dense, dropout, out_proj, - } + }) } pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { @@ -592,22 +595,25 @@ impl MBartForSequenceClassification { /// let p = nn::VarStore::new(device); /// let config = MBartConfig::from_file(config_path); /// let mbart: MBartForSequenceClassification = - /// MBartForSequenceClassification::new(&p.root(), &config); + /// MBartForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &MBartConfig) -> MBartForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &MBartConfig, + ) -> Result where P: Borrow>, { let p = p.borrow(); let base_model = MBartModel::new(p / "model", config); - let classification_head = MBartClassificationHead::new(p / "classification_head", config); + let classification_head = MBartClassificationHead::new(p / "classification_head", config)?; let eos_token_id = config.eos_token_id.unwrap_or(3); - MBartForSequenceClassification { + Ok(MBartForSequenceClassification { base_model, classification_head, eos_token_id, - } + }) } /// Forward pass through the model diff --git a/src/mobilebert/mobilebert_model.rs b/src/mobilebert/mobilebert_model.rs index 6753360..1ec2677 100644 --- a/src/mobilebert/mobilebert_model.rs +++ b/src/mobilebert/mobilebert_model.rs @@ -690,9 +690,12 @@ impl MobileBertForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = MobileBertConfig::from_file(config_path); - /// let mobilebert = MobileBertForSequenceClassification::new(&p.root(), &config); + /// let mobilebert = MobileBertForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &MobileBertConfig, + ) -> Result where P: Borrow>, { @@ -703,7 +706,11 @@ impl MobileBertForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let classifier = nn::linear( p / "classifier", @@ -711,11 +718,11 @@ impl MobileBertForSequenceClassification { num_labels, Default::default(), ); - MobileBertForSequenceClassification { + Ok(MobileBertForSequenceClassification { mobilebert, dropout, classifier, - } + }) } /// Forward pass through the model diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index db241f3..667c236 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -232,7 +232,7 @@ impl SequenceClassificationOption { ModelType::Bert => { if let ConfigOption::Bert(config) = config { Ok(SequenceClassificationOption::Bert( - BertForSequenceClassification::new(p, config), + BertForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -243,7 +243,7 @@ impl SequenceClassificationOption { ModelType::Deberta => { if let ConfigOption::Deberta(config) = config { Ok(SequenceClassificationOption::Deberta( - DebertaForSequenceClassification::new(p, config), + DebertaForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -254,7 +254,7 @@ impl SequenceClassificationOption { ModelType::DebertaV2 => { if let ConfigOption::DebertaV2(config) = config { Ok(SequenceClassificationOption::DebertaV2( - DebertaV2ForSequenceClassification::new(p, config), + DebertaV2ForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -265,7 +265,7 @@ impl SequenceClassificationOption { ModelType::DistilBert => { if let ConfigOption::DistilBert(config) = config { Ok(SequenceClassificationOption::DistilBert( - DistilBertModelClassifier::new(p, config), + DistilBertModelClassifier::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -276,7 +276,7 @@ impl SequenceClassificationOption { ModelType::MobileBert => { if let ConfigOption::MobileBert(config) = config { Ok(SequenceClassificationOption::MobileBert( - MobileBertForSequenceClassification::new(p, config), + MobileBertForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -287,7 +287,7 @@ impl SequenceClassificationOption { ModelType::Roberta => { if let ConfigOption::Roberta(config) = config { Ok(SequenceClassificationOption::Roberta( - RobertaForSequenceClassification::new(p, config), + RobertaForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -298,7 +298,7 @@ impl SequenceClassificationOption { ModelType::XLMRoberta => { if let ConfigOption::Roberta(config) = config { Ok(SequenceClassificationOption::XLMRoberta( - RobertaForSequenceClassification::new(p, config), + RobertaForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -309,7 +309,7 @@ impl SequenceClassificationOption { ModelType::Albert => { if let ConfigOption::Albert(config) = config { Ok(SequenceClassificationOption::Albert( - AlbertForSequenceClassification::new(p, config), + AlbertForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -320,7 +320,7 @@ impl SequenceClassificationOption { ModelType::XLNet => { if let ConfigOption::XLNet(config) = config { Ok(SequenceClassificationOption::XLNet( - XLNetForSequenceClassification::new(p, config).unwrap(), + XLNetForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -331,7 +331,7 @@ impl SequenceClassificationOption { ModelType::Bart => { if let ConfigOption::Bart(config) = config { Ok(SequenceClassificationOption::Bart( - BartForSequenceClassification::new(p, config), + BartForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -353,7 +353,7 @@ impl SequenceClassificationOption { ModelType::Longformer => { if let ConfigOption::Longformer(config) = config { Ok(SequenceClassificationOption::Longformer( - LongformerForSequenceClassification::new(p, config), + LongformerForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -364,7 +364,7 @@ impl SequenceClassificationOption { ModelType::FNet => { if let ConfigOption::FNet(config) = config { Ok(SequenceClassificationOption::FNet( - FNetForSequenceClassification::new(p, config), + FNetForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 0a69f40..e8d2870 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -262,7 +262,7 @@ impl ZeroShotClassificationOption { ModelType::Bart => { if let ConfigOption::Bart(config) = config { Ok(ZeroShotClassificationOption::Bart( - BartForSequenceClassification::new(p, config), + BartForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -273,7 +273,7 @@ impl ZeroShotClassificationOption { ModelType::Deberta => { if let ConfigOption::Deberta(config) = config { Ok(ZeroShotClassificationOption::Deberta( - DebertaForSequenceClassification::new(p, config), + DebertaForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -284,7 +284,7 @@ impl ZeroShotClassificationOption { ModelType::Bert => { if let ConfigOption::Bert(config) = config { Ok(ZeroShotClassificationOption::Bert( - BertForSequenceClassification::new(p, config), + BertForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -295,7 +295,7 @@ impl ZeroShotClassificationOption { ModelType::DistilBert => { if let ConfigOption::DistilBert(config) = config { Ok(ZeroShotClassificationOption::DistilBert( - DistilBertModelClassifier::new(p, config), + DistilBertModelClassifier::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -306,7 +306,7 @@ impl ZeroShotClassificationOption { ModelType::MobileBert => { if let ConfigOption::MobileBert(config) = config { Ok(ZeroShotClassificationOption::MobileBert( - MobileBertForSequenceClassification::new(p, config), + MobileBertForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -317,7 +317,7 @@ impl ZeroShotClassificationOption { ModelType::Roberta => { if let ConfigOption::Bert(config) = config { Ok(ZeroShotClassificationOption::Roberta( - RobertaForSequenceClassification::new(p, config), + RobertaForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -328,7 +328,7 @@ impl ZeroShotClassificationOption { ModelType::XLMRoberta => { if let ConfigOption::Bert(config) = config { Ok(ZeroShotClassificationOption::XLMRoberta( - RobertaForSequenceClassification::new(p, config), + RobertaForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -339,7 +339,7 @@ impl ZeroShotClassificationOption { ModelType::Albert => { if let ConfigOption::Albert(config) = config { Ok(ZeroShotClassificationOption::Albert( - AlbertForSequenceClassification::new(p, config), + AlbertForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -350,7 +350,7 @@ impl ZeroShotClassificationOption { ModelType::XLNet => { if let ConfigOption::XLNet(config) = config { Ok(ZeroShotClassificationOption::XLNet( - XLNetForSequenceClassification::new(p, config).unwrap(), + XLNetForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( @@ -361,7 +361,7 @@ impl ZeroShotClassificationOption { ModelType::Longformer => { if let ConfigOption::Longformer(config) = config { Ok(ZeroShotClassificationOption::Longformer( - LongformerForSequenceClassification::new(p, config), + LongformerForSequenceClassification::new(p, config)?, )) } else { Err(RustBertError::InvalidConfigurationError( diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index c367254..8f3ae37 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -709,14 +709,15 @@ impl ReformerClassificationHead { config.hidden_size, Default::default(), ); - let num_labels = match &config.id2label { - Some(value) => value.len() as i64, - None => { - return Err(RustBertError::InvalidConfigurationError( - "an id to label mapping must be provided for classification tasks".to_string(), - )); - } - }; + let num_labels = config + .id2label + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? + .len() as i64; let out_proj = nn::linear( p / "out_proj", config.hidden_size, diff --git a/src/roberta/roberta_model.rs b/src/roberta/roberta_model.rs index d2f1108..8bce245 100644 --- a/src/roberta/roberta_model.rs +++ b/src/roberta/roberta_model.rs @@ -381,7 +381,7 @@ pub struct RobertaClassificationHead { } impl RobertaClassificationHead { - pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaClassificationHead + pub fn new<'p, P>(p: P, config: &BertConfig) -> Result where P: Borrow>, { @@ -395,7 +395,11 @@ impl RobertaClassificationHead { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let out_proj = nn::linear( p / "out_proj", @@ -405,11 +409,11 @@ impl RobertaClassificationHead { ); let dropout = Dropout::new(config.hidden_dropout_prob); - RobertaClassificationHead { + Ok(RobertaClassificationHead { dense, dropout, out_proj, - } + }) } pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { @@ -453,21 +457,24 @@ impl RobertaForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = RobertaConfig::from_file(config_path); - /// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config); + /// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config).unwrap(); /// ``` - pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForSequenceClassification + pub fn new<'p, P>( + p: P, + config: &BertConfig, + ) -> Result where P: Borrow>, { let p = p.borrow(); let roberta = BertModel::::new_with_optional_pooler(p / "roberta", config, false); - let classifier = RobertaClassificationHead::new(p / "classifier", config); + let classifier = RobertaClassificationHead::new(p / "classifier", config)?; - RobertaForSequenceClassification { + Ok(RobertaForSequenceClassification { roberta, classifier, - } + }) } /// Forward pass through the model diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 0b142c1..c62a5d2 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -919,7 +919,7 @@ impl XLNetForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = XLNetConfig::from_file(config_path); - /// let xlnet_model = XLNetForSequenceClassification::new(&p.root(), &config); + /// let xlnet_model = XLNetForSequenceClassification::new(&p.root(), &config).unwrap(); /// ``` pub fn new<'p, P>( p: P, @@ -936,7 +936,11 @@ impl XLNetForSequenceClassification { let num_labels = config .id2label .as_ref() - .expect("num_labels not provided in configuration") + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "num_labels not provided in configuration".to_string(), + ) + })? .len() as i64; let logits_proj = nn::linear( diff --git a/tests/albert.rs b/tests/albert.rs index 978bec0..3465a70 100644 --- a/tests/albert.rs +++ b/tests/albert.rs @@ -109,7 +109,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let albert_model = AlbertForSequenceClassification::new(vs.root(), &config); + let albert_model = AlbertForSequenceClassification::new(vs.root(), &config)?; // Define input let input = [ diff --git a/tests/bert.rs b/tests/bert.rs index 9217e6f..832e84d 100644 --- a/tests/bert.rs +++ b/tests/bert.rs @@ -162,7 +162,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let bert_model = BertForSequenceClassification::new(vs.root(), &config); + let bert_model = BertForSequenceClassification::new(vs.root(), &config)?; // Define input let input = [ diff --git a/tests/deberta.rs b/tests/deberta.rs index 80841d7..1d4643e 100644 --- a/tests/deberta.rs +++ b/tests/deberta.rs @@ -41,7 +41,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> { false, )?; let config = DebertaConfig::from_file(config_path); - let model = DebertaForSequenceClassification::new(vs.root(), &config); + let model = DebertaForSequenceClassification::new(vs.root(), &config)?; vs.load(weights_path)?; // Define input diff --git a/tests/deberta_v2.rs b/tests/deberta_v2.rs index c3a2ade..c8ebd02 100644 --- a/tests/deberta_v2.rs +++ b/tests/deberta_v2.rs @@ -88,7 +88,7 @@ fn deberta_v2_for_sequence_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(1, String::from("Neutral")); dummy_label_mapping.insert(2, String::from("Negative")); config.id2label = Some(dummy_label_mapping); - let model = DebertaV2ForSequenceClassification::new(vs.root(), &config); + let model = DebertaV2ForSequenceClassification::new(vs.root(), &config)?; // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; diff --git a/tests/longformer.rs b/tests/longformer.rs index 9b156d9..2db9bc7 100644 --- a/tests/longformer.rs +++ b/tests/longformer.rs @@ -197,7 +197,7 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(1, String::from("Negative")); dummy_label_mapping.insert(3, String::from("Neutral")); config.id2label = Some(dummy_label_mapping); - let model = LongformerForSequenceClassification::new(vs.root(), &config); + let model = LongformerForSequenceClassification::new(vs.root(), &config)?; // Define input let input = ["Very positive sentence", "Second sentence input"]; diff --git a/tests/mobilebert.rs b/tests/mobilebert.rs index 67f04d5..1976ffd 100644 --- a/tests/mobilebert.rs +++ b/tests/mobilebert.rs @@ -130,7 +130,7 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(1, String::from("Negative")); dummy_label_mapping.insert(3, String::from("Neutral")); config.id2label = Some(dummy_label_mapping); - let model = MobileBertForSequenceClassification::new(vs.root(), &config); + let model = MobileBertForSequenceClassification::new(vs.root(), &config)?; // Define input let input = ["Very positive sentence", "Second sentence input"]; diff --git a/tests/roberta.rs b/tests/roberta.rs index b29d708..e8d87df 100644 --- a/tests/roberta.rs +++ b/tests/roberta.rs @@ -136,7 +136,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config); + let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config)?; // Define input let input = [