Registration of DeBERTa to eligible pipelines

This commit is contained in:
Guillaume Becquin 2021-12-10 08:52:34 +01:00
parent 867a76774b
commit 71f216598f
6 changed files with 175 additions and 4 deletions

View File

@ -1092,7 +1092,7 @@ impl DebertaForQuestionAnswering {
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
@ -1100,7 +1100,7 @@ impl DebertaForQuestionAnswering {
) -> Result<DebertaQuestionAnsweringOutput, RustBertError> {
let base_model_output = self.deberta.forward_t(
input_ids,
mask,
attention_mask,
token_type_ids,
position_ids,
input_embeds,

View File

@ -20,6 +20,7 @@ use crate::albert::AlbertConfig;
use crate::bart::BartConfig;
use crate::bert::BertConfig;
use crate::common::error::RustBertError;
use crate::deberta::DebertaConfig;
use crate::distilbert::DistilBertConfig;
use crate::electra::ElectraConfig;
use crate::fnet::FNetConfig;
@ -36,8 +37,8 @@ use crate::t5::T5Config;
use crate::xlnet::XLNetConfig;
use crate::Config;
use rust_tokenizers::tokenizer::{
AlbertTokenizer, BertTokenizer, FNetTokenizer, Gpt2Tokenizer, M2M100Tokenizer,
MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer,
AlbertTokenizer, BertTokenizer, DeBERTaTokenizer, FNetTokenizer, Gpt2Tokenizer,
M2M100Tokenizer, MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer,
PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer,
Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
};
@ -57,6 +58,7 @@ pub enum ModelType {
Bart,
Bert,
DistilBert,
Deberta,
Roberta,
XLMRoberta,
Electra,
@ -85,6 +87,8 @@ pub enum ConfigOption {
Bert(BertConfig),
/// DistilBert configuration
DistilBert(DistilBertConfig),
/// DeBERTa configuration
Deberta(DebertaConfig),
/// Electra configuration
Electra(ElectraConfig),
/// Marian configuration
@ -121,6 +125,8 @@ pub enum ConfigOption {
pub enum TokenizerOption {
/// Bert Tokenizer
Bert(BertTokenizer),
/// DeBERTa Tokenizer
Deberta(DeBERTaTokenizer),
/// Roberta Tokenizer
Roberta(RobertaTokenizer),
/// XLMRoberta Tokenizer
@ -159,6 +165,7 @@ impl ConfigOption {
ModelType::Bert | ModelType::Roberta | ModelType::XLMRoberta => {
ConfigOption::Bert(BertConfig::from_file(path))
}
ModelType::Deberta => ConfigOption::Deberta(DebertaConfig::from_file(path)),
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
ModelType::Marian => ConfigOption::Marian(BartConfig::from_file(path)),
@ -188,6 +195,10 @@ impl ConfigOption {
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Deberta(config) => config
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config
.id2label
.as_ref()
@ -247,6 +258,7 @@ impl ConfigOption {
match self {
Self::Bart(config) => Some(config.max_position_embeddings),
Self::Bert(config) => Some(config.max_position_embeddings),
Self::Deberta(config) => Some(config.max_position_embeddings),
Self::DistilBert(config) => Some(config.max_position_embeddings),
Self::Electra(config) => Some(config.max_position_embeddings),
Self::Marian(config) => Some(config.max_position_embeddings),
@ -297,6 +309,26 @@ impl TokenizerOption {
strip_accents.unwrap_or(lower_case),
)?)
}
ModelType::Deberta => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::Deberta(DeBERTaTokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)?)
}
ModelType::Roberta | ModelType::Bart | ModelType::Longformer => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
@ -494,6 +526,7 @@ impl TokenizerOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::Marian(_) => ModelType::Marian,
@ -530,6 +563,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
@ -647,6 +687,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
@ -761,6 +808,9 @@ impl TokenizerOption {
Self::Bert(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::Deberta(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::Roberta(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
@ -810,6 +860,7 @@ impl TokenizerOption {
pub fn tokenize(&self, text: &str) -> Vec<String> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.tokenize(text),
Self::Deberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Roberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Marian(ref tokenizer) => tokenizer.tokenize(text),
Self::T5(ref tokenizer) => tokenizer.tokenize(text),
@ -831,6 +882,7 @@ impl TokenizerOption {
pub fn tokenize_with_offsets(&self, text: &str) -> TokensWithOffsets {
match *self {
Self::Bert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Deberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Roberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Marian(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::T5(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
@ -855,6 +907,7 @@ impl TokenizerOption {
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Marian(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::T5(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
@ -889,6 +942,9 @@ impl TokenizerOption {
Self::Bert(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::Deberta(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::Roberta(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
@ -945,6 +1001,10 @@ impl TokenizerOption {
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Deberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
@ -1021,6 +1081,7 @@ impl TokenizerOption {
{
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Deberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
@ -1045,6 +1106,10 @@ impl TokenizerOption {
.special_values
.get(BertVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Deberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Roberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::unknown_value())
@ -1113,6 +1178,12 @@ impl TokenizerOption {
.get(BertVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
@ -1194,6 +1265,12 @@ impl TokenizerOption {
.get(BertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values

View File

@ -47,6 +47,7 @@ use crate::albert::AlbertForQuestionAnswering;
use crate::bert::BertForQuestionAnswering;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForQuestionAnswering;
use crate::distilbert::{
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
@ -264,6 +265,8 @@ impl Default for QuestionAnsweringConfig {
pub enum QuestionAnsweringOption {
/// Bert for Question Answering
Bert(BertForQuestionAnswering),
/// DeBERTa for Question Answering
Deberta(DebertaForQuestionAnswering),
/// DistilBert for Question Answering
DistilBert(DistilBertForQuestionAnswering),
/// MobileBert for Question Answering
@ -313,6 +316,17 @@ impl QuestionAnsweringOption {
))
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(QuestionAnsweringOption::Deberta(
DebertaForQuestionAnswering::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTa!".to_string(),
))
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(QuestionAnsweringOption::DistilBert(
@ -423,6 +437,7 @@ impl QuestionAnsweringOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::DistilBert(_) => ModelType::DistilBert,
@ -448,6 +463,12 @@ impl QuestionAnsweringOption {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.start_logits, outputs.end_logits)
}
Self::Deberta(ref model) => {
let outputs = model
.forward_t(input_ids, mask, None, None, input_embeds, train)
.expect("Error in Deberta forward_t");
(outputs.start_logits, outputs.end_logits)
}
Self::DistilBert(ref model) => {
let outputs = model
.forward_t(input_ids, mask, input_embeds, train)

View File

@ -62,6 +62,7 @@ use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
@ -180,6 +181,8 @@ impl Default for SequenceClassificationConfig {
pub enum SequenceClassificationOption {
/// Bert for Sequence Classification
Bert(BertForSequenceClassification),
/// DeBERTa for Sequence Classification
Deberta(DebertaForSequenceClassification),
/// DistilBert for Sequence Classification
DistilBert(DistilBertModelClassifier),
/// MobileBert for Sequence Classification
@ -231,6 +234,17 @@ impl SequenceClassificationOption {
))
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(SequenceClassificationOption::Deberta(
DebertaForSequenceClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTa!".to_string(),
))
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(SequenceClassificationOption::DistilBert(
@ -352,6 +366,7 @@ impl SequenceClassificationOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
@ -400,6 +415,19 @@ impl SequenceClassificationOption {
)
.logits
}
Self::Deberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in Deberta forward_t")
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)

View File

@ -116,6 +116,7 @@ use crate::bert::{
};
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification;
use crate::fnet::FNetForTokenClassification;
@ -302,6 +303,8 @@ impl Default for TokenClassificationConfig {
pub enum TokenClassificationOption {
/// Bert for Token Classification
Bert(BertForTokenClassification),
/// DeBERTa for Token Classification
Deberta(DebertaForTokenClassification),
/// DistilBert for Token Classification
DistilBert(DistilBertForTokenClassification),
/// MobileBert for Token Classification
@ -461,6 +464,7 @@ impl TokenClassificationOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::DistilBert(_) => ModelType::DistilBert,
@ -495,6 +499,19 @@ impl TokenClassificationOption {
)
.logits
}
Self::Deberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in DeBERTa forward_t")
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)

View File

@ -104,6 +104,7 @@ use crate::bart::{
BartVocabResources,
};
use crate::bert::BertForSequenceClassification;
use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
@ -211,6 +212,8 @@ impl Default for ZeroShotClassificationConfig {
pub enum ZeroShotClassificationOption {
/// Bart for Sequence Classification
Bart(BartForSequenceClassification),
/// DeBERTa for Sequence Classification
Deberta(DebertaForSequenceClassification),
/// Bert for Sequence Classification
Bert(BertForSequenceClassification),
/// DistilBert for Sequence Classification
@ -258,6 +261,17 @@ impl ZeroShotClassificationOption {
))
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(ZeroShotClassificationOption::Deberta(
DebertaForSequenceClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTa!".to_string(),
))
}
}
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::Bert(
@ -357,6 +371,7 @@ impl ZeroShotClassificationOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::Deberta(_) => ModelType::Deberta,
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
@ -403,6 +418,19 @@ impl ZeroShotClassificationOption {
)
.logits
}
Self::Deberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in DeBERTa forward_t")
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)