Addition of BART as an option for sequence classification pipelines, updated return type of pipelines to be a Result

This commit is contained in:
Guillaume B 2020-09-01 17:52:49 +02:00
parent 919cbff3c9
commit 71121e56d8
11 changed files with 105 additions and 74 deletions

View File

@ -27,7 +27,7 @@ fn main() -> anyhow::Result<()> {
let qa_inputs = squad_processor(squad_path);
// Get answer
let answers = qa_model.predict(&qa_inputs, 1, 64);
let answers = qa_model.predict(&qa_inputs, 1, 64)?;
println!("Sample answer: {:?}", answers.first().unwrap());
println!("{}", answers.len());
Ok(())

View File

@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
.map(|v| v.as_str())
.collect::<Vec<&str>>()
.as_slice(),
),
)?,
);
}
let mut flat_outputs = vec![];

View File

@ -188,9 +188,10 @@ impl NERModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
self.token_classification_model
.predict(input, true, false)
pub fn predict(&self, input: &[&str]) -> Result<Vec<Entity>, RustBertError> {
Ok(self
.token_classification_model
.predict(input, true, false)?
.into_iter()
.filter(|token| token.label != "O")
.map(|token| Entity {
@ -198,6 +199,6 @@ impl NERModel {
score: token.score,
label: token.label,
})
.collect()
.collect())
}
}

View File

@ -328,6 +328,9 @@ impl QuestionAnsweringOption {
ModelType::T5 => {
panic!("QuestionAnswering not implemented for T5!");
}
ModelType::Bart => {
panic!("QuestionAnswering not implemented for BART!");
}
}
}
@ -503,7 +506,7 @@ impl QuestionAnsweringModel {
qa_inputs: &[QaInput],
top_k: i64,
batch_size: usize,
) -> Vec<Vec<Answer>> {
) -> Result<Vec<Vec<Answer>>, RustBertError> {
let examples: Vec<QaExample> = qa_inputs
.iter()
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
@ -617,7 +620,7 @@ impl QuestionAnsweringModel {
all_answers.push(vec![]);
}
}
all_answers
Ok(all_answers)
}
fn decode(&self, start: &Tensor, end: &Tensor, top_k: i64) -> (Vec<i64>, Vec<i64>, Vec<f64>) {

View File

@ -137,8 +137,8 @@ impl SentimentModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let labels = self.sequence_classification_model.predict(input);
pub fn predict(&self, input: &[&str]) -> Result<Vec<Sentiment>, RustBertError> {
let labels = self.sequence_classification_model.predict(input)?;
let mut sentiments = Vec::with_capacity(labels.len());
for label in labels {
let polarity = if label.id == 1 {
@ -151,7 +151,7 @@ impl SentimentModel {
score: label.score,
})
}
sentiments
Ok(sentiments)
}
}

View File

@ -58,6 +58,7 @@
//! # ;
//! ```
use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::common::error::RustBertError;
use crate::common::resources::{download_resource, RemoteResource, Resource};
@ -183,6 +184,8 @@ pub enum SequenceClassificationOption {
XLMRoberta(RobertaForSequenceClassification),
/// Albert for Sequence Classification
Albert(AlbertForSequenceClassification),
/// Bart for Sequence Classification
Bart(BartForSequenceClassification),
}
impl SequenceClassificationOption {
@ -244,6 +247,15 @@ impl SequenceClassificationOption {
panic!("You can only supply an AlbertConfig for Albert!");
}
}
ModelType::Bart => {
if let ConfigOption::Bart(config) = config {
SequenceClassificationOption::Bart(BartForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Bert!");
}
}
ModelType::Electra => {
panic!("SequenceClassification not implemented for Electra!");
}
@ -264,6 +276,7 @@ impl SequenceClassificationOption {
Self::XLMRoberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::Albert(_) => ModelType::Albert,
Self::Bart(_) => ModelType::Bart,
}
}
@ -276,50 +289,54 @@ impl SequenceClassificationOption {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Tensor {
) -> Result<Tensor, RustBertError> {
match *self {
Self::Bert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.0
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.0
}
Self::Albert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.0
}
Self::Bart(ref model) => match input_ids {
Some(input_ids) => Ok(model
.forward_t(&input_ids, mask.as_ref(), None, None, None, train)
.0),
None => {
return {
Err(RustBertError::ValueError(
"`input_ids` must be provided when using a BART model".to_string(),
))
}
}
},
Self::Bert(ref model) => Ok(model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.0),
Self::DistilBert(ref model) => Ok(model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0),
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => Ok(model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.0),
Self::Albert(ref model) => Ok(model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.0),
}
}
}
@ -431,19 +448,19 @@ impl SequenceClassificationModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
pub fn predict(&self, input: &[&str]) -> Result<Vec<Label>, RustBertError> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let output = self.sequence_classifier.forward_t(
self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
);
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
});
)
})?;
let output = output.softmax(-1, Kind::Float).detach().to(Device::Cpu);
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
let scores = output
.gather(1, &label_indices.unsqueeze(-1), false)
@ -466,7 +483,7 @@ impl SequenceClassificationModel {
};
labels.push(label)
}
labels
Ok(labels)
}
/// Multi-label classification of texts
@ -496,19 +513,23 @@ impl SequenceClassificationModel {
/// # Ok(())
/// # }
/// ```
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
pub fn predict_multilabel(
&self,
input: &[&str],
threshold: f64,
) -> Result<Vec<Vec<Label>>, RustBertError> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let output = self.sequence_classifier.forward_t(
self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
);
output.sigmoid().detach().to(Device::Cpu)
});
)
})?;
let output = output.sigmoid().detach().to(Device::Cpu);
let label_indices = output.as_ref().ge(threshold).nonzero();
let mut labels: Vec<Vec<Label>> = vec![];
@ -538,6 +559,6 @@ impl SequenceClassificationModel {
if sequence_labels.len() > 0 {
labels.push(sequence_labels);
}
labels
Ok(labels)
}
}

View File

@ -370,6 +370,9 @@ impl TokenClassificationOption {
ModelType::T5 => {
panic!("TokenClassification not implemented for T5!");
}
ModelType::Bart => {
panic!("TokenClassification not implemented for BART!");
}
}
}
@ -572,7 +575,7 @@ impl TokenClassificationModel {
input: &[&str],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Token> {
) -> Result<Vec<Token>, RustBertError> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
self.token_sequence_classifier.forward_t(
@ -619,7 +622,7 @@ impl TokenClassificationModel {
if consolidate_sub_tokens {
self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
}
tokens
Ok(tokens)
}
fn decode_token(

View File

@ -650,6 +650,9 @@ impl TranslationOption {
ModelType::Albert => {
panic!("Translation not implemented for Albert!");
}
ModelType::Bart => {
panic!("Translation not implemented for BART!");
}
}
}

View File

@ -361,7 +361,7 @@ fn bert_pre_trained_ner() -> anyhow::Result<()> {
];
// Run model
let output = ner_model.predict(&input);
let output = ner_model.predict(&input)?;
assert_eq!(output.len(), 4);
@ -407,7 +407,7 @@ fn bert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&vec![qa_input], 1, 32)?;
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -27,7 +27,7 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = sentiment_classifier.predict(&input);
let output = sentiment_classifier.predict(&input)?;
assert_eq!(output.len(), 3 as usize);
assert_eq!(output[0].polarity, SentimentPolarity::Positive);
@ -249,7 +249,7 @@ fn distilbert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&vec![qa_input], 1, 32)?;
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -357,7 +357,7 @@ fn roberta_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&vec![qa_input], 1, 32)?;
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);
@ -396,7 +396,7 @@ fn xlm_roberta_german_ner() -> anyhow::Result<()> {
"Chongqing ist eine Stadt in China.",
];
let output = ner_model.predict(&input);
let output = ner_model.predict(&input)?;
assert_eq!(output.len(), 4);