Merge pull request #41 from guillaume-be/multilabel_classification

Addition of multi-label classification prediction method for sequence…
This commit is contained in:
guillaume-be 2020-05-20 16:41:11 +00:00 committed by GitHub
commit 60bffe6e8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 2 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.7.3"
version = "0.7.4"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"

View File

@ -0,0 +1,35 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
fn main() -> failure::Fallible<()> {
// Set-up model
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This is a neutral sentence.",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
for label in output {
println!("{:?}", label);
}
Ok(())
}

View File

@ -184,7 +184,7 @@ impl SequenceClassificationOption {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => {panic!("SequenceClassification not implemented for Electra!");}
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); }
}
}
@ -335,4 +335,74 @@ impl SequenceClassificationModel {
}
labels
}
/// Multi-label classification of texts
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of texts to classify.
/// * `threshold` - `f64` threshold above which a label will be considered true by the classifier
///
/// # Returns
///
/// * `Vec<Vec<Label>>` containing a vector of true labels for each input text
///
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
///
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
/// let input = [
/// "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
/// "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
/// ];
/// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
///# Ok(())
///# }
/// ```
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false);
output.sigmoid().detach().to(Device::Cpu)
});
let label_indices = output.as_ref().ge(threshold).nonzero();
let mut labels: Vec<Vec<Label>> = vec!();
let mut sequence_labels: Vec<Label> = vec!();
for sentence_idx in 0..label_indices.size()[0] {
let label_index_tensor = label_indices.get(sentence_idx);
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>();
let (sentence, id) = (sentence_label[0], sentence_label[1]);
if sentence as usize > labels.len() {
labels.push(sequence_labels);
sequence_labels = vec!();
}
let score = output.double_value(sentence_label.as_slice());
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
let label = Label {
text: label_string,
score,
id,
sentence: sentence as usize,
};
sequence_labels.push(label);
}
if sequence_labels.len() > 0 {
labels.push(sequence_labels);
}
labels
}
}