diff --git a/README.md b/README.md index 36c5a18..099aaa1 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,32 @@ Example output: ] ``` -#### 6. Sentiment analysis +#### 6. Zero-shot classification +Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference. +```rust + let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; + + let input_sentence = "Who are you voting for in 2020?"; + let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; + let candidate_labels = &["politics", "public health", "economics", "sports"]; + + let output = sequence_classification_model.predict_multilabel( + &[input_sentence, input_sequence_2], + candidate_labels, + None, + 128, + ); +``` + +Output: +``` +[ + [ Label { "politics", score: 0.972 }, Label { "public health", score: 0.032 }, Label {"economics", score: 0.006 }, Label {"sports", score: 0.004 } ], + [ Label { "politics", score: 0.975 }, Label { "public health", score: 0.0818 }, Label {"economics", score: 0.852 }, Label {"sports", score: 0.001 } ], +] +``` + +#### 7. Sentiment analysis Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2. ```rust let sentiment_classifier = SentimentModel::new(Default::default())?; @@ -185,7 +210,7 @@ Output: ] ``` -#### 7. Named Entity Recognition +#### 8. Named Entity Recognition Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz). Models are currently available for English, German, Spanish and Dutch. ```rust diff --git a/src/lib.rs b/src/lib.rs index ed32b50..89155f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ //! - Translation //! - Summarization //! - Multi-turn dialogue +//! - Zero-shot classification //! - Sentiment Analysis //! - Named Entity Recognition //! - Question-Answering diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index cca048d..d721141 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -176,7 +176,86 @@ //! # ; //! ``` //! -//! #### 6. Sentiment analysis +//! #### 6. Zero-shot classification +//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference. +//! ```no_run +//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; +//! # fn main() -> anyhow::Result<()> { +//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); +//! # Ok(()) +//! # } +//! ``` +//! +//! outputs: +//! ```no_run +//! # use rust_bert::pipelines::sequence_classification::Label; +//! let output = [ +//! [ +//! Label { +//! text: "politics".to_string(), +//! score: 0.972, +//! id: 0, +//! sentence: 0, +//! }, +//! Label { +//! text: "public health".to_string(), +//! score: 0.032, +//! id: 1, +//! sentence: 0, +//! }, +//! Label { +//! text: "economics".to_string(), +//! score: 0.006, +//! id: 2, +//! sentence: 0, +//! }, +//! Label { +//! text: "sports".to_string(), +//! score: 0.004, +//! id: 3, +//! sentence: 0, +//! }, +//! ], +//! [ +//! Label { +//! text: "politics".to_string(), +//! score: 0.975, +//! id: 0, +//! sentence: 1, +//! }, +//! Label { +//! text: "economics".to_string(), +//! score: 0.852, +//! id: 2, +//! sentence: 1, +//! }, +//! Label { +//! text: "public health".to_string(), +//! score: 0.0818, +//! id: 1, +//! sentence: 1, +//! }, +//! Label { +//! text: "sports".to_string(), +//! score: 0.001, +//! id: 3, +//! sentence: 1, +//! }, +//! ], +//! ] +//! .to_vec(); +//! ``` +//! +//! #### 7. Sentiment analysis //! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2. //! ```no_run //! use rust_bert::pipelines::sentiment::SentimentModel; @@ -215,7 +294,7 @@ //! # ; //! ``` //! -//! #### 7. Named Entity Recognition +//! #### 8. Named Entity Recognition //! Extracts entities (Person, Location, Organization, Miscellaneous) from text. The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz) //! Additional pre-trained models are available for English, German, Spanish and Dutch. //! ```no_run diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 79f619d..3fc7803 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -11,6 +11,93 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! # Zero-shot classification pipeline +//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference. +//! The default model is a BART model fine-tuned on a MNLI. From a list of input sequences to classify and a list of target labels, +//! single-class or multi-label classification is performed, translating the classification task to an inference task. +//! The default template for translation to inference task is `This example is about {}.`. This template can be updated to a more specific +//! value that may match better the use case, for example `This review is about a {product_class}`. +//! +//! - `predict` performs single-class classification (one and exactly one label must be true for each provided input) +//! - `predict_multilabel` performs multi-label classification (zero, one or more labels may be true for each provided input) +//! +//! ```no_run +//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel; +//! # fn main() -> anyhow::Result<()> { +//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; +//! let input_sentence = "Who are you voting for in 2020?"; +//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +//! let candidate_labels = &["politics", "public health", "economics", "sports"]; +//! let output = sequence_classification_model.predict_multilabel( +//! &[input_sentence, input_sequence_2], +//! candidate_labels, +//! None, +//! 128, +//! ); +//! # Ok(()) +//! # } +//! ``` +//! +//! outputs: +//! ```no_run +//! # use rust_bert::pipelines::sequence_classification::Label; +//! let output = [ +//! [ +//! Label { +//! text: "politics".to_string(), +//! score: 0.972, +//! id: 0, +//! sentence: 0, +//! }, +//! Label { +//! text: "public health".to_string(), +//! score: 0.032, +//! id: 1, +//! sentence: 0, +//! }, +//! Label { +//! text: "economics".to_string(), +//! score: 0.006, +//! id: 2, +//! sentence: 0, +//! }, +//! Label { +//! text: "sports".to_string(), +//! score: 0.004, +//! id: 3, +//! sentence: 0, +//! }, +//! ], +//! [ +//! Label { +//! text: "politics".to_string(), +//! score: 0.975, +//! id: 0, +//! sentence: 1, +//! }, +//! Label { +//! text: "economics".to_string(), +//! score: 0.852, +//! id: 2, +//! sentence: 1, +//! }, +//! Label { +//! text: "public health".to_string(), +//! score: 0.0818, +//! id: 1, +//! sentence: 1, +//! }, +//! Label { +//! text: "sports".to_string(), +//! score: 0.001, +//! id: 3, +//! sentence: 1, +//! }, +//! ], +//! ] +//! .to_vec(); +//! ``` + use crate::albert::AlbertForSequenceClassification; use crate::bart::{ BartConfigResources, BartForSequenceClassification, BartMergesResources, BartModelResources, @@ -511,7 +598,7 @@ impl ZeroShotClassificationModel { /// /// * `input` - `&[&str]` Array of texts to classify. /// * `labels` - `&[&str]` Possible labels for the inputs. - /// * `template` - `Option String>>` closure to build label propositions. If None, will default to `"This example is {}."`. + /// * `template` - `Option String>>` closure to build label propositions. If None, will default to `"This example is about {}."`. /// * `max_length` -`usize` Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template. /// /// # Returns