Add pipelines::masked_language and codebert support (#282)

* ad support for loading local moddel in SequenceClassificationConfig

* adjust config to match the SequenceClassificationConfig

* add piplines::masked_language

* add support and example for codebert

* provide an optional mask_token String field for asked_language pipline

* update example for masked_language pipeline

* codebert support revocation

* revoke support for loading local moddel

* solve conflicts

* update MaskedLanguageConfig

* fix doctest error in zero_shot_classification.rs

* MaskedLM pipeline updates

* fix multiple masked token, added test

* Updated changelog and docs

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
Vincent Xiao 2022-12-22 00:58:02 +08:00 committed by GitHub
parent a34cf9f8e4
commit dae899fea6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 839 additions and 105 deletions

View File

@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. The format
## Added
- Addition of All-MiniLM-L6-V2 model weights
- Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT)
- Addition of Masked Language Model pipeline, allowing to predict masked words.
## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`).

View File

@ -8,7 +8,13 @@ repository = "https://github.com/guillaume-be/rust-bert"
documentation = "https://docs.rs/rust-bert"
license = "Apache-2.0"
readme = "README.md"
keywords = ["nlp", "deep-learning", "machine-learning", "transformers", "translation"]
keywords = [
"nlp",
"deep-learning",
"machine-learning",
"transformers",
"translation",
]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -57,7 +63,7 @@ opt-level = 3
default = ["remote"]
doc-only = ["tch/doc-only"]
all-tests = []
remote = [ "cached-path", "dirs", "lazy_static" ]
remote = ["cached-path", "dirs", "lazy_static"]
[package.metadata.docs.rs]
features = ["doc-only"]
@ -82,6 +88,6 @@ anyhow = "1.0.58"
csv = "1.1.6"
criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "~0.9.0"
torch-sys = "0.9.0"
tempfile = "3.3.0"
itertools = "0.10.3"

View File

@ -33,6 +33,7 @@ The tasks currently supported include:
- Part of Speech tagging
- Question-Answering
- Language Generation
- Masked Language Model
- Sentence Embeddings
<details>
@ -436,6 +437,33 @@ Output:
```
</details>
<details>
<summary> <b>12. Masked Language Model </b> </summary>
Predict masked words in input sentences.
```rust
let model = MaskedLanguageModel::new(Default::default())?;
let sentences = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];
let output = model.predict(&sentences);
```
Output:
```
[
[MaskedToken { text: "college", id: 2267, score: 8.091}],
[
MaskedToken { text: "capital", id: 3007, score: 16.7249},
MaskedToken { text: "located", id: 2284, score: 9.0452}
]
]
```
</details>
## Benchmarks
For simple pipelines (sequence classification, tokens classification, question answering) the performance between Python and Rust is expected to be comparable. This is because the most expensive part of these pipeline is the language model itself, sharing a common implementation in the Torch backend. The [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/) provides a benchmarks section covering all pipelines.

View File

@ -0,0 +1,46 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up model
let config = MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,
true,
None,
None,
Some(String::from("<mask>")),
);
let mask_language_model = MaskedLanguageModel::new(config)?;
// Define input
let input = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];
// Run model
let output = mask_language_model.predict(input)?;
for sentence_output in output {
println!("{:?}", sentence_output);
}
Ok(())
}

View File

@ -1,96 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = BertConfig::from_file(config_path);
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
false,
)
});
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2); // Outputs "pear" : "It was a very nice and [pleasant] day"
Ok(())
}

View File

@ -41,6 +41,7 @@
//! - Question-Answering
//! - Language Generation
//! - Sentence Embeddings
//! - Masked Language Model
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
@ -613,6 +614,38 @@
//! # ;
//! ```
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>12. Masked Language Model </b> </summary>
//!
//! Predict masked words in input sentences.
//!```no_run
//! # use rust_bert::pipelines::masked_language::MaskedLanguageModel;
//! # fn main() -> anyhow::Result<()> {
//! let model = MaskedLanguageModel::new(Default::default())?;
//!
//! let sentences = [
//! "Hello I am a <mask> student",
//! "Paris is the <mask> of France. It is <mask> in Europe.",
//! ];
//!
//! let output = model.predict(&sentences);
//! # Ok(())
//! # }
//! ```
//! Output:
//!```no_run
//! # use rust_bert::pipelines::masked_language::MaskedToken;
//! let output = vec![
//! vec![MaskedToken { text: String::from("college"), id: 2267, score: 8.091}],
//! vec![
//! MaskedToken { text: String::from("capital"), id: 3007, score: 16.7249},
//! MaskedToken { text: String::from("located"), id: 2284, score: 9.0452}
//! ]
//! ]
//! # ;
//! ```
//! </details>
//!
//! ## Benchmarks
//!

View File

@ -1533,6 +1533,114 @@ impl TokenizerOption {
}
}
/// Interface method
pub fn get_mask_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::MBart50(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Pegasus(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(PegasusVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::M2M100(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
Self::Reformer(_) => None,
}
}
/// Interface method
pub fn get_mask_value(&self) -> Option<&str> {
match self {
Self::Bert(_) => Some(BertVocab::mask_value()),
Self::Deberta(_) => Some(DeBERTaVocab::mask_value()),
Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()),
Self::Roberta(_) => Some(RobertaVocab::mask_value()),
Self::Bart(_) => Some(RobertaVocab::mask_value()),
Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()),
Self::Albert(_) => Some(AlbertVocab::mask_value()),
Self::XLNet(_) => Some(XLNetVocab::mask_value()),
Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()),
Self::MBart50(_) => Some(MBart50Vocab::mask_value()),
Self::FNet(_er) => Some(FNetVocab::mask_value()),
Self::M2M100(_) => None,
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
Self::Reformer(_) => None,
Self::Pegasus(_) => None,
}
}
/// Interface method
pub fn get_bos_id(&self) -> Option<i64> {
match *self {

View File

@ -0,0 +1,566 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019-2020 Guillaume Becquin
// Copyright 2020 Maarten van Gompel
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Masked language pipeline (e.g. Fill Mask)
//! Fill in the missing / masked words in input sequences. The pattern to use to specify
//! a masked word can be specified in the `MaskedLanguageConfig` (`mask_token`). and allows
//! multiple masked tokens per input sequence.
//!
//! ```no_run
//!use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
//!use rust_bert::pipelines::common::ModelType;
//!use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
//!use rust_bert::resources::RemoteResource;
//! fn main() -> anyhow::Result<()> {
//!
//! let config = MaskedLanguageConfig::new(
//! ModelType::Bert,
//! RemoteResource::from_pretrained(BertModelResources::BERT),
//! RemoteResource::from_pretrained(BertConfigResources::BERT),
//! RemoteResource::from_pretrained(BertVocabResources::BERT),
//! None,
//! true,
//! None,
//! None,
//! Some(String::from("<mask>")),
//! );
//!
//! let mask_language_model = MaskedLanguageModel::new(config)?;
//! let input = [
//! "Hello I am a <mask> student",
//! "Paris is the <mask> of France. It is <mask> in Europe.",
//! ];
//!
//! let output = mask_language_model.predict(input)?;
//! Ok(())
//! }
//! ```
//!
use crate::bert::BertForMaskedLM;
use crate::common::error::RustBertError;
use crate::deberta::DebertaForMaskedLM;
use crate::deberta_v2::DebertaV2ForMaskedLM;
use crate::fnet::FNetForMaskedLM;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForMaskedLM;
#[cfg(feature = "remote")]
use crate::{
bert::{BertConfigResources, BertModelResources, BertVocabResources},
resources::RemoteResource,
};
use rust_tokenizers::tokenizer::TruncationStrategy;
use rust_tokenizers::TokenizedInput;
use std::borrow::Borrow;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
#[derive(Debug, Clone)]
/// Output container for masked language model pipeline.
pub struct MaskedToken {
/// String representation of the masked word
pub text: String,
/// Vocabulary index for the masked word
pub id: i64,
/// Score for the masked word
pub score: f64,
}
/// # Configuration for MaskedLanguageModel
/// Contains information regarding the model to load and device to place the model on.
pub struct MaskedLanguageConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None)
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
pub strip_accents: Option<bool>,
/// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
pub add_prefix_space: Option<bool>,
/// Token used for masking words. This is the token which the model will try to predict.
pub mask_token: Option<String>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
impl MaskedLanguageConfig {
/// Instantiate a new masked language configuration of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// * mask_token - A token used for model to predict masking words..
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
mask_token: impl Into<Option<String>>,
) -> MaskedLanguageConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
MaskedLanguageConfig {
model_type,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
mask_token: mask_token.into(),
device: Device::cuda_if_available(),
}
}
}
#[cfg(feature = "remote")]
impl Default for MaskedLanguageConfig {
/// Provides a BERT language model
fn default() -> MaskedLanguageConfig {
MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,
true,
None,
None,
None,
)
}
}
#[allow(clippy::large_enum_variant)]
/// # Abstraction that holds one particular masked language model, for any of the supported models
pub enum MaskedLanguageOption {
/// Bert for Masked Language
Bert(BertForMaskedLM),
/// DeBERTa for Masked Language
Deberta(DebertaForMaskedLM),
/// DeBERTa V2 for Masked Language
DebertaV2(DebertaV2ForMaskedLM),
/// Roberta for Masked Language
Roberta(RobertaForMaskedLM),
/// XLMRoberta for Masked Language
XLMRoberta(RobertaForMaskedLM),
/// FNet for Masked Language
FNet(FNetForMaskedLM),
}
impl MaskedLanguageOption {
/// Instantiate a new masked language model of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(MaskedLanguageOption::Bert(BertForMaskedLM::new(p, config)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Bert!".to_string(),
))
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(MaskedLanguageOption::Deberta(DebertaForMaskedLM::new(
p, config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTa!".to_string(),
))
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
Ok(MaskedLanguageOption::DebertaV2(DebertaV2ForMaskedLM::new(
p, config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaV2Config for DeBERTa V2!".to_string(),
))
}
}
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
Ok(MaskedLanguageOption::Roberta(RobertaForMaskedLM::new(
p, config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
))
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
Ok(MaskedLanguageOption::XLMRoberta(RobertaForMaskedLM::new(
p, config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(MaskedLanguageOption::FNet(FNetForMaskedLM::new(p, config)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Masked Language is not implemented for {:?}!",
model_type
))),
}
}
/// Returns the `ModelType` for this MaskedLanguageOption
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::DebertaV2(_) => ModelType::DebertaV2,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
Self::FNet(_) => ModelType::FNet,
}
}
/// Interface method to forward_t() of the particular models.
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> Tensor {
match *self {
Self::Bert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.prediction_scores
}
Self::Deberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in Deberta forward_t")
.logits
}
Self::DebertaV2(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in Deberta V2 forward_t")
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.prediction_scores
}
Self::FNet(ref model) => {
model
.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
.expect("Error in FNet forward pass.")
.prediction_scores
}
}
}
}
/// # MaskedLanguageModel for Masked Language (e.g. Fill Mask)
pub struct MaskedLanguageModel {
tokenizer: TokenizerOption,
language_encode: MaskedLanguageOption,
mask_token: Option<String>,
var_store: VarStore,
max_length: usize,
}
impl MaskedLanguageModel {
/// Build a new `MaskedLanguageModel`
///
/// # Arguments
///
/// * `config` - `MaskedLanguageConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::masked_language::MaskedLanguageModel;
///
/// let model = MaskedLanguageModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
pub fn new(config: MaskedLanguageConfig) -> Result<MaskedLanguageModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let vocab_path = config.vocab_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(merges_resource.get_local_path()?)
} else {
None
};
let device = config.device;
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,
)?;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let max_length = model_config
.get_max_len()
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let language_encode =
MaskedLanguageOption::new(config.model_type, &var_store.root(), &model_config)?;
var_store.load(weights_path)?;
let mask_token = config.mask_token;
Ok(MaskedLanguageModel {
tokenizer,
language_encode,
mask_token,
var_store,
max_length,
})
}
/// Replace custom user-provided mask token by language model mask token.
fn replace_mask_token<'a, S>(
&self,
input: S,
mask_token: &str,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<[&'a str]>,
{
let model_mask_token = self.tokenizer.get_mask_value().ok_or_else(||
RustBertError::InvalidConfigurationError("Tokenizer does ot have a default mask token and no mask token provided in configuration. \
Please provide a `mask_token` in the configuration.".into()))?;
let output = input
.as_ref()
.iter()
.map(|&x| x.replace(mask_token, model_mask_token))
.collect::<Vec<_>>();
Ok(output)
}
fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(
input.as_ref(),
self.max_length,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
/// Mask texts
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of texts to mask.
///
/// # Returns
///
/// * `Vec<String>` containing masked words for input texts
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::masked_language::MaskedLanguageModel;
/// // Set-up model
/// let mask_language_model = MaskedLanguageModel::new(Default::default())?;
///
/// // Define input
/// let input = [
/// "Looks like one [MASK] is missing",
/// "It was a very nice and [MASK] day",
/// ];
///
/// // Run model
/// let output = mask_language_model.predict(&input);
/// for word in output {
/// println!("{:?}", word);
/// }
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(&self, input: S) -> Result<Vec<Vec<MaskedToken>>, RustBertError>
where
S: AsRef<[&'a str]>,
{
let input_tensor = if let Some(mask_token) = &self.mask_token {
let input_with_replaced_mask = self.replace_mask_token(input.as_ref(), mask_token)?;
self.prepare_for_model(
input_with_replaced_mask
.iter()
.map(|w| w.as_str())
.collect::<Vec<&str>>(),
)
} else {
self.prepare_for_model(input.as_ref())
};
let output = no_grad(|| {
self.language_encode.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
false,
)
});
// get the position of mask_token in input texts
let mask_token_id =
self.tokenizer
.get_mask_id()
.ok_or_else(|| RustBertError::InvalidConfigurationError(
"Tokenizer does not have a mask token id, Please use a tokenizer/model with a mask token.".into(),
))?;
let mask_token_mask = input_tensor.eq(mask_token_id);
let mut output_tokens = Vec::with_capacity(input.as_ref().len());
for input_id in 0..input.as_ref().len() as i64 {
let mut sequence_tokens = vec![];
let sequence_mask = mask_token_mask.get(input_id);
if bool::from(sequence_mask.any()) {
let mask_scores = output
.get(input_id)
.index_select(0, &sequence_mask.argwhere().squeeze_dim(1));
let (token_scores, token_ids) = mask_scores.max_dim(1, false);
for (id, score) in token_ids.iter::<i64>()?.zip(token_scores.iter::<f64>()?) {
let text = self.tokenizer.decode(&[id], false, true);
sequence_tokens.push(MaskedToken { text, id, score });
}
}
output_tokens.push(sequence_tokens);
}
Ok(output_tokens)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] // no need to run, compilation is enough to verify it is Send
fn test() {
let config = MaskedLanguageConfig::default();
let _: Box<dyn Send> = Box::new(MaskedLanguageModel::new(config));
}
}

View File

@ -478,6 +478,7 @@ pub mod common;
pub mod conversation;
pub mod generation_utils;
pub mod keywords_extraction;
pub mod masked_language;
pub mod ner;
pub mod pos_tagging;
pub mod question_answering;

View File

@ -680,7 +680,7 @@ impl ZeroShotClassificationModel {
/// let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
///
/// let output = sequence_classification_model.try_predict(
/// let output = sequence_classification_model.predict(
/// &[input_sentence, input_sequence_2],
/// candidate_labels,
/// None,
@ -693,7 +693,7 @@ impl ZeroShotClassificationModel {
/// outputs:
/// ```no_run
/// # use rust_bert::pipelines::sequence_classification::Label;
/// let output = Ok([
/// let output = [
/// Label {
/// text: "politics".to_string(),
/// score: 0.959,
@ -707,7 +707,7 @@ impl ZeroShotClassificationModel {
/// sentence: 1,
/// },
/// ]
/// .to_vec());
/// .to_vec();
/// ```
pub fn predict<'a, S, T>(
&self,
@ -783,7 +783,7 @@ impl ZeroShotClassificationModel {
/// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
///
/// let output = sequence_classification_model.try_predict_multilabel(
/// let output = sequence_classification_model.predict_multilabel(
/// &[input_sentence, input_sequence_2],
/// candidate_labels,
/// None,
@ -795,7 +795,7 @@ impl ZeroShotClassificationModel {
/// outputs:
/// ```no_run
/// # use rust_bert::pipelines::sequence_classification::Label;
/// let output = Ok([
/// let output = [
/// [
/// Label {
/// text: "politics".to_string(),
@ -849,7 +849,7 @@ impl ZeroShotClassificationModel {
/// },
/// ],
/// ]
/// .to_vec());
/// .to_vec();
/// ```
pub fn predict_multilabel<'a, S, T>(
&self,

View File

@ -7,6 +7,7 @@ use rust_bert::bert::{
BertModelResources, BertVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
@ -100,6 +101,46 @@ fn bert_masked_lm() -> anyhow::Result<()> {
Ok(())
}
#[test]
fn bert_masked_lm_pipeline() -> anyhow::Result<()> {
// Set-up model
let config = MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,
true,
None,
None,
Some(String::from("<mask>")),
);
let mask_language_model = MaskedLanguageModel::new(config)?;
// Define input
let input = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];
// Run model
let output = mask_language_model.predict(input)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0].len(), 1);
assert_eq!(output[0][0].id, 2267);
assert_eq!(output[0][0].text, "college");
assert!((output[0][0].score - 8.0919).abs() < 1e-4);
assert_eq!(output[1].len(), 2);
assert_eq!(output[1][0].id, 3007);
assert_eq!(output[1][0].text, "capital");
assert!((output[1][0].score - 16.7249).abs() < 1e-4);
assert_eq!(output[1][1].id, 2284);
assert_eq!(output[1][1].text, "located");
assert!((output[1][1].score - 9.0452).abs() < 1e-4);
Ok(())
}
#[test]
fn bert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths