Aligned ModelForTokenClassification and ModelForSequenceClassification APIs (#323)

This commit is contained in:
guillaume-be 2023-01-15 11:10:38 +00:00 committed by GitHub
parent 445b76fe7b
commit f12e8ef475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 199 additions and 113 deletions

View File

@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file. The format
- Allow mixing local and remote resources in pipelines.
- Upgraded to `torch` 1.13 (via `tch` 0.9.0).
- (BREAKING) Made the `max_length` argument for generation methods and pipelines optional.
- (BREAKING) Changed return type of `ModelForSequenceClassification` and `ModelForTokenClassification` to `Result<Self, RustBertError>` allowing error handling if no labels are provided in the configuration.
## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.

View File

@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
vs.load(weights_path)?;
// Define input

View File

@ -505,9 +505,12 @@ impl AlbertForSequenceClassification {
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForSequenceClassification =
/// AlbertForSequenceClassification::new(&p.root(), &config);
/// AlbertForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &AlbertConfig,
) -> Result<AlbertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -519,7 +522,11 @@ impl AlbertForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -528,11 +535,11 @@ impl AlbertForSequenceClassification {
Default::default(),
);
AlbertForSequenceClassification {
Ok(AlbertForSequenceClassification {
albert,
dropout,
classifier,
}
})
}
/// Forward pass through the model

View File

@ -695,7 +695,7 @@ pub struct BartClassificationHead {
}
impl BartClassificationHead {
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartClassificationHead
pub fn new<'p, P>(p: P, config: &BartConfig) -> Result<BartClassificationHead, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -703,7 +703,11 @@ impl BartClassificationHead {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let dense = nn::linear(
p / "dense",
@ -719,11 +723,11 @@ impl BartClassificationHead {
Default::default(),
);
BartClassificationHead {
Ok(BartClassificationHead {
dense,
dropout,
out_proj,
}
})
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
@ -768,22 +772,25 @@ impl BartForSequenceClassification {
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&p.root() / "bart", &config);
/// BartForSequenceClassification::new(&p.root() / "bart", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &BartConfig,
) -> Result<BartForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = BartModel::new(p / "model", config);
let classification_head = BartClassificationHead::new(p / "classification_head", config);
let classification_head = BartClassificationHead::new(p / "classification_head", config)?;
let eos_token_id = config.eos_token_id.unwrap_or(3);
BartForSequenceClassification {
Ok(BartForSequenceClassification {
base_model,
classification_head,
eos_token_id,
}
})
}
/// Forward pass through the model

View File

@ -682,9 +682,12 @@ impl BertForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config);
/// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<BertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -695,7 +698,11 @@ impl BertForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -704,11 +711,11 @@ impl BertForSequenceClassification {
Default::default(),
);
BertForSequenceClassification {
Ok(BertForSequenceClassification {
bert,
dropout,
classifier,
}
})
}
/// Forward pass through the model

View File

@ -732,9 +732,12 @@ impl DebertaForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DebertaConfig::from_file(config_path);
/// let model = DebertaForSequenceClassification::new(&p.root(), &config);
/// let model = DebertaForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &DebertaConfig,
) -> Result<DebertaForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -751,7 +754,11 @@ impl DebertaForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
@ -761,12 +768,12 @@ impl DebertaForSequenceClassification {
Default::default(),
);
DebertaForSequenceClassification {
Ok(DebertaForSequenceClassification {
deberta,
pooler,
classifier,
dropout,
}
})
}
/// Forward pass through the model

View File

@ -594,9 +594,12 @@ impl DebertaV2ForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DebertaV2Config::from_file(config_path);
/// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config);
/// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &DebertaV2Config,
) -> Result<DebertaV2ForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -613,7 +616,11 @@ impl DebertaV2ForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
@ -623,12 +630,12 @@ impl DebertaV2ForSequenceClassification {
Default::default(),
);
DebertaV2ForSequenceClassification {
Ok(DebertaV2ForSequenceClassification {
deberta,
pooler,
classifier,
dropout,
}
})
}
/// Forward pass through the model

View File

@ -287,9 +287,12 @@ impl DistilBertModelClassifier {
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&p.root() / "distilbert", &config);
/// DistilBertModelClassifier::new(&p.root() / "distilbert", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier
pub fn new<'p, P>(
p: P,
config: &DistilBertConfig,
) -> Result<DistilBertModelClassifier, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -300,7 +303,11 @@ impl DistilBertModelClassifier {
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let pre_classifier = nn::linear(
@ -312,12 +319,12 @@ impl DistilBertModelClassifier {
let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertModelClassifier {
Ok(DistilBertModelClassifier {
distil_bert_model,
pre_classifier,
classifier,
dropout,
}
})
}
/// Forward pass through the model
@ -680,7 +687,8 @@ impl DistilBertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
/// let distil_bert =
/// DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
/// ```
pub fn new<'p, P>(
p: P,

View File

@ -499,9 +499,12 @@ impl FNetForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = FNetConfig::from_file(config_path);
/// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config);
/// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &FNetConfig,
) -> Result<FNetForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -512,7 +515,11 @@ impl FNetForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -521,11 +528,11 @@ impl FNetForSequenceClassification {
Default::default(),
);
FNetForSequenceClassification {
Ok(FNetForSequenceClassification {
fnet,
dropout,
classifier,
}
})
}
/// Forward pass through the model

View File

@ -791,7 +791,10 @@ pub struct LongformerClassificationHead {
}
impl LongformerClassificationHead {
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerClassificationHead
pub fn new<'p, P>(
p: P,
config: &LongformerConfig,
) -> Result<LongformerClassificationHead, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -808,7 +811,11 @@ impl LongformerClassificationHead {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
@ -817,11 +824,11 @@ impl LongformerClassificationHead {
Default::default(),
);
LongformerClassificationHead {
Ok(LongformerClassificationHead {
dense,
dropout,
out_proj,
}
})
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
@ -865,21 +872,24 @@ impl LongformerForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = LongformerConfig::from_file(config_path);
/// let longformer_model = LongformerForSequenceClassification::new(&p.root(), &config);
/// let longformer_model = LongformerForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &LongformerConfig,
) -> Result<LongformerForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let longformer = LongformerModel::new(p / "longformer", config, false);
let classifier = LongformerClassificationHead::new(p / "classifier", config);
let classifier = LongformerClassificationHead::new(p / "classifier", config)?;
LongformerForSequenceClassification {
Ok(LongformerForSequenceClassification {
longformer,
classifier,
}
})
}
/// Forward pass through the model

View File

@ -178,7 +178,7 @@ pub struct MBartClassificationHead {
}
impl MBartClassificationHead {
pub fn new<'p, P>(p: P, config: &MBartConfig) -> MBartClassificationHead
pub fn new<'p, P>(p: P, config: &MBartConfig) -> Result<MBartClassificationHead, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -194,9 +194,12 @@ impl MBartClassificationHead {
let num_labels = config
.id2label
.as_ref()
.expect("id2label not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
config.d_model,
@ -206,11 +209,11 @@ impl MBartClassificationHead {
let dropout = Dropout::new(config.classifier_dropout.unwrap_or(0.0));
MBartClassificationHead {
Ok(MBartClassificationHead {
dense,
dropout,
out_proj,
}
})
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
@ -592,22 +595,25 @@ impl MBartForSequenceClassification {
/// let p = nn::VarStore::new(device);
/// let config = MBartConfig::from_file(config_path);
/// let mbart: MBartForSequenceClassification =
/// MBartForSequenceClassification::new(&p.root(), &config);
/// MBartForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &MBartConfig) -> MBartForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &MBartConfig,
) -> Result<MBartForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = MBartModel::new(p / "model", config);
let classification_head = MBartClassificationHead::new(p / "classification_head", config);
let classification_head = MBartClassificationHead::new(p / "classification_head", config)?;
let eos_token_id = config.eos_token_id.unwrap_or(3);
MBartForSequenceClassification {
Ok(MBartForSequenceClassification {
base_model,
classification_head,
eos_token_id,
}
})
}
/// Forward pass through the model

View File

@ -690,9 +690,12 @@ impl MobileBertForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = MobileBertConfig::from_file(config_path);
/// let mobilebert = MobileBertForSequenceClassification::new(&p.root(), &config);
/// let mobilebert = MobileBertForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &MobileBertConfig,
) -> Result<MobileBertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -703,7 +706,11 @@ impl MobileBertForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -711,11 +718,11 @@ impl MobileBertForSequenceClassification {
num_labels,
Default::default(),
);
MobileBertForSequenceClassification {
Ok(MobileBertForSequenceClassification {
mobilebert,
dropout,
classifier,
}
})
}
/// Forward pass through the model

View File

@ -232,7 +232,7 @@ impl SequenceClassificationOption {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(SequenceClassificationOption::Bert(
BertForSequenceClassification::new(p, config),
BertForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -243,7 +243,7 @@ impl SequenceClassificationOption {
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(SequenceClassificationOption::Deberta(
DebertaForSequenceClassification::new(p, config),
DebertaForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -254,7 +254,7 @@ impl SequenceClassificationOption {
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
Ok(SequenceClassificationOption::DebertaV2(
DebertaV2ForSequenceClassification::new(p, config),
DebertaV2ForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -265,7 +265,7 @@ impl SequenceClassificationOption {
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(SequenceClassificationOption::DistilBert(
DistilBertModelClassifier::new(p, config),
DistilBertModelClassifier::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -276,7 +276,7 @@ impl SequenceClassificationOption {
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(SequenceClassificationOption::MobileBert(
MobileBertForSequenceClassification::new(p, config),
MobileBertForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -287,7 +287,7 @@ impl SequenceClassificationOption {
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
Ok(SequenceClassificationOption::Roberta(
RobertaForSequenceClassification::new(p, config),
RobertaForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -298,7 +298,7 @@ impl SequenceClassificationOption {
ModelType::XLMRoberta => {
if let ConfigOption::Roberta(config) = config {
Ok(SequenceClassificationOption::XLMRoberta(
RobertaForSequenceClassification::new(p, config),
RobertaForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -309,7 +309,7 @@ impl SequenceClassificationOption {
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(SequenceClassificationOption::Albert(
AlbertForSequenceClassification::new(p, config),
AlbertForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -320,7 +320,7 @@ impl SequenceClassificationOption {
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(SequenceClassificationOption::XLNet(
XLNetForSequenceClassification::new(p, config).unwrap(),
XLNetForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -331,7 +331,7 @@ impl SequenceClassificationOption {
ModelType::Bart => {
if let ConfigOption::Bart(config) = config {
Ok(SequenceClassificationOption::Bart(
BartForSequenceClassification::new(p, config),
BartForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -353,7 +353,7 @@ impl SequenceClassificationOption {
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(SequenceClassificationOption::Longformer(
LongformerForSequenceClassification::new(p, config),
LongformerForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -364,7 +364,7 @@ impl SequenceClassificationOption {
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(SequenceClassificationOption::FNet(
FNetForSequenceClassification::new(p, config),
FNetForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(

View File

@ -262,7 +262,7 @@ impl ZeroShotClassificationOption {
ModelType::Bart => {
if let ConfigOption::Bart(config) = config {
Ok(ZeroShotClassificationOption::Bart(
BartForSequenceClassification::new(p, config),
BartForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -273,7 +273,7 @@ impl ZeroShotClassificationOption {
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(ZeroShotClassificationOption::Deberta(
DebertaForSequenceClassification::new(p, config),
DebertaForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -284,7 +284,7 @@ impl ZeroShotClassificationOption {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::Bert(
BertForSequenceClassification::new(p, config),
BertForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -295,7 +295,7 @@ impl ZeroShotClassificationOption {
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(ZeroShotClassificationOption::DistilBert(
DistilBertModelClassifier::new(p, config),
DistilBertModelClassifier::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -306,7 +306,7 @@ impl ZeroShotClassificationOption {
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(ZeroShotClassificationOption::MobileBert(
MobileBertForSequenceClassification::new(p, config),
MobileBertForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -317,7 +317,7 @@ impl ZeroShotClassificationOption {
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::Roberta(
RobertaForSequenceClassification::new(p, config),
RobertaForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -328,7 +328,7 @@ impl ZeroShotClassificationOption {
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::XLMRoberta(
RobertaForSequenceClassification::new(p, config),
RobertaForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -339,7 +339,7 @@ impl ZeroShotClassificationOption {
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(ZeroShotClassificationOption::Albert(
AlbertForSequenceClassification::new(p, config),
AlbertForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -350,7 +350,7 @@ impl ZeroShotClassificationOption {
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(ZeroShotClassificationOption::XLNet(
XLNetForSequenceClassification::new(p, config).unwrap(),
XLNetForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -361,7 +361,7 @@ impl ZeroShotClassificationOption {
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(ZeroShotClassificationOption::Longformer(
LongformerForSequenceClassification::new(p, config),
LongformerForSequenceClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(

View File

@ -709,14 +709,15 @@ impl ReformerClassificationHead {
config.hidden_size,
Default::default(),
);
let num_labels = match &config.id2label {
Some(value) => value.len() as i64,
None => {
return Err(RustBertError::InvalidConfigurationError(
"an id to label mapping must be provided for classification tasks".to_string(),
));
}
};
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,

View File

@ -381,7 +381,7 @@ pub struct RobertaClassificationHead {
}
impl RobertaClassificationHead {
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaClassificationHead
pub fn new<'p, P>(p: P, config: &BertConfig) -> Result<RobertaClassificationHead, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -395,7 +395,11 @@ impl RobertaClassificationHead {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
@ -405,11 +409,11 @@ impl RobertaClassificationHead {
);
let dropout = Dropout::new(config.hidden_dropout_prob);
RobertaClassificationHead {
Ok(RobertaClassificationHead {
dense,
dropout,
out_proj,
}
})
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
@ -453,21 +457,24 @@ impl RobertaForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = RobertaConfig::from_file(config_path);
/// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config);
/// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<RobertaForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let roberta =
BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
let classifier = RobertaClassificationHead::new(p / "classifier", config);
let classifier = RobertaClassificationHead::new(p / "classifier", config)?;
RobertaForSequenceClassification {
Ok(RobertaForSequenceClassification {
roberta,
classifier,
}
})
}
/// Forward pass through the model

View File

@ -919,7 +919,7 @@ impl XLNetForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = XLNetConfig::from_file(config_path);
/// let xlnet_model = XLNetForSequenceClassification::new(&p.root(), &config);
/// let xlnet_model = XLNetForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(
p: P,
@ -936,7 +936,11 @@ impl XLNetForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let logits_proj = nn::linear(

View File

@ -109,7 +109,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForSequenceClassification::new(vs.root(), &config);
let albert_model = AlbertForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -162,7 +162,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = BertForSequenceClassification::new(vs.root(), &config);
let bert_model = BertForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -41,7 +41,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
vs.load(weights_path)?;
// Define input

View File

@ -88,7 +88,7 @@ fn deberta_v2_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Neutral"));
dummy_label_mapping.insert(2, String::from("Negative"));
config.id2label = Some(dummy_label_mapping);
let model = DebertaV2ForSequenceClassification::new(vs.root(), &config);
let model = DebertaV2ForSequenceClassification::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -197,7 +197,7 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping);
let model = LongformerForSequenceClassification::new(vs.root(), &config);
let model = LongformerForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = ["Very positive sentence", "Second sentence input"];

View File

@ -130,7 +130,7 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping);
let model = MobileBertForSequenceClassification::new(vs.root(), &config);
let model = MobileBertForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = ["Very positive sentence", "Second sentence input"];

View File

@ -136,7 +136,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config);
let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = [