Documentation for zero-shot classification

This commit is contained in:
Guillaume B 2020-09-05 15:20:13 +02:00
parent b52b0cb005
commit 6e8f79ff67
4 changed files with 197 additions and 5 deletions

View File

@ -161,7 +161,32 @@ Example output:
]
```
#### 6. Sentiment analysis
#### 6. Zero-shot classification
Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
```rust
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2],
candidate_labels,
None,
128,
);
```
Output:
```
[
[ Label { "politics", score: 0.972 }, Label { "public health", score: 0.032 }, Label {"economics", score: 0.006 }, Label {"sports", score: 0.004 } ],
[ Label { "politics", score: 0.975 }, Label { "public health", score: 0.0818 }, Label {"economics", score: 0.852 }, Label {"sports", score: 0.001 } ],
]
```
#### 7. Sentiment analysis
Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
```rust
let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -185,7 +210,7 @@ Output:
]
```
#### 7. Named Entity Recognition
#### 8. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Models are currently available for English, German, Spanish and Dutch.
```rust

View File

@ -10,6 +10,7 @@
//! - Translation
//! - Summarization
//! - Multi-turn dialogue
//! - Zero-shot classification
//! - Sentiment Analysis
//! - Named Entity Recognition
//! - Question-Answering

View File

@ -176,7 +176,86 @@
//! # ;
//! ```
//!
//! #### 6. Sentiment analysis
//! #### 6. Zero-shot classification
//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
//! ```no_run
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
//!
//! outputs:
//! ```no_run
//! # use rust_bert::pipelines::sequence_classification::Label;
//! let output = [
//! [
//! Label {
//! text: "politics".to_string(),
//! score: 0.972,
//! id: 0,
//! sentence: 0,
//! },
//! Label {
//! text: "public health".to_string(),
//! score: 0.032,
//! id: 1,
//! sentence: 0,
//! },
//! Label {
//! text: "economics".to_string(),
//! score: 0.006,
//! id: 2,
//! sentence: 0,
//! },
//! Label {
//! text: "sports".to_string(),
//! score: 0.004,
//! id: 3,
//! sentence: 0,
//! },
//! ],
//! [
//! Label {
//! text: "politics".to_string(),
//! score: 0.975,
//! id: 0,
//! sentence: 1,
//! },
//! Label {
//! text: "economics".to_string(),
//! score: 0.852,
//! id: 2,
//! sentence: 1,
//! },
//! Label {
//! text: "public health".to_string(),
//! score: 0.0818,
//! id: 1,
//! sentence: 1,
//! },
//! Label {
//! text: "sports".to_string(),
//! score: 0.001,
//! id: 3,
//! sentence: 1,
//! },
//! ],
//! ]
//! .to_vec();
//! ```
//!
//! #### 7. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
@ -215,7 +294,7 @@
//! # ;
//! ```
//!
//! #### 7. Named Entity Recognition
//! #### 8. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Additional pre-trained models are available for English, German, Spanish and Dutch.
//! ```no_run

View File

@ -11,6 +11,93 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Zero-shot classification pipeline
//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
//! The default model is a BART model fine-tuned on a MNLI. From a list of input sequences to classify and a list of target labels,
//! single-class or multi-label classification is performed, translating the classification task to an inference task.
//! The default template for translation to inference task is `This example is about {}.`. This template can be updated to a more specific
//! value that may match better the use case, for example `This review is about a {product_class}`.
//!
//! - `predict` performs single-class classification (one and exactly one label must be true for each provided input)
//! - `predict_multilabel` performs multi-label classification (zero, one or more labels may be true for each provided input)
//!
//! ```no_run
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
//!
//! outputs:
//! ```no_run
//! # use rust_bert::pipelines::sequence_classification::Label;
//! let output = [
//! [
//! Label {
//! text: "politics".to_string(),
//! score: 0.972,
//! id: 0,
//! sentence: 0,
//! },
//! Label {
//! text: "public health".to_string(),
//! score: 0.032,
//! id: 1,
//! sentence: 0,
//! },
//! Label {
//! text: "economics".to_string(),
//! score: 0.006,
//! id: 2,
//! sentence: 0,
//! },
//! Label {
//! text: "sports".to_string(),
//! score: 0.004,
//! id: 3,
//! sentence: 0,
//! },
//! ],
//! [
//! Label {
//! text: "politics".to_string(),
//! score: 0.975,
//! id: 0,
//! sentence: 1,
//! },
//! Label {
//! text: "economics".to_string(),
//! score: 0.852,
//! id: 2,
//! sentence: 1,
//! },
//! Label {
//! text: "public health".to_string(),
//! score: 0.0818,
//! id: 1,
//! sentence: 1,
//! },
//! Label {
//! text: "sports".to_string(),
//! score: 0.001,
//! id: 3,
//! sentence: 1,
//! },
//! ],
//! ]
//! .to_vec();
//! ```
use crate::albert::AlbertForSequenceClassification;
use crate::bart::{
BartConfigResources, BartForSequenceClassification, BartMergesResources, BartModelResources,
@ -511,7 +598,7 @@ impl ZeroShotClassificationModel {
///
/// * `input` - `&[&str]` Array of texts to classify.
/// * `labels` - `&[&str]` Possible labels for the inputs.
/// * `template` - `Option<Box<dyn Fn(&str) -> String>>` closure to build label propositions. If None, will default to `"This example is {}."`.
/// * `template` - `Option<Box<dyn Fn(&str) -> String>>` closure to build label propositions. If None, will default to `"This example is about {}."`.
/// * `max_length` -`usize` Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
///
/// # Returns