Permits TokenClassificationOption / DistilBertForTokenClassification to fail gracefully for an invalid configuration. (#320)

* Properly handle config errors when creating classification models

Instead of panic, we now return a proper RustBertError so that
an invalid model or config thereof wouldn't crash the whole system.

* Properly formats newly added code.

* Fixes an example within the documentation

* Properly unwraps the newly created results in unit tests.

* Fixes some code formatting issues.

* Uses proper/idiomatic error handling in unit tests.

* Moves the "?" to the correct position.
This commit is contained in:
Andreas Haufler 2023-01-15 11:32:57 +01:00 committed by GitHub
parent 2c4a79524d
commit 445b76fe7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 163 additions and 84 deletions

View File

@ -648,9 +648,12 @@ impl AlbertForTokenClassification {
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForTokenClassification =
/// AlbertForTokenClassification::new(&p.root(), &config);
/// AlbertForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForTokenClassification
pub fn new<'p, P>(
p: P,
config: &AlbertConfig,
) -> Result<AlbertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -661,7 +664,11 @@ impl AlbertForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -670,11 +677,11 @@ impl AlbertForTokenClassification {
Default::default(),
);
AlbertForTokenClassification {
Ok(AlbertForTokenClassification {
albert,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -707,7 +714,7 @@ impl AlbertForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config);
/// # let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));

View File

@ -971,9 +971,12 @@ impl BertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForTokenClassification::new(&p.root() / "bert", &config);
/// let bert = BertForTokenClassification::new(&p.root() / "bert", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForTokenClassification
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<BertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -984,7 +987,11 @@ impl BertForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -993,11 +1000,11 @@ impl BertForTokenClassification {
Default::default(),
);
BertForTokenClassification {
Ok(BertForTokenClassification {
bert,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -1030,7 +1037,7 @@ impl BertForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model = BertForTokenClassification::new(&vs.root(), &config);
/// # let bert_model = BertForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));

View File

@ -882,9 +882,12 @@ impl DebertaForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DebertaConfig::from_file(config_path);
/// let model = DebertaForTokenClassification::new(&p.root(), &config);
/// let model = DebertaForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForTokenClassification
pub fn new<'p, P>(
p: P,
config: &DebertaConfig,
) -> Result<DebertaForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -895,7 +898,11 @@ impl DebertaForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -904,11 +911,11 @@ impl DebertaForTokenClassification {
Default::default(),
);
DebertaForTokenClassification {
Ok(DebertaForTokenClassification {
deberta,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -940,7 +947,7 @@ impl DebertaForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DebertaConfig::from_file(config_path);
/// # let model = DebertaForTokenClassification::new(&vs.root(), &config);
/// # let model = DebertaForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));

View File

@ -746,7 +746,10 @@ impl DebertaV2ForTokenClassification {
/// let config = DebertaV2Config::from_file(config_path);
/// let model = DebertaV2ForTokenClassification::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForTokenClassification
pub fn new<'p, P>(
p: P,
config: &DebertaV2Config,
) -> Result<DebertaV2ForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -757,7 +760,11 @@ impl DebertaV2ForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -766,11 +773,11 @@ impl DebertaV2ForTokenClassification {
Default::default(),
);
DebertaV2ForTokenClassification {
Ok(DebertaV2ForTokenClassification {
deberta,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -802,7 +809,7 @@ impl DebertaV2ForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DebertaV2Config::from_file(config_path);
/// # let model = DebertaV2ForTokenClassification::new(&vs.root(), &config);
/// # let model = DebertaV2ForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));

View File

@ -680,9 +680,12 @@ impl DistilBertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config);
/// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForTokenClassification
pub fn new<'p, P>(
p: P,
config: &DistilBertConfig,
) -> Result<DistilBertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -693,17 +696,21 @@ impl DistilBertForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"id2label must be provided for classifiers".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertForTokenClassification {
Ok(DistilBertForTokenClassification {
distil_bert_model,
classifier,
dropout,
}
})
}
/// Forward pass through the model
@ -735,7 +742,7 @@ impl DistilBertForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DistilBertConfig::from_file(config_path);
/// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
/// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -793,6 +800,7 @@ pub struct DistilBertSequenceClassificationOutput {
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the DistilBERT token classification model output
pub struct DistilBertTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
@ -802,6 +810,7 @@ pub struct DistilBertTokenClassificationOutput {
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the DistilBERT question answering model output
pub struct DistilBertQuestionAnsweringOutput {
/// Logits for the start position for token of each input sequence

View File

@ -806,9 +806,12 @@ impl ElectraForTokenClassification {
/// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForTokenClassification =
/// ElectraForTokenClassification::new(&p.root(), &config);
/// ElectraForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraForTokenClassification
pub fn new<'p, P>(
p: P,
config: &ElectraConfig,
) -> Result<ElectraForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -819,7 +822,11 @@ impl ElectraForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"id2label must be provided for classifiers".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -828,11 +835,11 @@ impl ElectraForTokenClassification {
Default::default(),
);
ElectraForTokenClassification {
Ok(ElectraForTokenClassification {
electra,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -865,7 +872,7 @@ impl ElectraForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config);
/// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));

View File

@ -779,9 +779,12 @@ impl FNetForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = FNetConfig::from_file(config_path);
/// let fnet = FNetForTokenClassification::new(&p.root() / "fnet", &config);
/// let fnet = FNetForTokenClassification::new(&p.root() / "fnet", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForTokenClassification
pub fn new<'p, P>(
p: P,
config: &FNetConfig,
) -> Result<FNetForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -792,7 +795,11 @@ impl FNetForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -801,11 +808,11 @@ impl FNetForTokenClassification {
Default::default(),
);
FNetForTokenClassification {
Ok(FNetForTokenClassification {
fnet,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -836,7 +843,7 @@ impl FNetForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = FNetConfig::from_file(config_path);
/// let model = FNetForTokenClassification::new(&vs.root(), &config);
/// let model = FNetForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -953,12 +960,12 @@ impl FNetForQuestionAnswering {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::fnet::{FNetConfig, FNetForTokenClassification};
/// use rust_bert::fnet::{FNetConfig, FNetForQuestionAnswering};
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = FNetConfig::from_file(config_path);
/// let model = FNetForTokenClassification::new(&vs.root(), &config);
/// let model = FNetForQuestionAnswering::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -1067,7 +1074,7 @@ mod test {
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device);
let vs = nn::VarStore::new(device);
let config = FNetConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(FNetModel::new(vs.root(), &config, true));

View File

@ -1196,9 +1196,12 @@ impl LongformerForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = LongformerConfig::from_file(config_path);
/// let longformer_model = LongformerForTokenClassification::new(&p.root(), &config);
/// let longformer_model = LongformerForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForTokenClassification
pub fn new<'p, P>(
p: P,
config: &LongformerConfig,
) -> Result<LongformerForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -1210,7 +1213,11 @@ impl LongformerForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
@ -1220,11 +1227,11 @@ impl LongformerForTokenClassification {
Default::default(),
);
LongformerForTokenClassification {
Ok(LongformerForTokenClassification {
longformer,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -1260,7 +1267,7 @@ impl LongformerForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = LongformerConfig::from_file(config_path);
/// let longformer_model = LongformerForTokenClassification::new(&vs.root(), &config);
/// let longformer_model = LongformerForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));

View File

@ -1127,9 +1127,12 @@ impl MobileBertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = MobileBertConfig::from_file(config_path);
/// let mobilebert = MobileBertForTokenClassification::new(&p.root(), &config);
/// let mobilebert = MobileBertForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForTokenClassification
pub fn new<'p, P>(
p: P,
config: &MobileBertConfig,
) -> Result<MobileBertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -1140,7 +1143,11 @@ impl MobileBertForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -1148,11 +1155,12 @@ impl MobileBertForTokenClassification {
num_labels,
Default::default(),
);
MobileBertForTokenClassification {
Ok(MobileBertForTokenClassification {
mobilebert,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -1185,7 +1193,7 @@ impl MobileBertForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = MobileBertConfig::from_file(config_path);
/// let model = MobileBertForTokenClassification::new(&vs.root(), &config);
/// let model = MobileBertForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));

View File

@ -354,7 +354,7 @@ impl TokenClassificationOption {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(TokenClassificationOption::Bert(
BertForTokenClassification::new(p, config),
BertForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -365,7 +365,7 @@ impl TokenClassificationOption {
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(TokenClassificationOption::Deberta(
DebertaForTokenClassification::new(p, config),
DebertaForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -376,7 +376,7 @@ impl TokenClassificationOption {
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
Ok(TokenClassificationOption::DebertaV2(
DebertaV2ForTokenClassification::new(p, config),
DebertaV2ForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -387,7 +387,7 @@ impl TokenClassificationOption {
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(TokenClassificationOption::DistilBert(
DistilBertForTokenClassification::new(p, config),
DistilBertForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -398,7 +398,7 @@ impl TokenClassificationOption {
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(TokenClassificationOption::MobileBert(
MobileBertForTokenClassification::new(p, config),
MobileBertForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -409,7 +409,7 @@ impl TokenClassificationOption {
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
Ok(TokenClassificationOption::Roberta(
RobertaForTokenClassification::new(p, config),
RobertaForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -420,7 +420,7 @@ impl TokenClassificationOption {
ModelType::XLMRoberta => {
if let ConfigOption::Roberta(config) = config {
Ok(TokenClassificationOption::XLMRoberta(
RobertaForTokenClassification::new(p, config),
RobertaForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -431,7 +431,7 @@ impl TokenClassificationOption {
ModelType::Electra => {
if let ConfigOption::Electra(config) = config {
Ok(TokenClassificationOption::Electra(
ElectraForTokenClassification::new(p, config),
ElectraForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -442,7 +442,7 @@ impl TokenClassificationOption {
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(TokenClassificationOption::Albert(
AlbertForTokenClassification::new(p, config),
AlbertForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -453,7 +453,7 @@ impl TokenClassificationOption {
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(TokenClassificationOption::XLNet(
XLNetForTokenClassification::new(p, config).unwrap(),
XLNetForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -464,7 +464,7 @@ impl TokenClassificationOption {
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(TokenClassificationOption::Longformer(
LongformerForTokenClassification::new(p, config),
LongformerForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -475,7 +475,7 @@ impl TokenClassificationOption {
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(TokenClassificationOption::FNet(
FNetForTokenClassification::new(p, config),
FNetForTokenClassification::new(p, config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(

View File

@ -16,6 +16,7 @@ use crate::common::activations::_gelu;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Tensor};
@ -733,7 +734,10 @@ impl RobertaForTokenClassification {
/// let config = RobertaConfig::from_file(config_path);
/// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForTokenClassification
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<RobertaForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
@ -744,7 +748,11 @@ impl RobertaForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
@ -753,11 +761,11 @@ impl RobertaForTokenClassification {
Default::default(),
);
RobertaForTokenClassification {
Ok(RobertaForTokenClassification {
roberta,
dropout,
classifier,
}
})
}
/// Forward pass through the model
@ -792,7 +800,7 @@ impl RobertaForTokenClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
/// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));

View File

@ -1080,7 +1080,7 @@ impl XLNetForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = XLNetConfig::from_file(config_path);
/// let xlnet_model = XLNetForTokenClassification::new(&p.root(), &config);
/// let xlnet_model = XLNetForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(
p: P,
@ -1095,7 +1095,11 @@ impl XLNetForTokenClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(

View File

@ -242,7 +242,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForTokenClassification::new(vs.root(), &config);
let albert_model = AlbertForTokenClassification::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -283,7 +283,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = BertForTokenClassification::new(vs.root(), &config);
let bert_model = BertForTokenClassification::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -170,7 +170,7 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = DebertaForTokenClassification::new(vs.root(), &config);
let model = DebertaForTokenClassification::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -142,7 +142,7 @@ fn deberta_v2_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = DebertaV2ForTokenClassification::new(vs.root(), &config);
let model = DebertaV2ForTokenClassification::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -211,7 +211,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let distil_bert_model = DistilBertForTokenClassification::new(vs.root(), &config);
let distil_bert_model = DistilBertForTokenClassification::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -121,6 +121,7 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
Ok(())
}
//
#[test]
fn fnet_for_multiple_choice() -> anyhow::Result<()> {
@ -201,7 +202,7 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
config.output_hidden_states = Some(true);
let fnet_model = FNetForTokenClassification::new(vs.root(), &config);
let fnet_model = FNetForTokenClassification::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -337,7 +337,7 @@ fn longformer_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = LongformerForTokenClassification::new(vs.root(), &config);
let model = LongformerForTokenClassification::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -240,7 +240,7 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = MobileBertForTokenClassification::new(vs.root(), &config);
let model = MobileBertForTokenClassification::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -273,7 +273,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForTokenClassification::new(vs.root(), &config);
let roberta_model = RobertaForTokenClassification::new(vs.root(), &config)?;
// Define input
let input = [