Tokenizer special token map update (#330)

* Updates for compatibility with tokenizers special token rework

* Updated mask pipline methods

* Bumped version

* Fix clippy warnings
This commit is contained in:
guillaume-be 2023-01-30 17:53:18 +00:00 committed by GitHub
parent 80e0197e2c
commit 84561ec82b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 536 additions and 659 deletions

View File

@ -2,6 +2,8 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
## [0.20.0] - 2023-01-21
## Added

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.20.0"
version = "0.20.1-alpha"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and language models"
@ -69,7 +69,7 @@ remote = ["cached-path", "dirs", "lazy_static"]
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~7.0.2"
rust_tokenizers = "8.0.0"
tch = "~0.10.1"
serde_json = "1"
serde = { version = "1", features = ["derive"] }

View File

@ -16,7 +16,7 @@ async fn main() -> Result<()> {
"Classify this negative text".to_owned(),
];
let sentiments = classifier.predict(texts).await?;
println!("Results: {:?}", sentiments);
println!("Results: {sentiments:?}");
Ok(())
}

View File

@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = sequence_classification_model.predict(input);
for label in output {
println!("{:?}", label);
println!("{label:?}");
}
// Masked language model
@ -78,7 +78,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = mask_language_model.predict(input)?;
for sentence_output in output {
println!("{:?}", sentence_output);
println!("{sentence_output:?}");
}
Ok(())

View File

@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> {
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
println!("{output:?}");
let _ = conversation_manager
.get(&conversation_1_id)
@ -40,11 +40,11 @@ fn main() -> anyhow::Result<()> {
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
println!("{output:?}");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
println!("{output:?}");
Ok(())
}

View File

@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let output = model.generate(&[input_context], None);
for sentence in output {
println!("{:?}", sentence);
println!("{sentence:?}");
}
Ok(())
}

View File

@ -60,7 +60,7 @@ fn main() -> anyhow::Result<()> {
let output = model.generate(&[input_context_1, input_context_2], None);
for sentence in output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -55,7 +55,7 @@ fn main() -> anyhow::Result<()> {
let output = model.generate(&[input_context_1, input_context_2], None);
for sentence in output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
let output = model.generate(&[input_context], None);
for sentence in output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -39,7 +39,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = mask_language_model.predict(input)?;
for sentence_output in output {
println!("{:?}", sentence_output);
println!("{sentence_output:?}");
}
Ok(())

View File

@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = ner_model.predict_full_entities(&input);
for entity in output {
println!("{:?}", entity);
println!("{entity:?}");
}
Ok(())

View File

@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = pos_model.predict(&input);
for (pos, pos_tag) in output[0].iter().enumerate() {
println!("{} - {:?}", pos, pos_tag);
println!("{pos} - {pos_tag:?}");
}
Ok(())

View File

@ -34,6 +34,6 @@ fn main() -> anyhow::Result<()> {
// Get answer
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
println!("{answers:?}");
Ok(())
}

View File

@ -50,6 +50,6 @@ fn main() -> anyhow::Result<()> {
// Get answer
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
println!("{answers:?}");
Ok(())
}

View File

@ -55,6 +55,6 @@ fn main() -> anyhow::Result<()> {
// Get answer
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
println!("{answers:?}");
Ok(())
}

View File

@ -12,6 +12,6 @@ fn main() -> anyhow::Result<()> {
// Generate Embeddings
let embeddings = model.encode(&sentences)?;
println!("{:?}", embeddings);
println!("{embeddings:?}");
Ok(())
}

View File

@ -32,6 +32,6 @@ fn main() -> anyhow::Result<()> {
// Generate Embeddings
let embeddings = model.encode(&sentences)?;
println!("{:?}", embeddings);
println!("{embeddings:?}");
Ok(())
}

View File

@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = sentiment_classifier.predict(input);
for sentiment in output {
println!("{:?}", sentiment);
println!("{sentiment:?}");
}
Ok(())

View File

@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = sentiment_classifier.predict(input);
for sentiment in output {
println!("{:?}", sentiment);
println!("{sentiment:?}");
}
Ok(())

View File

@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = sequence_classification_model.predict(input);
for label in output {
println!("{:?}", label);
println!("{label:?}");
}
Ok(())

View File

@ -29,7 +29,7 @@ fn main() -> anyhow::Result<()> {
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
if let Ok(labels) = output {
for label in labels {
println!("{:?}", label);
println!("{label:?}");
}
}

View File

@ -73,7 +73,7 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())

View File

@ -68,7 +68,7 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())

View File

@ -70,7 +70,7 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())

View File

@ -56,7 +56,7 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
let token_outputs = token_classification_model.predict(&input);
for token in token_outputs {
println!("{:?}", token);
println!("{token:?}");
}
Ok(())

View File

@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
for sentence in output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> {
let output = model.translate(&[input_context_1, input_context_2], None, None)?;
for sentence in output {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -56,7 +56,7 @@ fn main() -> anyhow::Result<()> {
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
for sentence in outputs {
println!("{}", sentence);
println!("{sentence}");
}
Ok(())
}

View File

@ -27,13 +27,13 @@ fn main() -> anyhow::Result<()> {
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
format!("This example is about {label}.")
})),
128,
)
.unwrap();
println!("{:?}", output);
println!("{output:?}");
Ok(())
}

View File

@ -26,7 +26,7 @@ use crate::pipelines::generation_utils::{
};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
use rust_tokenizers::vocab::RobertaVocab;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
@ -1263,9 +1263,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0],
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids

View File

@ -14,8 +14,7 @@ pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError>
Kind::Double => Scalar::float(f64::INFINITY),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get positive infinity for {:?}",
kind
"Type not supported: attempted to get positive infinity for {kind:?}",
)))
}
})
@ -34,8 +33,7 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError>
Kind::Double => Scalar::float(f64::NEG_INFINITY),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get negative infinity for {:?}",
kind
"Type not supported: attempted to get negative infinity for {kind:?}",
)))
}
})

View File

@ -111,8 +111,7 @@ impl FromStr for PositionAttentionType {
"c2p" => Ok(PositionAttentionType::c2p),
"p2p" => Ok(PositionAttentionType::p2p),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Position attention type `{}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
s
"Position attention type `{s}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
))),
}
}

View File

@ -118,8 +118,7 @@ impl FromStr for NormRelEmbedType {
match s {
"layer_norm" => Ok(NormRelEmbedType::layer_norm),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Layer normalization type `{}` not in accepted variants (`layer_norm`)",
s
"Layer normalization type `{s}` not in accepted variants (`layer_norm`)",
))),
}
}

View File

@ -24,7 +24,7 @@ use crate::pipelines::generation_utils::{
use crate::pipelines::translation::Language;
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
use rust_tokenizers::vocab::M2M100Vocab;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
@ -804,9 +804,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0],
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids

View File

@ -885,7 +885,9 @@ impl MarianGenerator {
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id =
Some(tokenizer.convert_tokens_to_ids(&[MarianVocab::pad_value()])[0]);
Some(tokenizer.get_pad_id().ok_or(RustBertError::TokenizerError(
"The tokenizer must contain a pad token ID to be used as BOS".to_string(),
))?);
let max_position_embeddings = config.max_position_embeddings;
Ok(MarianGenerator {

View File

@ -25,7 +25,7 @@ use crate::pipelines::generation_utils::{
use crate::pipelines::translation::Language;
use crate::{Activation, Config, RustBertError};
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
use rust_tokenizers::vocab::MBart50Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
@ -1059,9 +1059,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0],
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids

View File

@ -780,7 +780,8 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0],
.get_pad_id()
.expect("A padding token must be provided to encode prompt texts."),
};
let token_ids = token_ids

View File

@ -46,11 +46,7 @@ use rust_tokenizers::tokenizer::{
OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer,
T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
};
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, DeBERTaV2Vocab, DeBERTaVocab, FNetVocab, Gpt2Vocab, M2M100Vocab,
MBart50Vocab, MarianVocab, OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab,
RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab,
};
use rust_tokenizers::vocab::Vocab;
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@ -1275,174 +1271,144 @@ impl TokenizerOption {
/// Interface method
pub fn get_unk_id(&self) -> i64 {
match *self {
Self::Bert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Deberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::DebertaV2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Roberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Bart(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::XLMRoberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Marian(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MarianVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::T5(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(T5Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Albert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::XLNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::GPT2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(Gpt2Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::OpenAiGpt(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(OpenAiGptVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Reformer(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ReformerVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::ProphetNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Pegasus(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(PegasusVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::MBart50(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::M2M100(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Bert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Deberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::DebertaV2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Roberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Bart(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::XLMRoberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Marian(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::T5(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Albert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::XLNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::GPT2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::OpenAiGpt(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Reformer(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::ProphetNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::Pegasus(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::MBart50(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::M2M100(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
Self::FNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
}
}
/// Interface method
pub fn get_pad_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::pad_value())
.unwrap_or(&1),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Marian(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MarianVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::T5(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(T5Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Pegasus(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(PegasusVocab::pad_value())
.unwrap_or(&0),
),
Self::MBart50(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::M2M100(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::pad_value())
.unwrap_or(&1),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Bert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Deberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::DebertaV2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Roberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Bart(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::XLMRoberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Marian(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::T5(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Albert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::XLNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::ProphetNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Pegasus(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::MBart50(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::M2M100(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::FNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
Self::Reformer(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
@ -1452,78 +1418,54 @@ impl TokenizerOption {
/// Interface method
pub fn get_sep_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::MBart50(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::sep_value())
.unwrap_or(&1),
),
Self::M2M100(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Bert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::Deberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::DebertaV2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::Roberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::Bart(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::XLMRoberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::Albert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::XLNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::ProphetNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::MBart50(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::M2M100(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::FNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
@ -1536,78 +1478,54 @@ impl TokenizerOption {
/// Interface method
pub fn get_mask_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::MBart50(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Pegasus(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(PegasusVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Bert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::Deberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::DebertaV2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::Roberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::Bart(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::XLMRoberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::Albert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::XLNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::ProphetNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::MBart50(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::FNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::Pegasus(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
Self::Marian(_) => None,
Self::M2M100(_) => None,
Self::T5(_) => None,
@ -1620,84 +1538,90 @@ impl TokenizerOption {
/// Interface method
pub fn get_mask_value(&self) -> Option<&str> {
match self {
Self::Bert(_) => Some(BertVocab::mask_value()),
Self::Deberta(_) => Some(DeBERTaVocab::mask_value()),
Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()),
Self::Roberta(_) => Some(RobertaVocab::mask_value()),
Self::Bart(_) => Some(RobertaVocab::mask_value()),
Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()),
Self::Albert(_) => Some(AlbertVocab::mask_value()),
Self::XLNet(_) => Some(XLNetVocab::mask_value()),
Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()),
Self::MBart50(_) => Some(MBart50Vocab::mask_value()),
Self::FNet(_er) => Some(FNetVocab::mask_value()),
Self::Bert(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::Deberta(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::DebertaV2(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::Roberta(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::Bart(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::XLMRoberta(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::Albert(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::XLNet(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::ProphetNet(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::MBart50(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::FNet(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::Pegasus(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
Self::M2M100(_) => None,
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
Self::Reformer(_) => None,
Self::Pegasus(_) => None,
}
}
/// Interface method
pub fn get_bos_id(&self) -> Option<i64> {
match *self {
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::bos_value())
.unwrap_or(&0),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::M2M100(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::bos_value())
.unwrap_or(&0),
),
Self::GPT2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(Gpt2Vocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::bos_value())
.expect("BOS token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::Bart(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::DebertaV2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::XLMRoberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::Albert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::XLNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::M2M100(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::GPT2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::Deberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
Self::MBart50(_) => Some(0),
Self::FNet(_) => None,
Self::Bert(_) => None,
@ -1713,90 +1637,62 @@ impl TokenizerOption {
/// Interface method
pub fn get_eos_id(&self) -> Option<i64> {
match *self {
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::eos_value())
.unwrap_or(&2),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::MBart50(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::eos_value())
.unwrap_or(&2),
),
Self::M2M100(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::GPT2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(Gpt2Vocab::eos_value())
.unwrap_or(&2),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::Marian(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MarianVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::T5(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(T5Vocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::Reformer(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ReformerVocab::eos_value())
.expect("EOS token not found in vocabulary"),
),
Self::Pegasus(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(PegasusVocab::eos_value())
.unwrap_or(&1),
),
Self::Roberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::Bart(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::DebertaV2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::XLMRoberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::Albert(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::XLNet(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::MBart50(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::M2M100(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::GPT2(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::Deberta(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::Marian(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::T5(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::Reformer(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::Pegasus(ref tokenizer) => {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
Self::FNet(_) => None,
Self::Bert(_) => None,
Self::ProphetNet(_) => None,

View File

@ -257,8 +257,7 @@ impl MaskedLanguageOption {
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Masked Language is not implemented for {:?}!",
model_type
"Masked Language is not implemented for {model_type:?}!",
))),
}
}

View File

@ -456,8 +456,7 @@ impl QuestionAnsweringOption {
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"QuestionAnswering not implemented for {:?}!",
model_type
"QuestionAnswering not implemented for {model_type:?}!",
))),
}
}

View File

@ -88,8 +88,7 @@ impl SentenceEmbeddingsBuilder<Local> {
ModelType::T5 => (model_dir.join("spiece.model"), None),
_ => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Unsupported transformer model {:?} for Sentence Embeddings",
transformer_type
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings",
)));
}
};

View File

@ -408,7 +408,7 @@ mod serde_sentence_embeddings_module_type {
where
S: Serializer,
{
serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type))
serializer.serialize_str(&format!("sentence_transformers.models.{module_type:?}"))
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<SentenceEmbeddingsModuleType, D::Error>
@ -430,7 +430,7 @@ mod serde_sentence_embeddings_module_type {
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_string())))
.transpose()
.map_err(de::Error::custom)?
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s))
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {s}"))
.map_err(de::Error::custom)
}
}

View File

@ -102,7 +102,7 @@ where
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase())))
.transpose()
.map_err(de::Error::custom)?
.ok_or_else(|| format!("Invalid Activation: {}", activation))
.ok_or_else(|| format!("Invalid Activation: {activation}"))
.map_err(de::Error::custom)
}

View File

@ -65,8 +65,7 @@ impl SentenceEmbeddingsOption {
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
_ => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Unsupported transformer model {:?} for Sentence Embeddings",
transformer_type
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings"
)));
}
};

View File

@ -373,8 +373,7 @@ impl SequenceClassificationOption {
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Sequence Classification not implemented for {:?}!",
model_type
"Sequence Classification not implemented for {model_type:?}!",
))),
}
}

View File

@ -484,8 +484,7 @@ impl TokenClassificationOption {
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Token classification not implemented for {:?}!",
model_type
"Token classification not implemented for {model_type:?}!"
))),
}
}

View File

@ -372,8 +372,7 @@ impl TranslationModelBuilder {
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
"Automated translation model builder not implemented for {model_type:?}"
)));
}
};
@ -459,91 +458,92 @@ mod model_fetchers {
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
let (resources, source_languages, target_languages) =
if let (Some(source_languages), Some(target_languages)) =
(source_languages, target_languages)
{
match (source_languages.as_slice(), target_languages.as_slice()) {
([Language::English], [Language::German]) => {
get_marian_resources!(ENGLISH2GERMAN)
}
([Language::English], [Language::Russian]) => {
get_marian_resources!(ENGLISH2RUSSIAN)
}
([Language::English], [Language::Dutch]) => {
get_marian_resources!(ENGLISH2DUTCH)
}
([Language::English], [Language::ChineseMandarin]) => {
get_marian_resources!(ENGLISH2CHINESE)
}
([Language::English], [Language::Swedish]) => {
get_marian_resources!(ENGLISH2SWEDISH)
}
([Language::English], [Language::Arabic]) => {
get_marian_resources!(ENGLISH2ARABIC)
}
([Language::English], [Language::Hindi]) => {
get_marian_resources!(ENGLISH2HINDI)
}
([Language::English], [Language::Hebrew]) => {
get_marian_resources!(ENGLISH2HEBREW)
}
([Language::German], [Language::English]) => {
get_marian_resources!(GERMAN2ENGLISH)
}
([Language::German], [Language::French]) => {
get_marian_resources!(GERMAN2FRENCH)
}
([Language::French], [Language::German]) => {
get_marian_resources!(FRENCH2GERMAN)
}
([Language::Russian], [Language::English]) => {
get_marian_resources!(RUSSIAN2ENGLISH)
}
([Language::Dutch], [Language::English]) => {
get_marian_resources!(DUTCH2ENGLISH)
}
([Language::ChineseMandarin], [Language::English]) => {
get_marian_resources!(CHINESE2ENGLISH)
}
([Language::Swedish], [Language::English]) => {
get_marian_resources!(SWEDISH2ENGLISH)
}
([Language::Arabic], [Language::English]) => {
get_marian_resources!(ARABIC2ENGLISH)
}
([Language::Hindi], [Language::English]) => {
get_marian_resources!(HINDI2ENGLISH)
}
([Language::Hebrew], [Language::English]) => {
get_marian_resources!(HEBREW2ENGLISH)
}
([Language::English], languages)
if languages
.iter()
.all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) =>
{
get_marian_resources!(ENGLISH2ROMANCE)
}
(languages, [Language::English])
if languages
.iter()
.all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) =>
{
get_marian_resources!(ROMANCE2ENGLISH)
}
(_, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"No Pretrained Marian configuration found for {:?} to {:?} translation",
source_languages, target_languages
)));
}
let (resources, source_languages, target_languages) = if let (
Some(source_languages),
Some(target_languages),
) =
(source_languages, target_languages)
{
match (source_languages.as_slice(), target_languages.as_slice()) {
([Language::English], [Language::German]) => {
get_marian_resources!(ENGLISH2GERMAN)
}
} else {
return Err(RustBertError::InvalidConfigurationError(
"Source and target languages must be provided for Marian models".to_string(),
));
};
([Language::English], [Language::Russian]) => {
get_marian_resources!(ENGLISH2RUSSIAN)
}
([Language::English], [Language::Dutch]) => {
get_marian_resources!(ENGLISH2DUTCH)
}
([Language::English], [Language::ChineseMandarin]) => {
get_marian_resources!(ENGLISH2CHINESE)
}
([Language::English], [Language::Swedish]) => {
get_marian_resources!(ENGLISH2SWEDISH)
}
([Language::English], [Language::Arabic]) => {
get_marian_resources!(ENGLISH2ARABIC)
}
([Language::English], [Language::Hindi]) => {
get_marian_resources!(ENGLISH2HINDI)
}
([Language::English], [Language::Hebrew]) => {
get_marian_resources!(ENGLISH2HEBREW)
}
([Language::German], [Language::English]) => {
get_marian_resources!(GERMAN2ENGLISH)
}
([Language::German], [Language::French]) => {
get_marian_resources!(GERMAN2FRENCH)
}
([Language::French], [Language::German]) => {
get_marian_resources!(FRENCH2GERMAN)
}
([Language::Russian], [Language::English]) => {
get_marian_resources!(RUSSIAN2ENGLISH)
}
([Language::Dutch], [Language::English]) => {
get_marian_resources!(DUTCH2ENGLISH)
}
([Language::ChineseMandarin], [Language::English]) => {
get_marian_resources!(CHINESE2ENGLISH)
}
([Language::Swedish], [Language::English]) => {
get_marian_resources!(SWEDISH2ENGLISH)
}
([Language::Arabic], [Language::English]) => {
get_marian_resources!(ARABIC2ENGLISH)
}
([Language::Hindi], [Language::English]) => {
get_marian_resources!(HINDI2ENGLISH)
}
([Language::Hebrew], [Language::English]) => {
get_marian_resources!(HEBREW2ENGLISH)
}
([Language::English], languages)
if languages
.iter()
.all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) =>
{
get_marian_resources!(ENGLISH2ROMANCE)
}
(languages, [Language::English])
if languages
.iter()
.all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) =>
{
get_marian_resources!(ROMANCE2ENGLISH)
}
(_, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"No Pretrained Marian configuration found for {source_languages:?} to {target_languages:?} translation",
)));
}
}
} else {
return Err(RustBertError::InvalidConfigurationError(
"Source and target languages must be provided for Marian models".to_string(),
));
};
Ok(TranslationResources {
model_type: ModelType::Marian,

View File

@ -135,7 +135,7 @@ pub enum Language {
impl Display for Language {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", {
let input_string = format!("{:?}", self);
let input_string = format!("{self:?}");
let mut output: Vec<&str> = Vec::new();
let mut start: usize = 0;
@ -584,8 +584,7 @@ impl TranslationOption {
if let Some(source_language) = source_language {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
source_language, supported_source_languages
"{source_language} not in list of supported languages: {supported_source_languages:?}",
)));
}
}
@ -593,8 +592,7 @@ impl TranslationOption {
if let Some(target_language) = target_language {
if !supported_target_languages.contains(target_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
target_language, supported_target_languages
"{target_language} not in list of supported languages: {supported_target_languages:?}"
)));
}
}
@ -610,9 +608,8 @@ impl TranslationOption {
None => {
return Err(RustBertError::ValueError(format!(
"Missing target language for Marian \
(multiple languages supported by model: {:?}, \
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)",
supported_target_languages
)));
}
}
@ -653,9 +650,8 @@ impl TranslationOption {
None => {
return Err(RustBertError::ValueError(format!(
"Missing source language for MBart\
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_source_languages
(multiple languages supported by model: {supported_source_languages:?}, \
need to specify target language)"
)));
}
}
@ -670,9 +666,8 @@ impl TranslationOption {
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for MBart\
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_target_languages
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)"
)));
},
),
@ -681,8 +676,8 @@ impl TranslationOption {
Some(value) => {
let language_code = value.get_iso_639_1_code();
match language_code.len() {
2 => format!(">>{}.<< ", language_code),
3 => format!(">>{}<< ", language_code),
2 => format!(">>{language_code}.<< "),
3 => format!(">>{language_code}<< "),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
@ -693,9 +688,8 @@ impl TranslationOption {
None => {
return Err(RustBertError::ValueError(format!(
"Missing source language for M2M100 \
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_source_languages
(multiple languages supported by model: {supported_source_languages:?}, \
need to specify target language)"
)));
}
}),
@ -704,8 +698,8 @@ impl TranslationOption {
Some(
model._get_tokenizer().convert_tokens_to_ids(&[
match language_code.len() {
2 => format!(">>{}.<<", language_code),
3 => format!(">>{}<<", language_code),
2 => format!(">>{language_code}.<<"),
3 => format!(">>{language_code}<<"),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
@ -717,9 +711,8 @@ impl TranslationOption {
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for M2M100 \
(multiple languages supported by model: {:?}, \
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)",
supported_target_languages
)));
},
),

View File

@ -370,8 +370,7 @@ impl ZeroShotClassificationOption {
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Zero shot classification not implemented for {:?}!",
model_type
"Zero shot classification not implemented for {model_type:?}!",
))),
}
}
@ -604,7 +603,7 @@ impl ZeroShotClassificationModel {
None => labels
.as_ref()
.iter()
.map(|label| format!("This example is about {}.", label))
.map(|label| format!("This example is about {label}."))
.collect(),
};

View File

@ -14,7 +14,7 @@ use std::borrow::Borrow;
use std::collections::HashMap;
use rust_tokenizers::tokenizer::{ProphetNetTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{ProphetNetVocab, Vocab};
use rust_tokenizers::vocab::ProphetNetVocab;
use serde::{Deserialize, Serialize};
use tch::{nn, Kind, Tensor};
@ -1098,9 +1098,7 @@ impl
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0],
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids

View File

@ -548,7 +548,7 @@ impl ReformerModelWithLMHead {
if let Some(lsh_num_chunks_after) = config.lsh_num_chunks_after {
if config.attn_layers.contains(&AttentionType::lsh) & (lsh_num_chunks_after != 0) {
return Err(RustBertError::InvalidConfigurationError(
format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {})", lsh_num_chunks_after),
format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {lsh_num_chunks_after})"),
));
}
}
@ -556,7 +556,7 @@ impl ReformerModelWithLMHead {
if let Some(local_num_chunks_after) = config.local_num_chunks_after {
if config.attn_layers.contains(&AttentionType::local) & (local_num_chunks_after != 0) {
return Err(RustBertError::InvalidConfigurationError(
format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {})", local_num_chunks_after),
format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {local_num_chunks_after})"),
));
}
}

View File

@ -215,7 +215,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
format!("This example is about {label}.")
})),
128,
)?;
@ -244,7 +244,7 @@ fn bart_zero_shot_classification_try_error() -> anyhow::Result<()> {
[],
[],
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
format!("This example is about {label}.")
})),
128,
);
@ -276,7 +276,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
format!("This example is about {label}.")
})),
128,
)?;
@ -319,7 +319,7 @@ fn bart_zero_shot_classification_multilabel_try_error() -> anyhow::Result<()> {
[],
[],
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
format!("This example is about {label}.")
})),
128,
);

View File

@ -14,7 +14,7 @@ use rust_bert::pipelines::question_answering::{
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
use rust_tokenizers::vocab::Vocab;
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
@ -67,7 +67,9 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![
tokenizer.vocab().token_to_id(RobertaVocab::pad_value());
tokenizer
.vocab()
.token_to_id(tokenizer.vocab().get_pad_value());
max_len - input.len()
]);
input

View File

@ -63,8 +63,8 @@ about exoplanets like K2-18b."];
assert_eq!(
output[0],
"scientists have confirmed the presence of water in the atmosphere of k2 - 18b. \
[X_SEP] this is the first such discovery in a planet in its star's habitable zone. \
[X_SEP] the planet is 110 light - years from earth and has a star in the constellation leo."
this is the first such discovery in a planet in its star's habitable zone. \
the planet is 110 light - years from earth and has a star in the constellation leo."
);
Ok(())