From dae899fea6cb7a10f94564ad41cc775d7b192168 Mon Sep 17 00:00:00 2001 From: Vincent Xiao Date: Thu, 22 Dec 2022 00:58:02 +0800 Subject: [PATCH] Add pipelines::masked_language and codebert support (#282) * ad support for loading local moddel in SequenceClassificationConfig * adjust config to match the SequenceClassificationConfig * add piplines::masked_language * add support and example for codebert * provide an optional mask_token String field for asked_language pipline * update example for masked_language pipeline * codebert support revocation * revoke support for loading local moddel * solve conflicts * update MaskedLanguageConfig * fix doctest error in zero_shot_classification.rs * MaskedLM pipeline updates * fix multiple masked token, added test * Updated changelog and docs Co-authored-by: Guillaume Becquin --- CHANGELOG.md | 1 + Cargo.toml | 12 +- README.md | 28 ++ examples/masked_language.rs | 46 ++ examples/masked_language_model_bert.rs | 96 ---- src/lib.rs | 33 ++ src/pipelines/common.rs | 108 +++++ src/pipelines/masked_language.rs | 566 ++++++++++++++++++++++ src/pipelines/mod.rs | 1 + src/pipelines/zero_shot_classification.rs | 12 +- tests/bert.rs | 41 ++ 11 files changed, 839 insertions(+), 105 deletions(-) create mode 100644 examples/masked_language.rs delete mode 100644 examples/masked_language_model_bert.rs create mode 100644 src/pipelines/masked_language.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 841e4a3..1f7d2e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. The format ## Added - Addition of All-MiniLM-L6-V2 model weights - Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT) +- Addition of Masked Language Model pipeline, allowing to predict masked words. ## Changed - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`). diff --git a/Cargo.toml b/Cargo.toml index 17bcabf..fa85db3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,13 @@ repository = "https://github.com/guillaume-be/rust-bert" documentation = "https://docs.rs/rust-bert" license = "Apache-2.0" readme = "README.md" -keywords = ["nlp", "deep-learning", "machine-learning", "transformers", "translation"] +keywords = [ + "nlp", + "deep-learning", + "machine-learning", + "transformers", + "translation", +] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -57,7 +63,7 @@ opt-level = 3 default = ["remote"] doc-only = ["tch/doc-only"] all-tests = [] -remote = [ "cached-path", "dirs", "lazy_static" ] +remote = ["cached-path", "dirs", "lazy_static"] [package.metadata.docs.rs] features = ["doc-only"] @@ -82,6 +88,6 @@ anyhow = "1.0.58" csv = "1.1.6" criterion = "0.3.6" tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] } -torch-sys = "~0.9.0" +torch-sys = "0.9.0" tempfile = "3.3.0" itertools = "0.10.3" diff --git a/README.md b/README.md index e28f769..a860c02 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ The tasks currently supported include: - Part of Speech tagging - Question-Answering - Language Generation + - Masked Language Model - Sentence Embeddings
@@ -436,6 +437,33 @@ Output: ```
+ +
+ 12. Masked Language Model + +Predict masked words in input sentences. +```rust + let model = MaskedLanguageModel::new(Default::default())?; + + let sentences = [ + "Hello I am a student", + "Paris is the of France. It is in Europe.", + ]; + + let output = model.predict(&sentences); +``` +Output: +``` +[ + [MaskedToken { text: "college", id: 2267, score: 8.091}], + [ + MaskedToken { text: "capital", id: 3007, score: 16.7249}, + MaskedToken { text: "located", id: 2284, score: 9.0452} + ] +] +``` +
+ ## Benchmarks For simple pipelines (sequence classification, tokens classification, question answering) the performance between Python and Rust is expected to be comparable. This is because the most expensive part of these pipeline is the language model itself, sharing a common implementation in the Torch backend. The [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/) provides a benchmarks section covering all pipelines. diff --git a/examples/masked_language.rs b/examples/masked_language.rs new file mode 100644 index 0000000..bd3db54 --- /dev/null +++ b/examples/masked_language.rs @@ -0,0 +1,46 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; +use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; +use rust_bert::resources::RemoteResource; +fn main() -> anyhow::Result<()> { + // Set-up model + let config = MaskedLanguageConfig::new( + ModelType::Bert, + RemoteResource::from_pretrained(BertModelResources::BERT), + RemoteResource::from_pretrained(BertConfigResources::BERT), + RemoteResource::from_pretrained(BertVocabResources::BERT), + None, + true, + None, + None, + Some(String::from("")), + ); + + let mask_language_model = MaskedLanguageModel::new(config)?; + // Define input + let input = [ + "Hello I am a student", + "Paris is the of France. It is in Europe.", + ]; + + // Run model + let output = mask_language_model.predict(input)?; + for sentence_output in output { + println!("{:?}", sentence_output); + } + + Ok(()) +} diff --git a/examples/masked_language_model_bert.rs b/examples/masked_language_model_bert.rs deleted file mode 100644 index 4598f91..0000000 --- a/examples/masked_language_model_bert.rs +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::bert::{ - BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources, -}; -use rust_bert::resources::{RemoteResource, ResourceProvider}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT); - let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT); - let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: BertTokenizer = - BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; - let config = BertConfig::from_file(config_path); - let bert_model = BertForMaskedLM::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = [ - "Looks like one [MASK] is missing", - "It was a very nice and [MASK] day", - ]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = no_grad(|| { - bert_model.forward_t( - Some(&input_tensor), - None, - None, - None, - None, - None, - None, - false, - ) - }); - - // Print masked tokens - let index_1 = model_output - .prediction_scores - .get(0) - .get(4) - .argmax(0, false); - let index_2 = model_output - .prediction_scores - .get(1) - .get(7) - .argmax(0, false); - let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); - let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); - - println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing" - println!("{}", word_2); // Outputs "pear" : "It was a very nice and [pleasant] day" - - Ok(()) -} diff --git a/src/lib.rs b/src/lib.rs index 1c35f8f..c53a41c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ //! - Question-Answering //! - Language Generation //! - Sentence Embeddings +//! - Masked Language Model //! //! More information on these can be found in the [`pipelines` module](./pipelines/index.html) //! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust @@ -613,6 +614,38 @@ //! # ; //! ``` //! +//!   +//!
+//! 12. Masked Language Model +//! +//! Predict masked words in input sentences. +//!```no_run +//! # use rust_bert::pipelines::masked_language::MaskedLanguageModel; +//! # fn main() -> anyhow::Result<()> { +//! let model = MaskedLanguageModel::new(Default::default())?; +//! +//! let sentences = [ +//! "Hello I am a student", +//! "Paris is the of France. It is in Europe.", +//! ]; +//! +//! let output = model.predict(&sentences); +//! # Ok(()) +//! # } +//! ``` +//! Output: +//!```no_run +//! # use rust_bert::pipelines::masked_language::MaskedToken; +//! let output = vec![ +//! vec![MaskedToken { text: String::from("college"), id: 2267, score: 8.091}], +//! vec![ +//! MaskedToken { text: String::from("capital"), id: 3007, score: 16.7249}, +//! MaskedToken { text: String::from("located"), id: 2284, score: 9.0452} +//! ] +//! ] +//! # ; +//! ``` +//!
//! //! ## Benchmarks //! diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index 89710dc..8ecedcd 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -1533,6 +1533,114 @@ impl TokenizerOption { } } + /// Interface method + pub fn get_mask_id(&self) -> Option { + match *self { + Self::Bert(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(BertVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::Deberta(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(DeBERTaVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::DebertaV2(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(DeBERTaV2Vocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::Roberta(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(RobertaVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::Bart(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(RobertaVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::XLMRoberta(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(XLMRobertaVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::Albert(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(AlbertVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::XLNet(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(XLNetVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::ProphetNet(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(ProphetNetVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::MBart50(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(MBart50Vocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::FNet(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(FNetVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::Pegasus(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(PegasusVocab::mask_value()) + .expect("MASK token not found in vocabulary"), + ), + Self::Marian(_) => None, + Self::M2M100(_) => None, + Self::T5(_) => None, + Self::GPT2(_) => None, + Self::OpenAiGpt(_) => None, + Self::Reformer(_) => None, + } + } + + /// Interface method + pub fn get_mask_value(&self) -> Option<&str> { + match self { + Self::Bert(_) => Some(BertVocab::mask_value()), + Self::Deberta(_) => Some(DeBERTaVocab::mask_value()), + Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()), + Self::Roberta(_) => Some(RobertaVocab::mask_value()), + Self::Bart(_) => Some(RobertaVocab::mask_value()), + Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()), + Self::Albert(_) => Some(AlbertVocab::mask_value()), + Self::XLNet(_) => Some(XLNetVocab::mask_value()), + Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()), + Self::MBart50(_) => Some(MBart50Vocab::mask_value()), + Self::FNet(_er) => Some(FNetVocab::mask_value()), + Self::M2M100(_) => None, + Self::Marian(_) => None, + Self::T5(_) => None, + Self::GPT2(_) => None, + Self::OpenAiGpt(_) => None, + Self::Reformer(_) => None, + Self::Pegasus(_) => None, + } + } + /// Interface method pub fn get_bos_id(&self) -> Option { match *self { diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs new file mode 100644 index 0000000..2ffad11 --- /dev/null +++ b/src/pipelines/masked_language.rs @@ -0,0 +1,566 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019-2020 Guillaume Becquin +// Copyright 2020 Maarten van Gompel +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//! # Masked language pipeline (e.g. Fill Mask) +//! Fill in the missing / masked words in input sequences. The pattern to use to specify +//! a masked word can be specified in the `MaskedLanguageConfig` (`mask_token`). and allows +//! multiple masked tokens per input sequence. +//! +//! ```no_run +//!use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; +//!use rust_bert::pipelines::common::ModelType; +//!use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; +//!use rust_bert::resources::RemoteResource; +//! fn main() -> anyhow::Result<()> { +//! +//! let config = MaskedLanguageConfig::new( +//! ModelType::Bert, +//! RemoteResource::from_pretrained(BertModelResources::BERT), +//! RemoteResource::from_pretrained(BertConfigResources::BERT), +//! RemoteResource::from_pretrained(BertVocabResources::BERT), +//! None, +//! true, +//! None, +//! None, +//! Some(String::from("")), +//! ); +//! +//! let mask_language_model = MaskedLanguageModel::new(config)?; +//! let input = [ +//! "Hello I am a student", +//! "Paris is the of France. It is in Europe.", +//! ]; +//! +//! let output = mask_language_model.predict(input)?; +//! Ok(()) +//! } +//! ``` +//! +use crate::bert::BertForMaskedLM; +use crate::common::error::RustBertError; +use crate::deberta::DebertaForMaskedLM; +use crate::deberta_v2::DebertaV2ForMaskedLM; +use crate::fnet::FNetForMaskedLM; +use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; +use crate::resources::ResourceProvider; +use crate::roberta::RobertaForMaskedLM; +#[cfg(feature = "remote")] +use crate::{ + bert::{BertConfigResources, BertModelResources, BertVocabResources}, + resources::RemoteResource, +}; +use rust_tokenizers::tokenizer::TruncationStrategy; +use rust_tokenizers::TokenizedInput; +use std::borrow::Borrow; +use tch::nn::VarStore; +use tch::{nn, no_grad, Device, Tensor}; + +#[derive(Debug, Clone)] +/// Output container for masked language model pipeline. +pub struct MaskedToken { + /// String representation of the masked word + pub text: String, + /// Vocabulary index for the masked word + pub id: i64, + /// Score for the masked word + pub score: f64, +} + +/// # Configuration for MaskedLanguageModel +/// Contains information regarding the model to load and device to place the model on. +pub struct MaskedLanguageConfig { + /// Model type + pub model_type: ModelType, + /// Model weights resource (default: pretrained BERT model on CoNLL) + pub model_resource: Box, + /// Config resource (default: pretrained BERT model on CoNLL) + pub config_resource: Box, + /// Vocab resource (default: pretrained BERT model on CoNLL) + pub vocab_resource: Box, + /// Merges resource (default: None) + pub merges_resource: Option>, + /// Automatically lower case all input upon tokenization (assumes a lower-cased model) + pub lower_case: bool, + /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models + pub strip_accents: Option, + /// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models) + pub add_prefix_space: Option, + /// Token used for masking words. This is the token which the model will try to predict. + pub mask_token: Option, + /// Device to place the model on (default: CUDA/GPU when available) + pub device: Device, +} + +impl MaskedLanguageConfig { + /// Instantiate a new masked language configuration of the supplied type. + /// + /// # Arguments + /// + /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) + /// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot) + /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json) + /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) + /// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. + /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) + /// * mask_token - A token used for model to predict masking words.. + pub fn new( + model_type: ModelType, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, + lower_case: bool, + strip_accents: impl Into>, + add_prefix_space: impl Into>, + mask_token: impl Into>, + ) -> MaskedLanguageConfig + where + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, + { + MaskedLanguageConfig { + model_type, + model_resource: Box::new(model_resource), + config_resource: Box::new(config_resource), + vocab_resource: Box::new(vocab_resource), + merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>), + lower_case, + strip_accents: strip_accents.into(), + add_prefix_space: add_prefix_space.into(), + mask_token: mask_token.into(), + device: Device::cuda_if_available(), + } + } +} +#[cfg(feature = "remote")] +impl Default for MaskedLanguageConfig { + /// Provides a BERT language model + fn default() -> MaskedLanguageConfig { + MaskedLanguageConfig::new( + ModelType::Bert, + RemoteResource::from_pretrained(BertModelResources::BERT), + RemoteResource::from_pretrained(BertConfigResources::BERT), + RemoteResource::from_pretrained(BertVocabResources::BERT), + None, + true, + None, + None, + None, + ) + } +} + +#[allow(clippy::large_enum_variant)] +/// # Abstraction that holds one particular masked language model, for any of the supported models +pub enum MaskedLanguageOption { + /// Bert for Masked Language + Bert(BertForMaskedLM), + /// DeBERTa for Masked Language + Deberta(DebertaForMaskedLM), + /// DeBERTa V2 for Masked Language + DebertaV2(DebertaV2ForMaskedLM), + /// Roberta for Masked Language + Roberta(RobertaForMaskedLM), + /// XLMRoberta for Masked Language + XLMRoberta(RobertaForMaskedLM), + /// FNet for Masked Language + FNet(FNetForMaskedLM), +} +impl MaskedLanguageOption { + /// Instantiate a new masked language model of the supplied type. + /// + /// # Arguments + /// + /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded) + /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) + /// * `config` - A configuration (the model type of the configuration must be compatible with the value for + /// `model_type`) + pub fn new<'p, P>( + model_type: ModelType, + p: P, + config: &ConfigOption, + ) -> Result + where + P: Borrow>, + { + match model_type { + ModelType::Bert => { + if let ConfigOption::Bert(config) = config { + Ok(MaskedLanguageOption::Bert(BertForMaskedLM::new(p, config))) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a BertConfig for Bert!".to_string(), + )) + } + } + ModelType::Deberta => { + if let ConfigOption::Deberta(config) = config { + Ok(MaskedLanguageOption::Deberta(DebertaForMaskedLM::new( + p, config, + ))) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a DebertaConfig for DeBERTa!".to_string(), + )) + } + } + ModelType::DebertaV2 => { + if let ConfigOption::DebertaV2(config) = config { + Ok(MaskedLanguageOption::DebertaV2(DebertaV2ForMaskedLM::new( + p, config, + ))) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a DebertaV2Config for DeBERTa V2!".to_string(), + )) + } + } + ModelType::Roberta => { + if let ConfigOption::Roberta(config) = config { + Ok(MaskedLanguageOption::Roberta(RobertaForMaskedLM::new( + p, config, + ))) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a BertConfig for Roberta!".to_string(), + )) + } + } + ModelType::XLMRoberta => { + if let ConfigOption::Bert(config) = config { + Ok(MaskedLanguageOption::XLMRoberta(RobertaForMaskedLM::new( + p, config, + ))) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a BertConfig for Roberta!".to_string(), + )) + } + } + ModelType::FNet => { + if let ConfigOption::FNet(config) = config { + Ok(MaskedLanguageOption::FNet(FNetForMaskedLM::new(p, config))) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a FNetConfig for FNet!".to_string(), + )) + } + } + _ => Err(RustBertError::InvalidConfigurationError(format!( + "Masked Language is not implemented for {:?}!", + model_type + ))), + } + } + + /// Returns the `ModelType` for this MaskedLanguageOption + pub fn model_type(&self) -> ModelType { + match *self { + Self::Bert(_) => ModelType::Bert, + Self::Deberta(_) => ModelType::Deberta, + Self::DebertaV2(_) => ModelType::DebertaV2, + Self::Roberta(_) => ModelType::Roberta, + Self::XLMRoberta(_) => ModelType::Roberta, + Self::FNet(_) => ModelType::FNet, + } + } + + /// Interface method to forward_t() of the particular models. + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + encoder_mask: Option<&Tensor>, + train: bool, + ) -> Tensor { + match *self { + Self::Bert(ref model) => { + model + .forward_t( + input_ids, + mask, + token_type_ids, + position_ids, + input_embeds, + encoder_hidden_states, + encoder_mask, + train, + ) + .prediction_scores + } + + Self::Deberta(ref model) => { + model + .forward_t( + input_ids, + mask, + token_type_ids, + position_ids, + input_embeds, + train, + ) + .expect("Error in Deberta forward_t") + .logits + } + Self::DebertaV2(ref model) => { + model + .forward_t( + input_ids, + mask, + token_type_ids, + position_ids, + input_embeds, + train, + ) + .expect("Error in Deberta V2 forward_t") + .logits + } + + Self::Roberta(ref model) | Self::XLMRoberta(ref model) => { + model + .forward_t( + input_ids, + mask, + token_type_ids, + position_ids, + input_embeds, + encoder_hidden_states, + encoder_mask, + train, + ) + .prediction_scores + } + Self::FNet(ref model) => { + model + .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) + .expect("Error in FNet forward pass.") + .prediction_scores + } + } + } +} + +/// # MaskedLanguageModel for Masked Language (e.g. Fill Mask) +pub struct MaskedLanguageModel { + tokenizer: TokenizerOption, + language_encode: MaskedLanguageOption, + mask_token: Option, + var_store: VarStore, + max_length: usize, +} + +impl MaskedLanguageModel { + /// Build a new `MaskedLanguageModel` + /// + /// # Arguments + /// + /// * `config` - `MaskedLanguageConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::pipelines::masked_language::MaskedLanguageModel; + /// + /// let model = MaskedLanguageModel::new(Default::default())?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(config: MaskedLanguageConfig) -> Result { + let config_path = config.config_resource.get_local_path()?; + let vocab_path = config.vocab_resource.get_local_path()?; + let weights_path = config.model_resource.get_local_path()?; + let merges_path = if let Some(merges_resource) = &config.merges_resource { + Some(merges_resource.get_local_path()?) + } else { + None + }; + let device = config.device; + + let tokenizer = TokenizerOption::from_file( + config.model_type, + vocab_path.to_str().unwrap(), + merges_path.as_deref().map(|path| path.to_str().unwrap()), + config.lower_case, + config.strip_accents, + config.add_prefix_space, + )?; + let mut var_store = VarStore::new(device); + let model_config = ConfigOption::from_file(config.model_type, config_path); + let max_length = model_config + .get_max_len() + .map(|v| v as usize) + .unwrap_or(usize::MAX); + + let language_encode = + MaskedLanguageOption::new(config.model_type, &var_store.root(), &model_config)?; + var_store.load(weights_path)?; + let mask_token = config.mask_token; + Ok(MaskedLanguageModel { + tokenizer, + language_encode, + mask_token, + var_store, + max_length, + }) + } + + /// Replace custom user-provided mask token by language model mask token. + fn replace_mask_token<'a, S>( + &self, + input: S, + mask_token: &str, + ) -> Result, RustBertError> + where + S: AsRef<[&'a str]>, + { + let model_mask_token = self.tokenizer.get_mask_value().ok_or_else(|| + RustBertError::InvalidConfigurationError("Tokenizer does ot have a default mask token and no mask token provided in configuration. \ + Please provide a `mask_token` in the configuration.".into()))?; + let output = input + .as_ref() + .iter() + .map(|&x| x.replace(mask_token, model_mask_token)) + .collect::>(); + Ok(output) + } + + fn prepare_for_model<'a, S>(&self, input: S) -> Tensor + where + S: AsRef<[&'a str]>, + { + let tokenized_input: Vec = self.tokenizer.encode_list( + input.as_ref(), + self.max_length, + &TruncationStrategy::LongestFirst, + 0, + ); + let max_len = tokenized_input + .iter() + .map(|input| input.token_ids.len()) + .max() + .unwrap(); + let tokenized_input_tensors = tokenized_input + .iter() + .map(|input| input.token_ids.clone()) + .map(|mut input| { + input.extend(vec![0; max_len - input.len()]); + input + }) + .map(|input| Tensor::of_slice(&(input))) + .collect::>(); + Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()) + } + + /// Mask texts + /// + /// # Arguments + /// + /// * `input` - `&[&str]` Array of texts to mask. + /// + /// # Returns + /// + /// * `Vec` containing masked words for input texts + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::pipelines::masked_language::MaskedLanguageModel; + /// // Set-up model + /// let mask_language_model = MaskedLanguageModel::new(Default::default())?; + /// + /// // Define input + /// let input = [ + /// "Looks like one [MASK] is missing", + /// "It was a very nice and [MASK] day", + /// ]; + /// + /// // Run model + /// let output = mask_language_model.predict(&input); + /// for word in output { + /// println!("{:?}", word); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn predict<'a, S>(&self, input: S) -> Result>, RustBertError> + where + S: AsRef<[&'a str]>, + { + let input_tensor = if let Some(mask_token) = &self.mask_token { + let input_with_replaced_mask = self.replace_mask_token(input.as_ref(), mask_token)?; + self.prepare_for_model( + input_with_replaced_mask + .iter() + .map(|w| w.as_str()) + .collect::>(), + ) + } else { + self.prepare_for_model(input.as_ref()) + }; + + let output = no_grad(|| { + self.language_encode.forward_t( + Some(&input_tensor), + None, + None, + None, + None, + None, + None, + false, + ) + }); + // get the position of mask_token in input texts + let mask_token_id = + self.tokenizer + .get_mask_id() + .ok_or_else(|| RustBertError::InvalidConfigurationError( + "Tokenizer does not have a mask token id, Please use a tokenizer/model with a mask token.".into(), + ))?; + let mask_token_mask = input_tensor.eq(mask_token_id); + let mut output_tokens = Vec::with_capacity(input.as_ref().len()); + for input_id in 0..input.as_ref().len() as i64 { + let mut sequence_tokens = vec![]; + let sequence_mask = mask_token_mask.get(input_id); + if bool::from(sequence_mask.any()) { + let mask_scores = output + .get(input_id) + .index_select(0, &sequence_mask.argwhere().squeeze_dim(1)); + let (token_scores, token_ids) = mask_scores.max_dim(1, false); + for (id, score) in token_ids.iter::()?.zip(token_scores.iter::()?) { + let text = self.tokenizer.decode(&[id], false, true); + sequence_tokens.push(MaskedToken { text, id, score }); + } + } + output_tokens.push(sequence_tokens); + } + Ok(output_tokens) + } +} +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[ignore] // no need to run, compilation is enough to verify it is Send + fn test() { + let config = MaskedLanguageConfig::default(); + let _: Box = Box::new(MaskedLanguageModel::new(config)); + } +} diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index f810399..2f55cef 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -478,6 +478,7 @@ pub mod common; pub mod conversation; pub mod generation_utils; pub mod keywords_extraction; +pub mod masked_language; pub mod ner; pub mod pos_tagging; pub mod question_answering; diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 998cece..c538324 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -680,7 +680,7 @@ impl ZeroShotClassificationModel { /// let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; /// let candidate_labels = &["politics", "public health", "economics", "sports"]; /// - /// let output = sequence_classification_model.try_predict( + /// let output = sequence_classification_model.predict( /// &[input_sentence, input_sequence_2], /// candidate_labels, /// None, @@ -693,7 +693,7 @@ impl ZeroShotClassificationModel { /// outputs: /// ```no_run /// # use rust_bert::pipelines::sequence_classification::Label; - /// let output = Ok([ + /// let output = [ /// Label { /// text: "politics".to_string(), /// score: 0.959, @@ -707,7 +707,7 @@ impl ZeroShotClassificationModel { /// sentence: 1, /// }, /// ] - /// .to_vec()); + /// .to_vec(); /// ``` pub fn predict<'a, S, T>( &self, @@ -783,7 +783,7 @@ impl ZeroShotClassificationModel { /// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy."; /// let candidate_labels = &["politics", "public health", "economics", "sports"]; /// - /// let output = sequence_classification_model.try_predict_multilabel( + /// let output = sequence_classification_model.predict_multilabel( /// &[input_sentence, input_sequence_2], /// candidate_labels, /// None, @@ -795,7 +795,7 @@ impl ZeroShotClassificationModel { /// outputs: /// ```no_run /// # use rust_bert::pipelines::sequence_classification::Label; - /// let output = Ok([ + /// let output = [ /// [ /// Label { /// text: "politics".to_string(), @@ -849,7 +849,7 @@ impl ZeroShotClassificationModel { /// }, /// ], /// ] - /// .to_vec()); + /// .to_vec(); /// ``` pub fn predict_multilabel<'a, S, T>( &self, diff --git a/tests/bert.rs b/tests/bert.rs index 2a20249..d790922 100644 --- a/tests/bert.rs +++ b/tests/bert.rs @@ -7,6 +7,7 @@ use rust_bert::bert::{ BertModelResources, BertVocabResources, }; use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; use rust_bert::pipelines::ner::NERModel; use rust_bert::pipelines::question_answering::{ QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, @@ -100,6 +101,46 @@ fn bert_masked_lm() -> anyhow::Result<()> { Ok(()) } +#[test] +fn bert_masked_lm_pipeline() -> anyhow::Result<()> { + // Set-up model + let config = MaskedLanguageConfig::new( + ModelType::Bert, + RemoteResource::from_pretrained(BertModelResources::BERT), + RemoteResource::from_pretrained(BertConfigResources::BERT), + RemoteResource::from_pretrained(BertVocabResources::BERT), + None, + true, + None, + None, + Some(String::from("")), + ); + + let mask_language_model = MaskedLanguageModel::new(config)?; + // Define input + let input = [ + "Hello I am a student", + "Paris is the of France. It is in Europe.", + ]; + + // Run model + let output = mask_language_model.predict(input)?; + + assert_eq!(output.len(), 2); + assert_eq!(output[0].len(), 1); + assert_eq!(output[0][0].id, 2267); + assert_eq!(output[0][0].text, "college"); + assert!((output[0][0].score - 8.0919).abs() < 1e-4); + assert_eq!(output[1].len(), 2); + assert_eq!(output[1][0].id, 3007); + assert_eq!(output[1][0].text, "capital"); + assert!((output[1][0].score - 16.7249).abs() < 1e-4); + assert_eq!(output[1][1].id, 2284); + assert_eq!(output[1][1].text, "located"); + assert!((output[1][1].score - 9.0452).abs() < 1e-4); + Ok(()) +} + #[test] fn bert_for_sequence_classification() -> anyhow::Result<()> { // Resources paths