Merge pull request #199 from guillaume-be/fnet_implementation

Fnet implementation
This commit is contained in:
guillaume-be 2021-11-19 18:47:51 +01:00 committed by GitHub
commit b0d9f50d58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 2044 additions and 12 deletions

View File

@ -71,6 +71,7 @@ jobs:
--test electra
--test gpt2
--test marian
--test fnet
test-batch-1:
name: Integration tests (batch 1)

View File

@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file. The format
- (BREAKING) Support for `bad_word_ids` generation, allowing to ban a set of word ids for all model supporting text generation
- Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices)
- (BREAKING) Extension of the generation options that can be provided at runtime (after a model has been instantiated with a `GenerateConfig`), allowing to update the generation options from one text generation to another with the same model. This feature is implemented at the `LanguageGenerator` trait level, the high-level `TextGeneration` pipeline API remains unchanged.
- Addition of the FNet language model and support for sequence, token and multiple choice classification, question answering
## [0.16.0] - 2021-08-24
## Added

View File

@ -41,6 +41,7 @@ The tasks currently supported include:
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
DistilBERT|✅|✅|✅| | | |✅|
MobileBERT|✅|✅|✅| | | |✅|
FNet|✅|✅|✅| | | |✅|
BERT|✅|✅|✅| | | |✅|
RoBERTa|✅|✅|✅| | | |✅|
GPT| | | |✅ | | | |

View File

@ -0,0 +1,56 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
// Set-up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,
model_resource,
config_resource,
vocab_resource,
..Default::default()
};
let sentiment_classifier = SentimentModel::new(sentiment_config)?;
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
let output = sentiment_classifier.predict(&input);
for sentiment in output {
println!("{:?}", sentiment);
}
Ok(())
}

View File

@ -1233,7 +1233,6 @@ mod test {
let vs = tch::nn::VarStore::new(device);
let config = BertConfig::from_file(config_path);
let b: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
let _: Box<dyn Send> = Box::new(b);
let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(&vs.root(), &config));
}
}

156
src/fnet/attention.rs Normal file
View File

@ -0,0 +1,156 @@
// Copyright 2021 Google Research
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use crate::fnet::FNetConfig;
use std::borrow::Borrow;
use tch::nn::LayerNormConfig;
use tch::{nn, Tensor};
pub struct FNetFourierTransform {
layer_norm: nn::LayerNorm,
}
impl FNetFourierTransform {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetFourierTransform
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm = nn::layer_norm(
p.sub("output").sub("LayerNorm"),
vec![config.hidden_size],
layer_norm_config,
);
FNetFourierTransform { layer_norm }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let self_outputs = hidden_states.fft_fft2(None, &[1, 2], "backward").real();
(self_outputs + hidden_states).apply(&self.layer_norm)
}
}
pub struct FNetIntermediate {
dense: nn::Linear,
intermediate_activation_function: TensorFunction,
}
impl FNetIntermediate {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetIntermediate
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let intermediate_activation_function = config.hidden_act.get_function();
FNetIntermediate {
dense,
intermediate_activation_function,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.intermediate_activation_function.get_fn()(&hidden_states.apply(&self.dense))
}
}
pub struct FNetOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl FNetOutput {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetOutput
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
FNetOutput {
dense,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states = hidden_states
.apply(&self.dense)
.apply_t(&self.dropout, train);
(input_tensor + hidden_states).apply(&self.layer_norm)
}
}
pub struct FNetLayer {
fourier: FNetFourierTransform,
intermediate: FNetIntermediate,
output: FNetOutput,
}
impl FNetLayer {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fourier = FNetFourierTransform::new(p / "fourier", config);
let intermediate = FNetIntermediate::new(p / "intermediate", config);
let output = FNetOutput::new(p / "output", config);
FNetLayer {
fourier,
intermediate,
output,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
let fourier_outputs = self.fourier.forward(hidden_states);
let intermediate_output = self.intermediate.forward(&fourier_outputs);
self.output
.forward_t(&intermediate_output, &fourier_outputs, train)
}
}

131
src/fnet/embeddings.rs Normal file
View File

@ -0,0 +1,131 @@
// Copyright 2021 Google Research
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::fnet::FNetConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{EmbeddingConfig, LayerNormConfig};
use tch::{nn, Kind, Tensor};
pub struct FNetEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
projection: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl FNetEmbeddings {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetEmbeddings
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let word_embeddings_config = EmbeddingConfig {
padding_idx: config.pad_token_id.unwrap_or(3),
..Default::default()
};
let word_embeddings = nn::embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
word_embeddings_config,
);
let position_embeddings = nn::embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let token_type_embeddings = nn::embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let projection = nn::linear(
p / "projection",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.hidden_dropout_prob);
FNetEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
projection,
layer_norm,
dropout,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeddings, &self.word_embeddings)?;
let input_embeddings =
input_embeddings.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
input_shape.as_slice(),
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let token_type_embeddings = token_type_ids
.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap())
.apply(&self.token_type_embeddings);
let calc_position_ids = if position_ids.is_none() {
Some(Tensor::arange(
input_shape[1],
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let position_embeddings = position_ids
.unwrap_or_else(|| calc_position_ids.as_ref().unwrap())
.apply(&self.position_embeddings);
let embeddings = input_embeddings + token_type_embeddings + position_embeddings;
Ok(embeddings
.apply(&self.layer_norm)
.apply(&self.projection)
.apply_t(&self.dropout, train))
}
}

81
src/fnet/encoder.rs Normal file
View File

@ -0,0 +1,81 @@
// Copyright 2021 Google Research
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::fnet::attention::FNetLayer;
use crate::fnet::FNetConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub struct FNetEncoder {
layers: Vec<FNetLayer>,
output_hidden_states: bool,
}
impl FNetEncoder {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let p_layers = p / "layer";
let mut layers: Vec<FNetLayer> = Vec::with_capacity(config.num_hidden_layers as usize);
for layer_index in 0..config.num_hidden_layers {
layers.push(FNetLayer::new(&p_layers / layer_index, config));
}
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
FNetEncoder {
layers,
output_hidden_states,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> FNetEncoderOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut x: Option<Tensor> = None;
for layer in &self.layers {
let temp = if let Some(x_value) = &x {
layer.forward_t(x_value, train)
} else {
layer.forward_t(hidden_states, train)
};
x = Some(temp);
if let Some(all_hidden_states) = all_hidden_states.borrow_mut() {
all_hidden_states.push(x.as_ref().unwrap().copy());
};
}
FNetEncoderOutput {
hidden_states: x.unwrap(),
all_hidden_states,
}
}
}
/// Container for the FNet encoder output.
pub struct FNetEncoderOutput {
/// Last hidden states from the model
pub hidden_states: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
}

1052
src/fnet/fnet_model.rs Normal file

File diff suppressed because it is too large Load Diff

64
src/fnet/mod.rs Normal file
View File

@ -0,0 +1,64 @@
//! # FNet, Mixing Tokens with Fourier Transforms (Lee-Thorp et al.)
//!
//! Implementation of the FNet language model ([https://arxiv.org/abs/2105.03824](https://arxiv.org/abs/2105.03824) Lee-Thorp, Ainslie, Eckstein, Ontanon, 2021).
//! The base model is implemented in the `fnet_model::FNetModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `fnet_model::FNetForMaskedLM`
//! - Question answering: `fnet_model::FNetForQuestionAnswering`
//! - Sequence classification: `fnet_model::FNetForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `fnet_model::FNetForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! The example below illustrate a FNet Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `FNetTokenizer` using a `spiece.model` SentencePiece (BPE) model file
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: FNetTokenizer =
//! FNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
//! let config = FNetConfig::from_file(config_path);
//! let bert_model = FNetForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
//! # }
//! ```
mod attention;
mod embeddings;
mod encoder;
mod fnet_model;
pub use fnet_model::{
FNetConfig, FNetConfigResources, FNetForMaskedLM, FNetForMultipleChoice,
FNetForQuestionAnswering, FNetForSequenceClassification, FNetForTokenClassification,
FNetMaskedLMOutput, FNetModel, FNetModelOutput, FNetModelResources,
FNetQuestionAnsweringOutput, FNetSequenceClassificationOutput, FNetTokenClassificationOutput,
FNetVocabResources,
};

View File

@ -51,6 +51,7 @@
//! :-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
//! DistilBERT|✅|✅|✅| | | |✅|
//! MobileBERT|✅|✅|✅| | | |✅|
//! FNet|✅|✅|✅| | | |✅|
//! BERT|✅|✅|✅| | | |✅|
//! RoBERTa|✅|✅|✅| | | |✅|
//! GPT| | | |✅ | | | |
@ -583,6 +584,7 @@ pub mod bert;
mod common;
pub mod distilbert;
pub mod electra;
pub mod fnet;
pub mod gpt2;
pub mod gpt_neo;
pub mod longformer;

View File

@ -883,3 +883,31 @@ impl LanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M100Tokeni
for M2M100Generator
{
}
#[cfg(test)]
mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
Config,
};
use super::*;
#[test]
#[ignore] // compilation is enough, no need to run
fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device);
let config = M2M100Config::from_file(config_path);
let _: Box<dyn Send> = Box::new(M2M100Model::new(&vs.root(), &config));
}
}

View File

@ -1093,3 +1093,31 @@ impl LanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart50Token
for MBartGenerator
{
}
#[cfg(test)]
mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
Config,
};
use super::*;
#[test]
#[ignore] // compilation is enough, no need to run
fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device);
let config = MBartConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(MBartModel::new(&vs.root(), &config));
}
}

View File

@ -22,6 +22,7 @@ use crate::bert::BertConfig;
use crate::common::error::RustBertError;
use crate::distilbert::DistilBertConfig;
use crate::electra::ElectraConfig;
use crate::fnet::FNetConfig;
use crate::gpt2::Gpt2Config;
use crate::gpt_neo::GptNeoConfig;
use crate::longformer::LongformerConfig;
@ -35,15 +36,15 @@ use crate::t5::T5Config;
use crate::xlnet::XLNetConfig;
use crate::Config;
use rust_tokenizers::tokenizer::{
AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer,
MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer,
ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer,
TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
AlbertTokenizer, BertTokenizer, FNetTokenizer, Gpt2Tokenizer, M2M100Tokenizer,
MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer,
PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer,
Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
};
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab,
PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab,
XLNetVocab,
AlbertVocab, BertVocab, FNetVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab,
OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab,
XLMRobertaVocab, XLNetVocab,
};
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
@ -73,6 +74,7 @@ pub enum ModelType {
GPTNeo,
MBart,
M2M100,
FNet,
}
/// # Abstraction that holds a model configuration, can be of any of the supported models
@ -111,6 +113,8 @@ pub enum ConfigOption {
MBart(MBartConfig),
/// M2M100 configuration
M2M100(M2M100Config),
/// FNet configuration
FNet(FNetConfig),
}
/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
@ -143,6 +147,8 @@ pub enum TokenizerOption {
MBart50(MBart50Tokenizer),
/// M2M100 Tokenizer
M2M100(M2M100Tokenizer),
/// FNet Tokenizer
FNet(FNetTokenizer),
}
impl ConfigOption {
@ -168,6 +174,7 @@ impl ConfigOption {
ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)),
ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)),
ModelType::M2M100 => ConfigOption::M2M100(M2M100Config::from_file(path)),
ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)),
}
}
@ -225,6 +232,10 @@ impl ConfigOption {
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::FNet(config) => config
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"),
@ -251,6 +262,7 @@ impl ConfigOption {
Self::GPTNeo(config) => Some(config.max_position_embeddings),
Self::MBart(config) => Some(config.max_position_embeddings),
Self::M2M100(config) => Some(config.max_position_embeddings),
Self::FNet(config) => Some(config.max_position_embeddings),
}
}
}
@ -469,6 +481,11 @@ impl TokenizerOption {
lower_case,
)?)
}
ModelType::FNet => TokenizerOption::FNet(FNetTokenizer::from_file(
vocab_path,
lower_case,
strip_accents.unwrap_or(false),
)?),
};
Ok(tokenizer)
}
@ -490,6 +507,7 @@ impl TokenizerOption {
Self::Pegasus(_) => ModelType::Pegasus,
Self::MBart50(_) => ModelType::MBart,
Self::M2M100(_) => ModelType::M2M100,
Self::FNet(_) => ModelType::FNet,
}
}
@ -603,6 +621,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
}
}
@ -713,6 +738,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
}
}
@ -768,6 +800,9 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::FNet(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
}
}
@ -788,6 +823,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text),
Self::MBart50(ref tokenizer) => tokenizer.tokenize(text),
Self::M2M100(ref tokenizer) => tokenizer.tokenize(text),
Self::FNet(ref tokenizer) => tokenizer.tokenize(text),
}
}
@ -808,6 +844,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
}
}
@ -837,6 +874,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
}
}
@ -890,6 +928,9 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::FNet(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
}
}
@ -956,6 +997,10 @@ impl TokenizerOption {
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::FNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
};
TokenizedInput {
token_ids: token_ids_with_special_tokens.token_ids,
@ -989,6 +1034,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}
@ -1051,6 +1097,10 @@ impl TokenizerOption {
.special_values
.get(M2M100Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
}
}
@ -1123,6 +1173,12 @@ impl TokenizerOption {
.get(M2M100Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Reformer(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
@ -1180,6 +1236,12 @@ impl TokenizerOption {
.get(M2M100Vocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,

View File

@ -51,6 +51,7 @@ use crate::distilbert::{
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::fnet::FNetForQuestionAnswering;
use crate::longformer::LongformerForQuestionAnswering;
use crate::mobilebert::MobileBertForQuestionAnswering;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -279,6 +280,8 @@ pub enum QuestionAnsweringOption {
Reformer(ReformerForQuestionAnswering),
/// Longformer for Question Answering
Longformer(LongformerForQuestionAnswering),
/// FNet for Question Answering
FNet(FNetForQuestionAnswering),
}
impl QuestionAnsweringOption {
@ -398,6 +401,17 @@ impl QuestionAnsweringOption {
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(QuestionAnsweringOption::FNet(
FNetForQuestionAnswering::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"QuestionAnswering not implemented for {:?}!",
model_type
@ -417,6 +431,7 @@ impl QuestionAnsweringOption {
Self::XLNet(_) => ModelType::XLNet,
Self::Reformer(_) => ModelType::Reformer,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
@ -470,6 +485,12 @@ impl QuestionAnsweringOption {
.expect("Error in reformer forward pass");
(outputs.start_logits, outputs.end_logits)
}
Self::FNet(ref model) => {
let outputs = model
.forward_t(input_ids, None, None, None, train)
.expect("Error in fnet forward pass");
(outputs.start_logits, outputs.end_logits)
}
}
}
}

View File

@ -76,7 +76,7 @@ pub struct Sentiment {
pub score: f64,
}
type SentimentConfig = SequenceClassificationConfig;
pub type SentimentConfig = SequenceClassificationConfig;
/// # SentimentClassifier to perform sentiment analysis
pub struct SentimentModel {

View File

@ -66,6 +66,7 @@ use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -197,6 +198,8 @@ pub enum SequenceClassificationOption {
Reformer(ReformerForSequenceClassification),
/// Longformer for Sequence Classification
Longformer(LongformerForSequenceClassification),
/// FNet for Sequence Classification
FNet(FNetForSequenceClassification),
}
impl SequenceClassificationOption {
@ -327,6 +330,17 @@ impl SequenceClassificationOption {
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(SequenceClassificationOption::FNet(
FNetForSequenceClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Sequence Classification not implemented for {:?}!",
model_type
@ -347,6 +361,7 @@ impl SequenceClassificationOption {
Self::Bart(_) => ModelType::Bart,
Self::Reformer(_) => ModelType::Reformer,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
@ -455,6 +470,12 @@ impl SequenceClassificationOption {
.expect("Error in Longformer forward pass.")
.logits
}
Self::FNet(ref model) => {
model
.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
.expect("Error in FNet forward pass.")
.logits
}
}
}
}

View File

@ -118,6 +118,7 @@ use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification;
use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -317,6 +318,8 @@ pub enum TokenClassificationOption {
XLNet(XLNetForTokenClassification),
/// Longformer for Token Classification
Longformer(LongformerForTokenClassification),
/// FNet for Token Classification
FNet(FNetForTokenClassification),
}
impl TokenClassificationOption {
@ -436,6 +439,17 @@ impl TokenClassificationOption {
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(TokenClassificationOption::FNet(
FNetForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Token classification not implemented for {:?}!",
model_type
@ -455,6 +469,7 @@ impl TokenClassificationOption {
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
@ -556,6 +571,12 @@ impl TokenClassificationOption {
.expect("Error in longformer forward_t")
.logits
}
Self::FNet(ref model) => {
model
.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
.expect("Error in fnet forward_t")
.logits
}
}
}
}

View File

@ -242,7 +242,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
let albert_model = AlbertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = [
@ -268,7 +268,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 12, 4]);
assert_eq!(

307
tests/fnet.rs Normal file
View File

@ -0,0 +1,307 @@
extern crate anyhow;
extern crate dirs;
use rust_bert::fnet::{
FNetConfig, FNetConfigResources, FNetForMaskedLM, FNetForMultipleChoice,
FNetForQuestionAnswering, FNetForTokenClassification, FNetModelResources, FNetVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{FNetTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn fnet_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetModelResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: FNetTokenizer =
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let config = FNetConfig::from_file(config_path);
let fnet_model = FNetForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![3; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| fnet_model.forward_t(Some(&input_tensor), None, None, None, false))?;
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("▁one", word_1);
assert_eq!("▁the", word_2);
assert!((f64::from(model_output.prediction_scores.get(0).get(4).max()) - 13.1721).abs() < 1e-4);
Ok(())
}
#[test]
fn fnet_for_sequence_classification() -> anyhow::Result<()> {
// Set up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,
model_resource,
config_resource,
vocab_resource,
..Default::default()
};
let sentiment_classifier = SentimentModel::new(sentiment_config)?;
// Get sentiments
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = sentiment_classifier.predict(&input);
assert_eq!(output.len(), 3usize);
assert_eq!(output[0].polarity, SentimentPolarity::Negative);
assert!((output[0].score - 0.9978).abs() < 1e-4);
assert_eq!(output[1].polarity, SentimentPolarity::Negative);
assert!((output[1].score - 0.9982).abs() < 1e-4);
assert_eq!(output[2].polarity, SentimentPolarity::Positive);
assert!((output[2].score - 0.7570).abs() < 1e-4);
Ok(())
}
//
#[test]
fn fnet_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: FNetTokenizer =
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let mut config = FNetConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let fnet_model = FNetForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let model_output = no_grad(|| {
fnet_model
.forward_t(Some(&input_tensor), None, None, None, false)
.unwrap()
});
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
config.num_hidden_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
Ok(())
}
#[test]
fn fnet_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: FNetTokenizer =
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let mut config = FNetConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O"));
dummy_label_mapping.insert(1, String::from("LOC"));
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
config.output_hidden_states = Some(true);
let fnet_model = FNetForTokenClassification::new(&vs.root(), &config);
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
fnet_model
.forward_t(Some(&input_tensor), None, None, None, false)
.unwrap()
});
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!(
config.num_hidden_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
Ok(())
}
#[test]
fn fnet_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: FNetTokenizer =
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let mut config = FNetConfig::from_file(config_path);
config.output_hidden_states = Some(true);
let fnet_model = FNetForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
fnet_model
.forward_t(Some(&input_tensor), None, None, None, false)
.unwrap()
});
assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]);
assert_eq!(
config.num_hidden_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
Ok(())
}