From 61e5d2d56316d263e95e7b9476682c954451534c Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 13 Nov 2021 09:39:57 +0100 Subject: [PATCH] Addition of FNet model resource for sentiment analysis and registration in pipelines --- examples/sentiment_analysis_fnet.rs | 56 +++++++++++++++++ src/fnet/fnet_model.rs | 17 +++++- src/pipelines/common.rs | 76 +++++++++++++++++++++--- src/pipelines/question_answering.rs | 21 +++++++ src/pipelines/sentiment.rs | 2 +- src/pipelines/sequence_classification.rs | 21 +++++++ src/pipelines/token_classification.rs | 21 +++++++ 7 files changed, 205 insertions(+), 9 deletions(-) create mode 100644 examples/sentiment_analysis_fnet.rs diff --git a/examples/sentiment_analysis_fnet.rs b/examples/sentiment_analysis_fnet.rs new file mode 100644 index 0000000..0d2d475 --- /dev/null +++ b/examples/sentiment_analysis_fnet.rs @@ -0,0 +1,56 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel}; +use rust_bert::resources::{RemoteResource, Resource}; + +fn main() -> anyhow::Result<()> { + // Set-up classifier + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + FNetConfigResources::BASE_SST2, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + FNetVocabResources::BASE_SST2, + )); + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + FNetModelResources::BASE_SST2, + )); + + let sentiment_config = SentimentConfig { + model_type: ModelType::FNet, + model_resource, + config_resource, + vocab_resource, + ..Default::default() + }; + + let sentiment_classifier = SentimentModel::new(sentiment_config)?; + + // Define input + let input = [ + "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", + "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", + "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", + ]; + + // Run model + let output = sentiment_classifier.predict(&input); + for sentiment in output { + println!("{:?}", sentiment); + } + + Ok(()) +} diff --git a/src/fnet/fnet_model.rs b/src/fnet/fnet_model.rs index 13e1eff..0537b8c 100644 --- a/src/fnet/fnet_model.rs +++ b/src/fnet/fnet_model.rs @@ -38,6 +38,11 @@ impl FNetModelResources { "fnet-base/model", "https://huggingface.co/google/fnet-base/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BASE_SST2: (&'static str, &'static str) = ( + "fnet-base-sst2/model", + "https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/rust_model.ot", + ); } impl FNetConfigResources { @@ -46,6 +51,11 @@ impl FNetConfigResources { "fnet-base/config", "https://huggingface.co/google/fnet-base/resolve/main/config.json", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BASE_SST2: (&'static str, &'static str) = ( + "fnet-base-sst2/config", + "https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/config.json", + ); } impl FNetVocabResources { @@ -54,6 +64,11 @@ impl FNetVocabResources { "fnet-base/spiece", "https://huggingface.co/google/fnet-base/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BASE_SST2: (&'static str, &'static str) = ( + "fnet-base-sst2/spiece", + "https://huggingface.co/google/fnet-base/resolve/main/spiece.model", + ); } #[derive(Debug, Serialize, Deserialize)] @@ -227,7 +242,7 @@ impl FNetModel { None }; Ok(FNetModelOutput { - hidden_states, + hidden_states: encoder_output.hidden_states, pooled_output, all_hidden_states: encoder_output.all_hidden_states, }) diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index b2dd0c8..cf2be49 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -22,6 +22,7 @@ use crate::bert::BertConfig; use crate::common::error::RustBertError; use crate::distilbert::DistilBertConfig; use crate::electra::ElectraConfig; +use crate::fnet::FNetConfig; use crate::gpt2::Gpt2Config; use crate::gpt_neo::GptNeoConfig; use crate::longformer::LongformerConfig; @@ -35,15 +36,15 @@ use crate::t5::T5Config; use crate::xlnet::XLNetConfig; use crate::Config; use rust_tokenizers::tokenizer::{ - AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer, - MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer, - ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer, - TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer, + AlbertTokenizer, BertTokenizer, FNetTokenizer, Gpt2Tokenizer, M2M100Tokenizer, + MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer, + PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, + Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer, }; use rust_tokenizers::vocab::{ - AlbertVocab, BertVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab, - PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, - XLNetVocab, + AlbertVocab, BertVocab, FNetVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab, + OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, + XLMRobertaVocab, XLNetVocab, }; use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets}; use serde::{Deserialize, Serialize}; @@ -73,6 +74,7 @@ pub enum ModelType { GPTNeo, MBart, M2M100, + FNet, } /// # Abstraction that holds a model configuration, can be of any of the supported models @@ -111,6 +113,8 @@ pub enum ConfigOption { MBart(MBartConfig), /// M2M100 configuration M2M100(M2M100Config), + /// FNet configuration + FNet(FNetConfig), } /// # Abstraction that holds a particular tokenizer, can be of any of the supported models @@ -143,6 +147,8 @@ pub enum TokenizerOption { MBart50(MBart50Tokenizer), /// M2M100 Tokenizer M2M100(M2M100Tokenizer), + /// FNet Tokenizer + FNet(FNetTokenizer), } impl ConfigOption { @@ -168,6 +174,7 @@ impl ConfigOption { ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)), ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)), ModelType::M2M100 => ConfigOption::M2M100(M2M100Config::from_file(path)), + ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)), } } @@ -225,6 +232,10 @@ impl ConfigOption { .id2label .as_ref() .expect("No label dictionary (id2label) provided in configuration file"), + Self::FNet(config) => config + .id2label + .as_ref() + .expect("No label dictionary (id2label) provided in configuration file"), Self::T5(_) => panic!("T5 does not use a label mapping"), Self::GPT2(_) => panic!("GPT2 does not use a label mapping"), Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"), @@ -251,6 +262,7 @@ impl ConfigOption { Self::GPTNeo(config) => Some(config.max_position_embeddings), Self::MBart(config) => Some(config.max_position_embeddings), Self::M2M100(config) => Some(config.max_position_embeddings), + Self::FNet(config) => Some(config.max_position_embeddings), } } } @@ -469,6 +481,11 @@ impl TokenizerOption { lower_case, )?) } + ModelType::FNet => TokenizerOption::FNet(FNetTokenizer::from_file( + vocab_path, + lower_case, + strip_accents.unwrap_or(false), + )?), }; Ok(tokenizer) } @@ -490,6 +507,7 @@ impl TokenizerOption { Self::Pegasus(_) => ModelType::Pegasus, Self::MBart50(_) => ModelType::MBart, Self::M2M100(_) => ModelType::M2M100, + Self::FNet(_) => ModelType::FNet, } } @@ -603,6 +621,13 @@ impl TokenizerOption { truncation_strategy, stride, ), + Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_list( + tokenizer, + text_list, + max_len, + truncation_strategy, + stride, + ), } } @@ -713,6 +738,13 @@ impl TokenizerOption { truncation_strategy, stride, ), + Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list( + tokenizer, + text_pair_list, + max_len, + truncation_strategy, + stride, + ), } } @@ -768,6 +800,9 @@ impl TokenizerOption { Self::M2M100(ref tokenizer) => { tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride) } + Self::FNet(ref tokenizer) => { + tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride) + } } } @@ -788,6 +823,7 @@ impl TokenizerOption { Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text), Self::MBart50(ref tokenizer) => tokenizer.tokenize(text), Self::M2M100(ref tokenizer) => tokenizer.tokenize(text), + Self::FNet(ref tokenizer) => tokenizer.tokenize(text), } } @@ -808,6 +844,7 @@ impl TokenizerOption { Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text), Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text), Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text), + Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text), } } @@ -837,6 +874,7 @@ impl TokenizerOption { Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), + Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), } } @@ -890,6 +928,9 @@ impl TokenizerOption { Self::M2M100(ref tokenizer) => { tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces) } + Self::FNet(ref tokenizer) => { + tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces) + } } } @@ -956,6 +997,10 @@ impl TokenizerOption { token_ids_with_offsets_1, token_ids_with_offsets_2, ), + Self::FNet(ref tokenizer) => tokenizer.build_input_with_special_tokens( + token_ids_with_offsets_1, + token_ids_with_offsets_2, + ), }; TokenizedInput { token_ids: token_ids_with_special_tokens.token_ids, @@ -989,6 +1034,7 @@ impl TokenizerOption { Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), } } @@ -1051,6 +1097,10 @@ impl TokenizerOption { .special_values .get(M2M100Vocab::unknown_value()) .expect("UNK token not found in vocabulary"), + Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(FNetVocab::unknown_value()) + .expect("UNK token not found in vocabulary"), } } @@ -1123,6 +1173,12 @@ impl TokenizerOption { .get(M2M100Vocab::pad_value()) .expect("PAD token not found in vocabulary"), ), + Self::FNet(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(FNetVocab::pad_value()) + .expect("PAD token not found in vocabulary"), + ), Self::Reformer(_) => None, Self::GPT2(_) => None, Self::OpenAiGpt(_) => None, @@ -1180,6 +1236,12 @@ impl TokenizerOption { .get(M2M100Vocab::sep_value()) .expect("SEP token not found in vocabulary"), ), + Self::FNet(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(FNetVocab::sep_value()) + .expect("SEP token not found in vocabulary"), + ), Self::Marian(_) => None, Self::T5(_) => None, Self::GPT2(_) => None, diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index 1df53d7..93e9cc3 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -51,6 +51,7 @@ use crate::distilbert::{ DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources, DistilBertVocabResources, }; +use crate::fnet::FNetForQuestionAnswering; use crate::longformer::LongformerForQuestionAnswering; use crate::mobilebert::MobileBertForQuestionAnswering; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; @@ -277,6 +278,8 @@ pub enum QuestionAnsweringOption { Reformer(ReformerForQuestionAnswering), /// Longformer for Question Answering Longformer(LongformerForQuestionAnswering), + /// FNet for Question Answering + FNet(FNetForQuestionAnswering), } impl QuestionAnsweringOption { @@ -396,6 +399,17 @@ impl QuestionAnsweringOption { )) } } + ModelType::FNet => { + if let ConfigOption::FNet(config) = config { + Ok(QuestionAnsweringOption::FNet( + FNetForQuestionAnswering::new(p, config), + )) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a FNetConfig for FNet!".to_string(), + )) + } + } _ => Err(RustBertError::InvalidConfigurationError(format!( "QuestionAnswering not implemented for {:?}!", model_type @@ -415,6 +429,7 @@ impl QuestionAnsweringOption { Self::XLNet(_) => ModelType::XLNet, Self::Reformer(_) => ModelType::Reformer, Self::Longformer(_) => ModelType::Longformer, + Self::FNet(_) => ModelType::FNet, } } @@ -468,6 +483,12 @@ impl QuestionAnsweringOption { .expect("Error in reformer forward pass"); (outputs.start_logits, outputs.end_logits) } + Self::FNet(ref model) => { + let outputs = model + .forward_t(input_ids, None, None, None, train) + .expect("Error in fnet forward pass"); + (outputs.start_logits, outputs.end_logits) + } } } } diff --git a/src/pipelines/sentiment.rs b/src/pipelines/sentiment.rs index 3b08557..0fd22d8 100644 --- a/src/pipelines/sentiment.rs +++ b/src/pipelines/sentiment.rs @@ -75,7 +75,7 @@ pub struct Sentiment { pub score: f64, } -type SentimentConfig = SequenceClassificationConfig; +pub type SentimentConfig = SequenceClassificationConfig; /// # SentimentClassifier to perform sentiment analysis pub struct SentimentModel { diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 685f9d3..7e6ffd2 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -66,6 +66,7 @@ use crate::distilbert::{ DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources, DistilBertVocabResources, }; +use crate::fnet::FNetForSequenceClassification; use crate::longformer::LongformerForSequenceClassification; use crate::mobilebert::MobileBertForSequenceClassification; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; @@ -197,6 +198,8 @@ pub enum SequenceClassificationOption { Reformer(ReformerForSequenceClassification), /// Longformer for Sequence Classification Longformer(LongformerForSequenceClassification), + /// FNet for Sequence Classification + FNet(FNetForSequenceClassification), } impl SequenceClassificationOption { @@ -327,6 +330,17 @@ impl SequenceClassificationOption { )) } } + ModelType::FNet => { + if let ConfigOption::FNet(config) = config { + Ok(SequenceClassificationOption::FNet( + FNetForSequenceClassification::new(p, config), + )) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a FNetConfig for FNet!".to_string(), + )) + } + } _ => Err(RustBertError::InvalidConfigurationError(format!( "Sequence Classification not implemented for {:?}!", model_type @@ -347,6 +361,7 @@ impl SequenceClassificationOption { Self::Bart(_) => ModelType::Bart, Self::Reformer(_) => ModelType::Reformer, Self::Longformer(_) => ModelType::Longformer, + Self::FNet(_) => ModelType::FNet, } } @@ -455,6 +470,12 @@ impl SequenceClassificationOption { .expect("Error in Longformer forward pass.") .logits } + Self::FNet(ref model) => { + model + .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) + .expect("Error in FNet forward pass.") + .logits + } } } } diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 1cf0fcd..4c884d5 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -118,6 +118,7 @@ use crate::common::error::RustBertError; use crate::common::resources::{RemoteResource, Resource}; use crate::distilbert::DistilBertForTokenClassification; use crate::electra::ElectraForTokenClassification; +use crate::fnet::FNetForTokenClassification; use crate::longformer::LongformerForTokenClassification; use crate::mobilebert::MobileBertForTokenClassification; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; @@ -317,6 +318,8 @@ pub enum TokenClassificationOption { XLNet(XLNetForTokenClassification), /// Longformer for Token Classification Longformer(LongformerForTokenClassification), + /// FNet for Token Classification + FNet(FNetForTokenClassification), } impl TokenClassificationOption { @@ -436,6 +439,17 @@ impl TokenClassificationOption { )) } } + ModelType::FNet => { + if let ConfigOption::FNet(config) = config { + Ok(TokenClassificationOption::FNet( + FNetForTokenClassification::new(p, config), + )) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply an FNetConfig for FNet!".to_string(), + )) + } + } _ => Err(RustBertError::InvalidConfigurationError(format!( "Token classification not implemented for {:?}!", model_type @@ -455,6 +469,7 @@ impl TokenClassificationOption { Self::Albert(_) => ModelType::Albert, Self::XLNet(_) => ModelType::XLNet, Self::Longformer(_) => ModelType::Longformer, + Self::FNet(_) => ModelType::FNet, } } @@ -556,6 +571,12 @@ impl TokenClassificationOption { .expect("Error in longformer forward_t") .logits } + Self::FNet(ref model) => { + model + .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) + .expect("Error in fnet forward_t") + .logits + } } } }