Addition of FNet model resource for sentiment analysis and registration in pipelines

This commit is contained in:
Guillaume Becquin 2021-11-13 09:39:57 +01:00
parent 69a935009a
commit 61e5d2d563
7 changed files with 205 additions and 9 deletions

View File

@ -0,0 +1,56 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
// Set-up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,
model_resource,
config_resource,
vocab_resource,
..Default::default()
};
let sentiment_classifier = SentimentModel::new(sentiment_config)?;
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
let output = sentiment_classifier.predict(&input);
for sentiment in output {
println!("{:?}", sentiment);
}
Ok(())
}

View File

@ -38,6 +38,11 @@ impl FNetModelResources {
"fnet-base/model",
"https://huggingface.co/google/fnet-base/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
pub const BASE_SST2: (&'static str, &'static str) = (
"fnet-base-sst2/model",
"https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/rust_model.ot",
);
}
impl FNetConfigResources {
@ -46,6 +51,11 @@ impl FNetConfigResources {
"fnet-base/config",
"https://huggingface.co/google/fnet-base/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
pub const BASE_SST2: (&'static str, &'static str) = (
"fnet-base-sst2/config",
"https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/config.json",
);
}
impl FNetVocabResources {
@ -54,6 +64,11 @@ impl FNetVocabResources {
"fnet-base/spiece",
"https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
pub const BASE_SST2: (&'static str, &'static str) = (
"fnet-base-sst2/spiece",
"https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
);
}
#[derive(Debug, Serialize, Deserialize)]
@ -227,7 +242,7 @@ impl FNetModel {
None
};
Ok(FNetModelOutput {
hidden_states,
hidden_states: encoder_output.hidden_states,
pooled_output,
all_hidden_states: encoder_output.all_hidden_states,
})

View File

@ -22,6 +22,7 @@ use crate::bert::BertConfig;
use crate::common::error::RustBertError;
use crate::distilbert::DistilBertConfig;
use crate::electra::ElectraConfig;
use crate::fnet::FNetConfig;
use crate::gpt2::Gpt2Config;
use crate::gpt_neo::GptNeoConfig;
use crate::longformer::LongformerConfig;
@ -35,15 +36,15 @@ use crate::t5::T5Config;
use crate::xlnet::XLNetConfig;
use crate::Config;
use rust_tokenizers::tokenizer::{
AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer,
MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer,
ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer,
TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
AlbertTokenizer, BertTokenizer, FNetTokenizer, Gpt2Tokenizer, M2M100Tokenizer,
MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer,
PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer,
Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
};
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab,
PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab,
XLNetVocab,
AlbertVocab, BertVocab, FNetVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab,
OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab,
XLMRobertaVocab, XLNetVocab,
};
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
@ -73,6 +74,7 @@ pub enum ModelType {
GPTNeo,
MBart,
M2M100,
FNet,
}
/// # Abstraction that holds a model configuration, can be of any of the supported models
@ -111,6 +113,8 @@ pub enum ConfigOption {
MBart(MBartConfig),
/// M2M100 configuration
M2M100(M2M100Config),
/// FNet configuration
FNet(FNetConfig),
}
/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
@ -143,6 +147,8 @@ pub enum TokenizerOption {
MBart50(MBart50Tokenizer),
/// M2M100 Tokenizer
M2M100(M2M100Tokenizer),
/// FNet Tokenizer
FNet(FNetTokenizer),
}
impl ConfigOption {
@ -168,6 +174,7 @@ impl ConfigOption {
ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)),
ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)),
ModelType::M2M100 => ConfigOption::M2M100(M2M100Config::from_file(path)),
ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)),
}
}
@ -225,6 +232,10 @@ impl ConfigOption {
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::FNet(config) => config
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"),
@ -251,6 +262,7 @@ impl ConfigOption {
Self::GPTNeo(config) => Some(config.max_position_embeddings),
Self::MBart(config) => Some(config.max_position_embeddings),
Self::M2M100(config) => Some(config.max_position_embeddings),
Self::FNet(config) => Some(config.max_position_embeddings),
}
}
}
@ -469,6 +481,11 @@ impl TokenizerOption {
lower_case,
)?)
}
ModelType::FNet => TokenizerOption::FNet(FNetTokenizer::from_file(
vocab_path,
lower_case,
strip_accents.unwrap_or(false),
)?),
};
Ok(tokenizer)
}
@ -490,6 +507,7 @@ impl TokenizerOption {
Self::Pegasus(_) => ModelType::Pegasus,
Self::MBart50(_) => ModelType::MBart,
Self::M2M100(_) => ModelType::M2M100,
Self::FNet(_) => ModelType::FNet,
}
}
@ -603,6 +621,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
}
}
@ -713,6 +738,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
}
}
@ -768,6 +800,9 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::FNet(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
}
}
@ -788,6 +823,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text),
Self::MBart50(ref tokenizer) => tokenizer.tokenize(text),
Self::M2M100(ref tokenizer) => tokenizer.tokenize(text),
Self::FNet(ref tokenizer) => tokenizer.tokenize(text),
}
}
@ -808,6 +844,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
}
}
@ -837,6 +874,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
}
}
@ -890,6 +928,9 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::FNet(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
}
}
@ -956,6 +997,10 @@ impl TokenizerOption {
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::FNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
};
TokenizedInput {
token_ids: token_ids_with_special_tokens.token_ids,
@ -989,6 +1034,7 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}
@ -1051,6 +1097,10 @@ impl TokenizerOption {
.special_values
.get(M2M100Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
}
}
@ -1123,6 +1173,12 @@ impl TokenizerOption {
.get(M2M100Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Reformer(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
@ -1180,6 +1236,12 @@ impl TokenizerOption {
.get(M2M100Vocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,

View File

@ -51,6 +51,7 @@ use crate::distilbert::{
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::fnet::FNetForQuestionAnswering;
use crate::longformer::LongformerForQuestionAnswering;
use crate::mobilebert::MobileBertForQuestionAnswering;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -277,6 +278,8 @@ pub enum QuestionAnsweringOption {
Reformer(ReformerForQuestionAnswering),
/// Longformer for Question Answering
Longformer(LongformerForQuestionAnswering),
/// FNet for Question Answering
FNet(FNetForQuestionAnswering),
}
impl QuestionAnsweringOption {
@ -396,6 +399,17 @@ impl QuestionAnsweringOption {
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(QuestionAnsweringOption::FNet(
FNetForQuestionAnswering::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"QuestionAnswering not implemented for {:?}!",
model_type
@ -415,6 +429,7 @@ impl QuestionAnsweringOption {
Self::XLNet(_) => ModelType::XLNet,
Self::Reformer(_) => ModelType::Reformer,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
@ -468,6 +483,12 @@ impl QuestionAnsweringOption {
.expect("Error in reformer forward pass");
(outputs.start_logits, outputs.end_logits)
}
Self::FNet(ref model) => {
let outputs = model
.forward_t(input_ids, None, None, None, train)
.expect("Error in fnet forward pass");
(outputs.start_logits, outputs.end_logits)
}
}
}
}

View File

@ -75,7 +75,7 @@ pub struct Sentiment {
pub score: f64,
}
type SentimentConfig = SequenceClassificationConfig;
pub type SentimentConfig = SequenceClassificationConfig;
/// # SentimentClassifier to perform sentiment analysis
pub struct SentimentModel {

View File

@ -66,6 +66,7 @@ use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -197,6 +198,8 @@ pub enum SequenceClassificationOption {
Reformer(ReformerForSequenceClassification),
/// Longformer for Sequence Classification
Longformer(LongformerForSequenceClassification),
/// FNet for Sequence Classification
FNet(FNetForSequenceClassification),
}
impl SequenceClassificationOption {
@ -327,6 +330,17 @@ impl SequenceClassificationOption {
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(SequenceClassificationOption::FNet(
FNetForSequenceClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Sequence Classification not implemented for {:?}!",
model_type
@ -347,6 +361,7 @@ impl SequenceClassificationOption {
Self::Bart(_) => ModelType::Bart,
Self::Reformer(_) => ModelType::Reformer,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
@ -455,6 +470,12 @@ impl SequenceClassificationOption {
.expect("Error in Longformer forward pass.")
.logits
}
Self::FNet(ref model) => {
model
.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
.expect("Error in FNet forward pass.")
.logits
}
}
}
}

View File

@ -118,6 +118,7 @@ use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification;
use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -317,6 +318,8 @@ pub enum TokenClassificationOption {
XLNet(XLNetForTokenClassification),
/// Longformer for Token Classification
Longformer(LongformerForTokenClassification),
/// FNet for Token Classification
FNet(FNetForTokenClassification),
}
impl TokenClassificationOption {
@ -436,6 +439,17 @@ impl TokenClassificationOption {
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(TokenClassificationOption::FNet(
FNetForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Token classification not implemented for {:?}!",
model_type
@ -455,6 +469,7 @@ impl TokenClassificationOption {
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
@ -556,6 +571,12 @@ impl TokenClassificationOption {
.expect("Error in longformer forward_t")
.logits
}
Self::FNet(ref model) => {
model
.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
.expect("Error in fnet forward_t")
.logits
}
}
}
}