mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Addition of multi-label classification prediction method for sequence classification pipeline
Version update
This commit is contained in:
parent
33a623e54d
commit
d0fc3ff40d
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
|
||||
|
35
examples/sequence_classification_multilabel.rs
Normal file
35
examples/sequence_classification_multilabel.rs
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Set-up model
|
||||
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||
"This is a neutral sentence.",
|
||||
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||
];
|
||||
|
||||
// Run model
|
||||
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
|
||||
for label in output {
|
||||
println!("{:?}", label);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -184,7 +184,7 @@ impl SequenceClassificationOption {
|
||||
panic!("You can only supply a BertConfig for Roberta!");
|
||||
}
|
||||
}
|
||||
ModelType::Electra => {panic!("SequenceClassification not implemented for Electra!");}
|
||||
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); }
|
||||
}
|
||||
}
|
||||
|
||||
@ -335,4 +335,74 @@ impl SequenceClassificationModel {
|
||||
}
|
||||
labels
|
||||
}
|
||||
|
||||
/// Multi-label classification of texts
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - `&[&str]` Array of texts to classify.
|
||||
/// * `threshold` - `f64` threshold above which a label will be considered true by the classifier
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Vec<Vec<Label>>` containing a vector of true labels for each input text
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# fn main() -> failure::Fallible<()> {
|
||||
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||
///
|
||||
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
||||
/// let input = [
|
||||
/// "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||
/// "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
||||
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||
/// ];
|
||||
/// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
|
||||
///# Ok(())
|
||||
///# }
|
||||
/// ```
|
||||
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
|
||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||
let output = no_grad(|| {
|
||||
let (output, _, _) = self.sequence_classifier
|
||||
.forward_t(Some(input_tensor.copy()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false);
|
||||
output.sigmoid().detach().to(Device::Cpu)
|
||||
});
|
||||
let label_indices = output.as_ref().ge(threshold).nonzero();
|
||||
|
||||
let mut labels: Vec<Vec<Label>> = vec!();
|
||||
let mut sequence_labels: Vec<Label> = vec!();
|
||||
|
||||
for sentence_idx in 0..label_indices.size()[0] {
|
||||
|
||||
let label_index_tensor = label_indices.get(sentence_idx);
|
||||
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||
let (sentence, id) = (sentence_label[0], sentence_label[1]);
|
||||
if sentence as usize > labels.len() {
|
||||
labels.push(sequence_labels);
|
||||
sequence_labels = vec!();
|
||||
}
|
||||
let score = output.double_value(sentence_label.as_slice());
|
||||
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
|
||||
let label = Label {
|
||||
text: label_string,
|
||||
score,
|
||||
id,
|
||||
sentence: sentence as usize,
|
||||
};
|
||||
sequence_labels.push(label);
|
||||
|
||||
}
|
||||
if sequence_labels.len() > 0 {
|
||||
labels.push(sequence_labels);
|
||||
}
|
||||
labels
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user