diff --git a/CHANGELOG.md b/CHANGELOG.md index d9eb037..00b0c8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## Changed +- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer. ## [0.20.0] - 2023-01-21 ## Added diff --git a/Cargo.toml b/Cargo.toml index 797b319..6a3e38b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-bert" -version = "0.20.0" +version = "0.20.1-alpha" authors = ["Guillaume Becquin "] edition = "2018" description = "Ready-to-use NLP pipelines and language models" @@ -69,7 +69,7 @@ remote = ["cached-path", "dirs", "lazy_static"] features = ["doc-only"] [dependencies] -rust_tokenizers = "~7.0.2" +rust_tokenizers = "8.0.0" tch = "~0.10.1" serde_json = "1" serde = { version = "1", features = ["derive"] } diff --git a/examples/async-sentiment.rs b/examples/async-sentiment.rs index e994a8b..9232773 100644 --- a/examples/async-sentiment.rs +++ b/examples/async-sentiment.rs @@ -16,7 +16,7 @@ async fn main() -> Result<()> { "Classify this negative text".to_owned(), ]; let sentiments = classifier.predict(texts).await?; - println!("Results: {:?}", sentiments); + println!("Results: {sentiments:?}"); Ok(()) } diff --git a/examples/codebert.rs b/examples/codebert.rs index d683484..d544297 100644 --- a/examples/codebert.rs +++ b/examples/codebert.rs @@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = sequence_classification_model.predict(input); for label in output { - println!("{:?}", label); + println!("{label:?}"); } // Masked language model @@ -78,7 +78,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = mask_language_model.predict(input)?; for sentence_output in output { - println!("{:?}", sentence_output); + println!("{sentence_output:?}"); } Ok(()) diff --git a/examples/conversation.rs b/examples/conversation.rs index 4b84516..b9976c3 100644 --- a/examples/conversation.rs +++ b/examples/conversation.rs @@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> { let output = conversation_model.generate_responses(&mut conversation_manager); - println!("{:?}", output); + println!("{output:?}"); let _ = conversation_manager .get(&conversation_1_id) @@ -40,11 +40,11 @@ fn main() -> anyhow::Result<()> { let output = conversation_model.generate_responses(&mut conversation_manager); - println!("{:?}", output); + println!("{output:?}"); let output = conversation_model.generate_responses(&mut conversation_manager); - println!("{:?}", output); + println!("{output:?}"); Ok(()) } diff --git a/examples/generation_gpt2.rs b/examples/generation_gpt2.rs index 49931c2..8c86e0e 100644 --- a/examples/generation_gpt2.rs +++ b/examples/generation_gpt2.rs @@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> { let output = model.generate(&[input_context], None); for sentence in output { - println!("{:?}", sentence); + println!("{sentence:?}"); } Ok(()) } diff --git a/examples/generation_gpt_neo.rs b/examples/generation_gpt_neo.rs index 4b884cd..19a644b 100644 --- a/examples/generation_gpt_neo.rs +++ b/examples/generation_gpt_neo.rs @@ -60,7 +60,7 @@ fn main() -> anyhow::Result<()> { let output = model.generate(&[input_context_1, input_context_2], None); for sentence in output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/generation_reformer.rs b/examples/generation_reformer.rs index 6cc01ca..8c25077 100644 --- a/examples/generation_reformer.rs +++ b/examples/generation_reformer.rs @@ -55,7 +55,7 @@ fn main() -> anyhow::Result<()> { let output = model.generate(&[input_context_1, input_context_2], None); for sentence in output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/generation_xlnet.rs b/examples/generation_xlnet.rs index 1c63313..ff75321 100644 --- a/examples/generation_xlnet.rs +++ b/examples/generation_xlnet.rs @@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> { let output = model.generate(&[input_context], None); for sentence in output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/masked_language.rs b/examples/masked_language.rs index bd3db54..d69f0ee 100644 --- a/examples/masked_language.rs +++ b/examples/masked_language.rs @@ -39,7 +39,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = mask_language_model.predict(input)?; for sentence_output in output { - println!("{:?}", sentence_output); + println!("{sentence_output:?}"); } Ok(()) diff --git a/examples/named_entities_recognition.rs b/examples/named_entities_recognition.rs index a25975b..5b28907 100644 --- a/examples/named_entities_recognition.rs +++ b/examples/named_entities_recognition.rs @@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = ner_model.predict_full_entities(&input); for entity in output { - println!("{:?}", entity); + println!("{entity:?}"); } Ok(()) diff --git a/examples/part_of_speech_tagging.rs b/examples/part_of_speech_tagging.rs index 8656650..1abc2c3 100644 --- a/examples/part_of_speech_tagging.rs +++ b/examples/part_of_speech_tagging.rs @@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = pos_model.predict(&input); for (pos, pos_tag) in output[0].iter().enumerate() { - println!("{} - {:?}", pos, pos_tag); + println!("{pos} - {pos_tag:?}"); } Ok(()) diff --git a/examples/question_answering.rs b/examples/question_answering.rs index 1b3b89c..26012d5 100644 --- a/examples/question_answering.rs +++ b/examples/question_answering.rs @@ -34,6 +34,6 @@ fn main() -> anyhow::Result<()> { // Get answer let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32); - println!("{:?}", answers); + println!("{answers:?}"); Ok(()) } diff --git a/examples/question_answering_bert.rs b/examples/question_answering_bert.rs index 86d1884..892dc4c 100644 --- a/examples/question_answering_bert.rs +++ b/examples/question_answering_bert.rs @@ -50,6 +50,6 @@ fn main() -> anyhow::Result<()> { // Get answer let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32); - println!("{:?}", answers); + println!("{answers:?}"); Ok(()) } diff --git a/examples/question_answering_longformer.rs b/examples/question_answering_longformer.rs index c1063a4..790a0c5 100644 --- a/examples/question_answering_longformer.rs +++ b/examples/question_answering_longformer.rs @@ -55,6 +55,6 @@ fn main() -> anyhow::Result<()> { // Get answer let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32); - println!("{:?}", answers); + println!("{answers:?}"); Ok(()) } diff --git a/examples/sentence_embeddings.rs b/examples/sentence_embeddings.rs index 9bf8304..f280e6c 100644 --- a/examples/sentence_embeddings.rs +++ b/examples/sentence_embeddings.rs @@ -12,6 +12,6 @@ fn main() -> anyhow::Result<()> { // Generate Embeddings let embeddings = model.encode(&sentences)?; - println!("{:?}", embeddings); + println!("{embeddings:?}"); Ok(()) } diff --git a/examples/sentence_embeddings_local.rs b/examples/sentence_embeddings_local.rs index f4d43d0..8ea7cf8 100644 --- a/examples/sentence_embeddings_local.rs +++ b/examples/sentence_embeddings_local.rs @@ -32,6 +32,6 @@ fn main() -> anyhow::Result<()> { // Generate Embeddings let embeddings = model.encode(&sentences)?; - println!("{:?}", embeddings); + println!("{embeddings:?}"); Ok(()) } diff --git a/examples/sentiment_analysis.rs b/examples/sentiment_analysis.rs index f339bbf..309d016 100644 --- a/examples/sentiment_analysis.rs +++ b/examples/sentiment_analysis.rs @@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = sentiment_classifier.predict(input); for sentiment in output { - println!("{:?}", sentiment); + println!("{sentiment:?}"); } Ok(()) diff --git a/examples/sentiment_analysis_fnet.rs b/examples/sentiment_analysis_fnet.rs index 7d942ce..1032ddf 100644 --- a/examples/sentiment_analysis_fnet.rs +++ b/examples/sentiment_analysis_fnet.rs @@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = sentiment_classifier.predict(input); for sentiment in output { - println!("{:?}", sentiment); + println!("{sentiment:?}"); } Ok(()) diff --git a/examples/sequence_classification.rs b/examples/sequence_classification.rs index 200d4e5..0f15bac 100644 --- a/examples/sequence_classification.rs +++ b/examples/sequence_classification.rs @@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> { // Run model let output = sequence_classification_model.predict(input); for label in output { - println!("{:?}", label); + println!("{label:?}"); } Ok(()) diff --git a/examples/sequence_classification_multilabel.rs b/examples/sequence_classification_multilabel.rs index af27a28..42f26bf 100644 --- a/examples/sequence_classification_multilabel.rs +++ b/examples/sequence_classification_multilabel.rs @@ -29,7 +29,7 @@ fn main() -> anyhow::Result<()> { let output = sequence_classification_model.predict_multilabel(&input, 0.05); if let Ok(labels) = output { for label in labels { - println!("{:?}", label); + println!("{label:?}"); } } diff --git a/examples/summarization_bart.rs b/examples/summarization_bart.rs index 29ecad2..3d9ce0a 100644 --- a/examples/summarization_bart.rs +++ b/examples/summarization_bart.rs @@ -73,7 +73,7 @@ about exoplanets like K2-18b."]; // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) let _output = summarization_model.summarize(&input); for sentence in _output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) diff --git a/examples/summarization_pegasus.rs b/examples/summarization_pegasus.rs index 48564aa..f68546d 100644 --- a/examples/summarization_pegasus.rs +++ b/examples/summarization_pegasus.rs @@ -68,7 +68,7 @@ about exoplanets like K2-18b."]; // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) let _output = summarization_model.summarize(&input); for sentence in _output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) diff --git a/examples/summarization_prophetnet.rs b/examples/summarization_prophetnet.rs index d4d3a2a..3825228 100644 --- a/examples/summarization_prophetnet.rs +++ b/examples/summarization_prophetnet.rs @@ -70,7 +70,7 @@ about exoplanets like K2-18b."]; // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) let _output = summarization_model.summarize(&input); for sentence in _output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) diff --git a/examples/summarization_t5.rs b/examples/summarization_t5.rs index 643d8c6..7087184 100644 --- a/examples/summarization_t5.rs +++ b/examples/summarization_t5.rs @@ -56,7 +56,7 @@ about exoplanets like K2-18b."]; // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) let _output = summarization_model.summarize(&input); for sentence in _output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) diff --git a/examples/token_classification.rs b/examples/token_classification.rs index 5226358..dd5223b 100644 --- a/examples/token_classification.rs +++ b/examples/token_classification.rs @@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> { let token_outputs = token_classification_model.predict(&input); for token in token_outputs { - println!("{:?}", token); + println!("{token:?}"); } Ok(()) diff --git a/examples/translation_builder.rs b/examples/translation_builder.rs index b457306..f4086c5 100644 --- a/examples/translation_builder.rs +++ b/examples/translation_builder.rs @@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> { let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?; for sentence in output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs index dade34c..5efc1f6 100644 --- a/examples/translation_m2m100.rs +++ b/examples/translation_m2m100.rs @@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> { outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); for sentence in outputs { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/translation_marian.rs b/examples/translation_marian.rs index 9a8156a..e62a2e5 100644 --- a/examples/translation_marian.rs +++ b/examples/translation_marian.rs @@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> { let output = model.translate(&[input_context_1, input_context_2], None, None)?; for sentence in output { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/translation_mbart.rs b/examples/translation_mbart.rs index c7d89a1..102f774 100644 --- a/examples/translation_mbart.rs +++ b/examples/translation_mbart.rs @@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> { outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); for sentence in outputs { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/translation_t5.rs b/examples/translation_t5.rs index aeca88d..bc05555 100644 --- a/examples/translation_t5.rs +++ b/examples/translation_t5.rs @@ -56,7 +56,7 @@ fn main() -> anyhow::Result<()> { outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?); for sentence in outputs { - println!("{}", sentence); + println!("{sentence}"); } Ok(()) } diff --git a/examples/zero_shot_classification.rs b/examples/zero_shot_classification.rs index 3391026..5d41b66 100644 --- a/examples/zero_shot_classification.rs +++ b/examples/zero_shot_classification.rs @@ -27,13 +27,13 @@ fn main() -> anyhow::Result<()> { [input_sentence, input_sequence_2], candidate_labels, Some(Box::new(|label: &str| { - format!("This example is about {}.", label) + format!("This example is about {label}.") })), 128, ) .unwrap(); - println!("{:?}", output); + println!("{output:?}"); Ok(()) } diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index 6ec22e3..8bde2d2 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -26,7 +26,7 @@ use crate::pipelines::generation_utils::{ }; use crate::{Config, RustBertError}; use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::{RobertaVocab, Vocab}; +use rust_tokenizers::vocab::RobertaVocab; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; @@ -1263,9 +1263,7 @@ impl PrivateLanguageGenerator value, - None => self - ._get_tokenizer() - .convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0], + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids diff --git a/src/common/kind.rs b/src/common/kind.rs index a3d4f9e..60b3dab 100644 --- a/src/common/kind.rs +++ b/src/common/kind.rs @@ -14,8 +14,7 @@ pub(crate) fn get_positive_infinity(kind: Kind) -> Result Kind::Double => Scalar::float(f64::INFINITY), _ => { return Err(RustBertError::ValueError(format!( - "Type not supported: attempted to get positive infinity for {:?}", - kind + "Type not supported: attempted to get positive infinity for {kind:?}", ))) } }) @@ -34,8 +33,7 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result Kind::Double => Scalar::float(f64::NEG_INFINITY), _ => { return Err(RustBertError::ValueError(format!( - "Type not supported: attempted to get negative infinity for {:?}", - kind + "Type not supported: attempted to get negative infinity for {kind:?}", ))) } }) diff --git a/src/deberta/deberta_model.rs b/src/deberta/deberta_model.rs index 0196638..00f8a9c 100644 --- a/src/deberta/deberta_model.rs +++ b/src/deberta/deberta_model.rs @@ -111,8 +111,7 @@ impl FromStr for PositionAttentionType { "c2p" => Ok(PositionAttentionType::c2p), "p2p" => Ok(PositionAttentionType::p2p), _ => Err(RustBertError::InvalidConfigurationError(format!( - "Position attention type `{}` not in accepted variants (`p2c`, `c2p`, `p2p`)", - s + "Position attention type `{s}` not in accepted variants (`p2c`, `c2p`, `p2p`)", ))), } } diff --git a/src/deberta_v2/deberta_v2_model.rs b/src/deberta_v2/deberta_v2_model.rs index 8ae4541..857b00c 100644 --- a/src/deberta_v2/deberta_v2_model.rs +++ b/src/deberta_v2/deberta_v2_model.rs @@ -118,8 +118,7 @@ impl FromStr for NormRelEmbedType { match s { "layer_norm" => Ok(NormRelEmbedType::layer_norm), _ => Err(RustBertError::InvalidConfigurationError(format!( - "Layer normalization type `{}` not in accepted variants (`layer_norm`)", - s + "Layer normalization type `{s}` not in accepted variants (`layer_norm`)", ))), } } diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index 0bffbe8..59e4b58 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -24,7 +24,7 @@ use crate::pipelines::generation_utils::{ use crate::pipelines::translation::Language; use crate::{Config, RustBertError}; use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::{M2M100Vocab, Vocab}; +use rust_tokenizers::vocab::M2M100Vocab; use std::borrow::Borrow; use tch::nn::{embedding, EmbeddingConfig}; use tch::{nn, Kind, Tensor}; @@ -804,9 +804,7 @@ impl PrivateLanguageGenerator value, - None => self - ._get_tokenizer() - .convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0], + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index a2ea85c..0dde3e0 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -885,7 +885,9 @@ impl MarianGenerator { let vocab_size = config.vocab_size; let is_encoder_decoder = true; let decoder_start_id = - Some(tokenizer.convert_tokens_to_ids(&[MarianVocab::pad_value()])[0]); + Some(tokenizer.get_pad_id().ok_or(RustBertError::TokenizerError( + "The tokenizer must contain a pad token ID to be used as BOS".to_string(), + ))?); let max_position_embeddings = config.max_position_embeddings; Ok(MarianGenerator { diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index 93ab313..471f562 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -25,7 +25,7 @@ use crate::pipelines::generation_utils::{ use crate::pipelines::translation::Language; use crate::{Activation, Config, RustBertError}; use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::{MBart50Vocab, Vocab}; +use rust_tokenizers::vocab::MBart50Vocab; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; @@ -1059,9 +1059,7 @@ impl PrivateLanguageGenerator value, - None => self - ._get_tokenizer() - .convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0], + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index bed3108..24569f5 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -780,7 +780,8 @@ impl PrivateLanguageGenerator value, None => self ._get_tokenizer() - .convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0], + .get_pad_id() + .expect("A padding token must be provided to encode prompt texts."), }; let token_ids = token_ids diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index 8ecedcd..0b3ea05 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -46,11 +46,7 @@ use rust_tokenizers::tokenizer::{ OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer, }; -use rust_tokenizers::vocab::{ - AlbertVocab, BertVocab, DeBERTaV2Vocab, DeBERTaVocab, FNetVocab, Gpt2Vocab, M2M100Vocab, - MBart50Vocab, MarianVocab, OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab, - RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab, -}; +use rust_tokenizers::vocab::Vocab; use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -1275,174 +1271,144 @@ impl TokenizerOption { /// Interface method pub fn get_unk_id(&self) -> i64 { match *self { - Self::Bert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(BertVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Deberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::DebertaV2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaV2Vocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Roberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Bart(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::XLMRoberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLMRobertaVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Marian(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MarianVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::T5(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(T5Vocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Albert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(AlbertVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::XLNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLNetVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::GPT2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(Gpt2Vocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::OpenAiGpt(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(OpenAiGptVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Reformer(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(ReformerVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::ProphetNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(ProphetNetVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::Pegasus(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(PegasusVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::MBart50(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MBart50Vocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::M2M100(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(M2M100Vocab::unknown_value()) - .expect("UNK token not found in vocabulary"), - Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(FNetVocab::unknown_value()) - .expect("UNK token not found in vocabulary"), + Self::Bert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Deberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::DebertaV2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Roberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Bart(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::XLMRoberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Marian(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::T5(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Albert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::XLNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::GPT2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::OpenAiGpt(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Reformer(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::ProphetNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::Pegasus(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::MBart50(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::M2M100(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } + Self::FNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + vocab.token_to_id(vocab.get_unknown_value()) + } } } /// Interface method pub fn get_pad_id(&self) -> Option { match *self { - Self::Bert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(BertVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::Deberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::DebertaV2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaV2Vocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::Roberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::Bart(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::pad_value()) - .unwrap_or(&1), - ), - Self::XLMRoberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLMRobertaVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::Marian(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MarianVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::T5(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(T5Vocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::Albert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(AlbertVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::XLNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLNetVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::ProphetNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(ProphetNetVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::Pegasus(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(PegasusVocab::pad_value()) - .unwrap_or(&0), - ), - Self::MBart50(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MBart50Vocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), - Self::M2M100(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(M2M100Vocab::pad_value()) - .unwrap_or(&1), - ), - Self::FNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(FNetVocab::pad_value()) - .expect("PAD token not found in vocabulary"), - ), + Self::Bert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::Deberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::DebertaV2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::Roberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::Bart(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::XLMRoberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::Marian(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::T5(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::Albert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::XLNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::ProphetNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::Pegasus(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::MBart50(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::M2M100(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } + Self::FNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_pad_value())) + } Self::Reformer(_) => None, Self::GPT2(_) => None, Self::OpenAiGpt(_) => None, @@ -1452,78 +1418,54 @@ impl TokenizerOption { /// Interface method pub fn get_sep_id(&self) -> Option { match *self { - Self::Bert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(BertVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::Deberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::DebertaV2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaV2Vocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::Roberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::Bart(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::XLMRoberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLMRobertaVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::Albert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(AlbertVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::XLNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLNetVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::ProphetNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(ProphetNetVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::MBart50(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MBart50Vocab::sep_value()) - .unwrap_or(&1), - ), - Self::M2M100(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(M2M100Vocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), - Self::FNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(FNetVocab::sep_value()) - .expect("SEP token not found in vocabulary"), - ), + Self::Bert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::Deberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::DebertaV2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::Roberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::Bart(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::XLMRoberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::Albert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::XLNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::ProphetNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::MBart50(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::M2M100(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } + Self::FNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_sep_value())) + } Self::Marian(_) => None, Self::T5(_) => None, Self::GPT2(_) => None, @@ -1536,78 +1478,54 @@ impl TokenizerOption { /// Interface method pub fn get_mask_id(&self) -> Option { match *self { - Self::Bert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(BertVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::Deberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::DebertaV2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaV2Vocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::Roberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::Bart(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::XLMRoberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLMRobertaVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::Albert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(AlbertVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::XLNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLNetVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::ProphetNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(ProphetNetVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::MBart50(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MBart50Vocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::FNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(FNetVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), - Self::Pegasus(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(PegasusVocab::mask_value()) - .expect("MASK token not found in vocabulary"), - ), + Self::Bert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::Deberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::DebertaV2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::Roberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::Bart(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::XLMRoberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::Albert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::XLNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::ProphetNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::MBart50(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::FNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } + Self::Pegasus(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_mask_value())) + } Self::Marian(_) => None, Self::M2M100(_) => None, Self::T5(_) => None, @@ -1620,84 +1538,90 @@ impl TokenizerOption { /// Interface method pub fn get_mask_value(&self) -> Option<&str> { match self { - Self::Bert(_) => Some(BertVocab::mask_value()), - Self::Deberta(_) => Some(DeBERTaVocab::mask_value()), - Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()), - Self::Roberta(_) => Some(RobertaVocab::mask_value()), - Self::Bart(_) => Some(RobertaVocab::mask_value()), - Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()), - Self::Albert(_) => Some(AlbertVocab::mask_value()), - Self::XLNet(_) => Some(XLNetVocab::mask_value()), - Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()), - Self::MBart50(_) => Some(MBart50Vocab::mask_value()), - Self::FNet(_er) => Some(FNetVocab::mask_value()), + Self::Bert(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::Deberta(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::DebertaV2(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::Roberta(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::Bart(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::XLMRoberta(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::Albert(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::XLNet(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::ProphetNet(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::MBart50(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::FNet(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } + Self::Pegasus(ref tokenizer) => { + Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) + } Self::M2M100(_) => None, Self::Marian(_) => None, Self::T5(_) => None, Self::GPT2(_) => None, Self::OpenAiGpt(_) => None, Self::Reformer(_) => None, - Self::Pegasus(_) => None, } } /// Interface method pub fn get_bos_id(&self) -> Option { match *self { - Self::Roberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), - Self::Bart(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::bos_value()) - .unwrap_or(&0), - ), - Self::DebertaV2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaV2Vocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), - Self::XLMRoberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLMRobertaVocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), - Self::Albert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(AlbertVocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), - Self::XLNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLNetVocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), - Self::M2M100(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(M2M100Vocab::bos_value()) - .unwrap_or(&0), - ), - Self::GPT2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(Gpt2Vocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), - Self::Deberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaVocab::bos_value()) - .expect("BOS token not found in vocabulary"), - ), + Self::Roberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::Bart(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::DebertaV2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::XLMRoberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::Albert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::XLNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::M2M100(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::GPT2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } + Self::Deberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_bos_value())) + } Self::MBart50(_) => Some(0), Self::FNet(_) => None, Self::Bert(_) => None, @@ -1713,90 +1637,62 @@ impl TokenizerOption { /// Interface method pub fn get_eos_id(&self) -> Option { match *self { - Self::Roberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::Bart(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(RobertaVocab::eos_value()) - .unwrap_or(&2), - ), - Self::DebertaV2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaV2Vocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::XLMRoberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLMRobertaVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::Albert(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(AlbertVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::XLNet(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(XLNetVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::MBart50(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MBart50Vocab::eos_value()) - .unwrap_or(&2), - ), - Self::M2M100(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(M2M100Vocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::GPT2(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(Gpt2Vocab::eos_value()) - .unwrap_or(&2), - ), - Self::Deberta(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(DeBERTaVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::Marian(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(MarianVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::T5(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(T5Vocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::Reformer(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(ReformerVocab::eos_value()) - .expect("EOS token not found in vocabulary"), - ), - Self::Pegasus(ref tokenizer) => Some( - *MultiThreadedTokenizer::vocab(tokenizer) - .special_values - .get(PegasusVocab::eos_value()) - .unwrap_or(&1), - ), + Self::Roberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::Bart(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::DebertaV2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::XLMRoberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::Albert(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::XLNet(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::MBart50(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::M2M100(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::GPT2(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::Deberta(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::Marian(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::T5(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::Reformer(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } + Self::Pegasus(ref tokenizer) => { + let vocab = MultiThreadedTokenizer::vocab(tokenizer); + Some(vocab.token_to_id(vocab.get_eos_value())) + } Self::FNet(_) => None, Self::Bert(_) => None, Self::ProphetNet(_) => None, diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs index 29ad5ec..c2c48c3 100644 --- a/src/pipelines/masked_language.rs +++ b/src/pipelines/masked_language.rs @@ -257,8 +257,7 @@ impl MaskedLanguageOption { } } _ => Err(RustBertError::InvalidConfigurationError(format!( - "Masked Language is not implemented for {:?}!", - model_type + "Masked Language is not implemented for {model_type:?}!", ))), } } diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index d6b6c30..8401189 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -456,8 +456,7 @@ impl QuestionAnsweringOption { } } _ => Err(RustBertError::InvalidConfigurationError(format!( - "QuestionAnswering not implemented for {:?}!", - model_type + "QuestionAnswering not implemented for {model_type:?}!", ))), } } diff --git a/src/pipelines/sentence_embeddings/builder.rs b/src/pipelines/sentence_embeddings/builder.rs index a35947a..ad26ea2 100644 --- a/src/pipelines/sentence_embeddings/builder.rs +++ b/src/pipelines/sentence_embeddings/builder.rs @@ -88,8 +88,7 @@ impl SentenceEmbeddingsBuilder { ModelType::T5 => (model_dir.join("spiece.model"), None), _ => { return Err(RustBertError::InvalidConfigurationError(format!( - "Unsupported transformer model {:?} for Sentence Embeddings", - transformer_type + "Unsupported transformer model {transformer_type:?} for Sentence Embeddings", ))); } }; diff --git a/src/pipelines/sentence_embeddings/config.rs b/src/pipelines/sentence_embeddings/config.rs index 81bf059..8901c3c 100644 --- a/src/pipelines/sentence_embeddings/config.rs +++ b/src/pipelines/sentence_embeddings/config.rs @@ -408,7 +408,7 @@ mod serde_sentence_embeddings_module_type { where S: Serializer, { - serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type)) + serializer.serialize_str(&format!("sentence_transformers.models.{module_type:?}")) } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -430,7 +430,7 @@ mod serde_sentence_embeddings_module_type { .map(|s| serde_json::from_value(serde_json::Value::String(s.to_string()))) .transpose() .map_err(de::Error::custom)? - .ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s)) + .ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {s}")) .map_err(de::Error::custom) } } diff --git a/src/pipelines/sentence_embeddings/layers.rs b/src/pipelines/sentence_embeddings/layers.rs index fa63282..e188500 100644 --- a/src/pipelines/sentence_embeddings/layers.rs +++ b/src/pipelines/sentence_embeddings/layers.rs @@ -102,7 +102,7 @@ where .map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase()))) .transpose() .map_err(de::Error::custom)? - .ok_or_else(|| format!("Invalid Activation: {}", activation)) + .ok_or_else(|| format!("Invalid Activation: {activation}")) .map_err(de::Error::custom) } diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs index fa4d35b..5ed9f09 100644 --- a/src/pipelines/sentence_embeddings/pipeline.rs +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -65,8 +65,7 @@ impl SentenceEmbeddingsOption { ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))), _ => { return Err(RustBertError::InvalidConfigurationError(format!( - "Unsupported transformer model {:?} for Sentence Embeddings", - transformer_type + "Unsupported transformer model {transformer_type:?} for Sentence Embeddings" ))); } }; diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 667c236..dac067a 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -373,8 +373,7 @@ impl SequenceClassificationOption { } } _ => Err(RustBertError::InvalidConfigurationError(format!( - "Sequence Classification not implemented for {:?}!", - model_type + "Sequence Classification not implemented for {model_type:?}!", ))), } } diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index ffb41c6..9ed4c64 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -484,8 +484,7 @@ impl TokenClassificationOption { } } _ => Err(RustBertError::InvalidConfigurationError(format!( - "Token classification not implemented for {:?}!", - model_type + "Token classification not implemented for {model_type:?}!" ))), } } diff --git a/src/pipelines/translation/translation_builder.rs b/src/pipelines/translation/translation_builder.rs index f02d461..ef286da 100644 --- a/src/pipelines/translation/translation_builder.rs +++ b/src/pipelines/translation/translation_builder.rs @@ -372,8 +372,7 @@ impl TranslationModelBuilder { } (Some(model_type), _, _) => { return Err(RustBertError::InvalidConfigurationError(format!( - "Automated translation model builder not implemented for {:?}", - model_type + "Automated translation model builder not implemented for {model_type:?}" ))); } }; @@ -459,91 +458,92 @@ mod model_fetchers { source_languages: Option<&Vec>, target_languages: Option<&Vec>, ) -> Result, RustBertError> { - let (resources, source_languages, target_languages) = - if let (Some(source_languages), Some(target_languages)) = - (source_languages, target_languages) - { - match (source_languages.as_slice(), target_languages.as_slice()) { - ([Language::English], [Language::German]) => { - get_marian_resources!(ENGLISH2GERMAN) - } - ([Language::English], [Language::Russian]) => { - get_marian_resources!(ENGLISH2RUSSIAN) - } - ([Language::English], [Language::Dutch]) => { - get_marian_resources!(ENGLISH2DUTCH) - } - ([Language::English], [Language::ChineseMandarin]) => { - get_marian_resources!(ENGLISH2CHINESE) - } - ([Language::English], [Language::Swedish]) => { - get_marian_resources!(ENGLISH2SWEDISH) - } - ([Language::English], [Language::Arabic]) => { - get_marian_resources!(ENGLISH2ARABIC) - } - ([Language::English], [Language::Hindi]) => { - get_marian_resources!(ENGLISH2HINDI) - } - ([Language::English], [Language::Hebrew]) => { - get_marian_resources!(ENGLISH2HEBREW) - } - ([Language::German], [Language::English]) => { - get_marian_resources!(GERMAN2ENGLISH) - } - ([Language::German], [Language::French]) => { - get_marian_resources!(GERMAN2FRENCH) - } - ([Language::French], [Language::German]) => { - get_marian_resources!(FRENCH2GERMAN) - } - ([Language::Russian], [Language::English]) => { - get_marian_resources!(RUSSIAN2ENGLISH) - } - ([Language::Dutch], [Language::English]) => { - get_marian_resources!(DUTCH2ENGLISH) - } - ([Language::ChineseMandarin], [Language::English]) => { - get_marian_resources!(CHINESE2ENGLISH) - } - ([Language::Swedish], [Language::English]) => { - get_marian_resources!(SWEDISH2ENGLISH) - } - ([Language::Arabic], [Language::English]) => { - get_marian_resources!(ARABIC2ENGLISH) - } - ([Language::Hindi], [Language::English]) => { - get_marian_resources!(HINDI2ENGLISH) - } - ([Language::Hebrew], [Language::English]) => { - get_marian_resources!(HEBREW2ENGLISH) - } - ([Language::English], languages) - if languages - .iter() - .all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) => - { - get_marian_resources!(ENGLISH2ROMANCE) - } - (languages, [Language::English]) - if languages - .iter() - .all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) => - { - get_marian_resources!(ROMANCE2ENGLISH) - } - (_, _) => { - return Err(RustBertError::InvalidConfigurationError(format!( - "No Pretrained Marian configuration found for {:?} to {:?} translation", - source_languages, target_languages - ))); - } + let (resources, source_languages, target_languages) = if let ( + Some(source_languages), + Some(target_languages), + ) = + (source_languages, target_languages) + { + match (source_languages.as_slice(), target_languages.as_slice()) { + ([Language::English], [Language::German]) => { + get_marian_resources!(ENGLISH2GERMAN) } - } else { - return Err(RustBertError::InvalidConfigurationError( - "Source and target languages must be provided for Marian models".to_string(), - )); - }; + ([Language::English], [Language::Russian]) => { + get_marian_resources!(ENGLISH2RUSSIAN) + } + ([Language::English], [Language::Dutch]) => { + get_marian_resources!(ENGLISH2DUTCH) + } + ([Language::English], [Language::ChineseMandarin]) => { + get_marian_resources!(ENGLISH2CHINESE) + } + ([Language::English], [Language::Swedish]) => { + get_marian_resources!(ENGLISH2SWEDISH) + } + ([Language::English], [Language::Arabic]) => { + get_marian_resources!(ENGLISH2ARABIC) + } + ([Language::English], [Language::Hindi]) => { + get_marian_resources!(ENGLISH2HINDI) + } + ([Language::English], [Language::Hebrew]) => { + get_marian_resources!(ENGLISH2HEBREW) + } + ([Language::German], [Language::English]) => { + get_marian_resources!(GERMAN2ENGLISH) + } + ([Language::German], [Language::French]) => { + get_marian_resources!(GERMAN2FRENCH) + } + ([Language::French], [Language::German]) => { + get_marian_resources!(FRENCH2GERMAN) + } + ([Language::Russian], [Language::English]) => { + get_marian_resources!(RUSSIAN2ENGLISH) + } + ([Language::Dutch], [Language::English]) => { + get_marian_resources!(DUTCH2ENGLISH) + } + ([Language::ChineseMandarin], [Language::English]) => { + get_marian_resources!(CHINESE2ENGLISH) + } + ([Language::Swedish], [Language::English]) => { + get_marian_resources!(SWEDISH2ENGLISH) + } + ([Language::Arabic], [Language::English]) => { + get_marian_resources!(ARABIC2ENGLISH) + } + ([Language::Hindi], [Language::English]) => { + get_marian_resources!(HINDI2ENGLISH) + } + ([Language::Hebrew], [Language::English]) => { + get_marian_resources!(HEBREW2ENGLISH) + } + ([Language::English], languages) + if languages + .iter() + .all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) => + { + get_marian_resources!(ENGLISH2ROMANCE) + } + (languages, [Language::English]) + if languages + .iter() + .all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) => + { + get_marian_resources!(ROMANCE2ENGLISH) + } + (_, _) => { + return Err(RustBertError::InvalidConfigurationError(format!( + "No Pretrained Marian configuration found for {source_languages:?} to {target_languages:?} translation", + ))); + } + } + } else { + return Err(RustBertError::InvalidConfigurationError( + "Source and target languages must be provided for Marian models".to_string(), + )); + }; Ok(TranslationResources { model_type: ModelType::Marian, diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs index 5da28d5..f71ff4b 100644 --- a/src/pipelines/translation/translation_pipeline.rs +++ b/src/pipelines/translation/translation_pipeline.rs @@ -135,7 +135,7 @@ pub enum Language { impl Display for Language { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", { - let input_string = format!("{:?}", self); + let input_string = format!("{self:?}"); let mut output: Vec<&str> = Vec::new(); let mut start: usize = 0; @@ -584,8 +584,7 @@ impl TranslationOption { if let Some(source_language) = source_language { if !supported_source_languages.contains(source_language) { return Err(RustBertError::ValueError(format!( - "{} not in list of supported languages: {:?}", - source_language, supported_source_languages + "{source_language} not in list of supported languages: {supported_source_languages:?}", ))); } } @@ -593,8 +592,7 @@ impl TranslationOption { if let Some(target_language) = target_language { if !supported_target_languages.contains(target_language) { return Err(RustBertError::ValueError(format!( - "{} not in list of supported languages: {:?}", - target_language, supported_target_languages + "{target_language} not in list of supported languages: {supported_target_languages:?}" ))); } } @@ -610,9 +608,8 @@ impl TranslationOption { None => { return Err(RustBertError::ValueError(format!( "Missing target language for Marian \ - (multiple languages supported by model: {:?}, \ + (multiple languages supported by model: {supported_target_languages:?}, \ need to specify target language)", - supported_target_languages ))); } } @@ -653,9 +650,8 @@ impl TranslationOption { None => { return Err(RustBertError::ValueError(format!( "Missing source language for MBart\ - (multiple languages supported by model: {:?}, \ - need to specify target language)", - supported_source_languages + (multiple languages supported by model: {supported_source_languages:?}, \ + need to specify target language)" ))); } } @@ -670,9 +666,8 @@ impl TranslationOption { } else { return Err(RustBertError::ValueError(format!( "Missing target language for MBart\ - (multiple languages supported by model: {:?}, \ - need to specify target language)", - supported_target_languages + (multiple languages supported by model: {supported_target_languages:?}, \ + need to specify target language)" ))); }, ), @@ -681,8 +676,8 @@ impl TranslationOption { Some(value) => { let language_code = value.get_iso_639_1_code(); match language_code.len() { - 2 => format!(">>{}.<< ", language_code), - 3 => format!(">>{}<< ", language_code), + 2 => format!(">>{language_code}.<< "), + 3 => format!(">>{language_code}<< "), _ => { return Err(RustBertError::ValueError( "Invalid ISO 639-3 code".to_string(), @@ -693,9 +688,8 @@ impl TranslationOption { None => { return Err(RustBertError::ValueError(format!( "Missing source language for M2M100 \ - (multiple languages supported by model: {:?}, \ - need to specify target language)", - supported_source_languages + (multiple languages supported by model: {supported_source_languages:?}, \ + need to specify target language)" ))); } }), @@ -704,8 +698,8 @@ impl TranslationOption { Some( model._get_tokenizer().convert_tokens_to_ids(&[ match language_code.len() { - 2 => format!(">>{}.<<", language_code), - 3 => format!(">>{}<<", language_code), + 2 => format!(">>{language_code}.<<"), + 3 => format!(">>{language_code}<<"), _ => { return Err(RustBertError::ValueError( "Invalid ISO 639-3 code".to_string(), @@ -717,9 +711,8 @@ impl TranslationOption { } else { return Err(RustBertError::ValueError(format!( "Missing target language for M2M100 \ - (multiple languages supported by model: {:?}, \ + (multiple languages supported by model: {supported_target_languages:?}, \ need to specify target language)", - supported_target_languages ))); }, ), diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index e8d2870..32e374b 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -370,8 +370,7 @@ impl ZeroShotClassificationOption { } } _ => Err(RustBertError::InvalidConfigurationError(format!( - "Zero shot classification not implemented for {:?}!", - model_type + "Zero shot classification not implemented for {model_type:?}!", ))), } } @@ -604,7 +603,7 @@ impl ZeroShotClassificationModel { None => labels .as_ref() .iter() - .map(|label| format!("This example is about {}.", label)) + .map(|label| format!("This example is about {label}.")) .collect(), }; diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index cb3efd2..8e7e872 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -14,7 +14,7 @@ use std::borrow::Borrow; use std::collections::HashMap; use rust_tokenizers::tokenizer::{ProphetNetTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::{ProphetNetVocab, Vocab}; +use rust_tokenizers::vocab::ProphetNetVocab; use serde::{Deserialize, Serialize}; use tch::{nn, Kind, Tensor}; @@ -1098,9 +1098,7 @@ impl let pad_token = match pad_token_id { Some(value) => value, - None => self - ._get_tokenizer() - .convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0], + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index 8f3ae37..61afbf2 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -548,7 +548,7 @@ impl ReformerModelWithLMHead { if let Some(lsh_num_chunks_after) = config.lsh_num_chunks_after { if config.attn_layers.contains(&AttentionType::lsh) & (lsh_num_chunks_after != 0) { return Err(RustBertError::InvalidConfigurationError( - format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {})", lsh_num_chunks_after), + format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {lsh_num_chunks_after})"), )); } } @@ -556,7 +556,7 @@ impl ReformerModelWithLMHead { if let Some(local_num_chunks_after) = config.local_num_chunks_after { if config.attn_layers.contains(&AttentionType::local) & (local_num_chunks_after != 0) { return Err(RustBertError::InvalidConfigurationError( - format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {})", local_num_chunks_after), + format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {local_num_chunks_after})"), )); } } diff --git a/tests/bart.rs b/tests/bart.rs index f2bed2a..590a0ec 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -215,7 +215,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> { [input_sentence, input_sequence_2], candidate_labels, Some(Box::new(|label: &str| { - format!("This example is about {}.", label) + format!("This example is about {label}.") })), 128, )?; @@ -244,7 +244,7 @@ fn bart_zero_shot_classification_try_error() -> anyhow::Result<()> { [], [], Some(Box::new(|label: &str| { - format!("This example is about {}.", label) + format!("This example is about {label}.") })), 128, ); @@ -276,7 +276,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> { [input_sentence, input_sequence_2], candidate_labels, Some(Box::new(|label: &str| { - format!("This example is about {}.", label) + format!("This example is about {label}.") })), 128, )?; @@ -319,7 +319,7 @@ fn bart_zero_shot_classification_multilabel_try_error() -> anyhow::Result<()> { [], [], Some(Box::new(|label: &str| { - format!("This example is about {}.", label) + format!("This example is about {label}.") })), 128, ); diff --git a/tests/longformer.rs b/tests/longformer.rs index 2db9bc7..508aaba 100644 --- a/tests/longformer.rs +++ b/tests/longformer.rs @@ -14,7 +14,7 @@ use rust_bert::pipelines::question_answering::{ use rust_bert::resources::{RemoteResource, ResourceProvider}; use rust_bert::Config; use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::{RobertaVocab, Vocab}; +use rust_tokenizers::vocab::Vocab; use std::collections::HashMap; use tch::{nn, no_grad, Device, Tensor}; @@ -67,7 +67,9 @@ fn longformer_masked_lm() -> anyhow::Result<()> { .map(|input| input.token_ids.clone()) .map(|mut input| { input.extend(vec![ - tokenizer.vocab().token_to_id(RobertaVocab::pad_value()); + tokenizer + .vocab() + .token_to_id(tokenizer.vocab().get_pad_value()); max_len - input.len() ]); input diff --git a/tests/prophetnet.rs b/tests/prophetnet.rs index 9d24153..9f4803a 100644 --- a/tests/prophetnet.rs +++ b/tests/prophetnet.rs @@ -63,8 +63,8 @@ about exoplanets like K2-18b."]; assert_eq!( output[0], "scientists have confirmed the presence of water in the atmosphere of k2 - 18b. \ -[X_SEP] this is the first such discovery in a planet in its star's habitable zone. \ -[X_SEP] the planet is 110 light - years from earth and has a star in the constellation leo." +this is the first such discovery in a planet in its star's habitable zone. \ +the planet is 110 light - years from earth and has a star in the constellation leo." ); Ok(())