mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Tokenizer special token map update (#330)
* Updates for compatibility with tokenizers special token rework * Updated mask pipline methods * Bumped version * Fix clippy warnings
This commit is contained in:
parent
80e0197e2c
commit
84561ec82b
@ -2,6 +2,8 @@
|
||||
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
## [Unreleased]
|
||||
## Changed
|
||||
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
|
||||
|
||||
## [0.20.0] - 2023-01-21
|
||||
## Added
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.20.0"
|
||||
version = "0.20.1-alpha"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
description = "Ready-to-use NLP pipelines and language models"
|
||||
@ -69,7 +69,7 @@ remote = ["cached-path", "dirs", "lazy_static"]
|
||||
features = ["doc-only"]
|
||||
|
||||
[dependencies]
|
||||
rust_tokenizers = "~7.0.2"
|
||||
rust_tokenizers = "8.0.0"
|
||||
tch = "~0.10.1"
|
||||
serde_json = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
|
@ -16,7 +16,7 @@ async fn main() -> Result<()> {
|
||||
"Classify this negative text".to_owned(),
|
||||
];
|
||||
let sentiments = classifier.predict(texts).await?;
|
||||
println!("Results: {:?}", sentiments);
|
||||
println!("Results: {sentiments:?}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = sequence_classification_model.predict(input);
|
||||
for label in output {
|
||||
println!("{:?}", label);
|
||||
println!("{label:?}");
|
||||
}
|
||||
|
||||
// Masked language model
|
||||
@ -78,7 +78,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = mask_language_model.predict(input)?;
|
||||
for sentence_output in output {
|
||||
println!("{:?}", sentence_output);
|
||||
println!("{sentence_output:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
|
||||
println!("{:?}", output);
|
||||
println!("{output:?}");
|
||||
|
||||
let _ = conversation_manager
|
||||
.get(&conversation_1_id)
|
||||
@ -40,11 +40,11 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
|
||||
println!("{:?}", output);
|
||||
println!("{output:?}");
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
|
||||
println!("{:?}", output);
|
||||
println!("{output:?}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = model.generate(&[input_context], None);
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence);
|
||||
println!("{sentence:?}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
|
||||
for sentence in output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||
|
||||
for sentence in output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = model.generate(&[input_context], None);
|
||||
|
||||
for sentence in output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = mask_language_model.predict(input)?;
|
||||
for sentence_output in output {
|
||||
println!("{:?}", sentence_output);
|
||||
println!("{sentence_output:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = ner_model.predict_full_entities(&input);
|
||||
for entity in output {
|
||||
println!("{:?}", entity);
|
||||
println!("{entity:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = pos_model.predict(&input);
|
||||
for (pos, pos_tag) in output[0].iter().enumerate() {
|
||||
println!("{} - {:?}", pos, pos_tag);
|
||||
println!("{pos} - {pos_tag:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -34,6 +34,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Get answer
|
||||
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||
println!("{:?}", answers);
|
||||
println!("{answers:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,6 +50,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Get answer
|
||||
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||
println!("{:?}", answers);
|
||||
println!("{answers:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -55,6 +55,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Get answer
|
||||
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||
println!("{:?}", answers);
|
||||
println!("{answers:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -12,6 +12,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Generate Embeddings
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
println!("{:?}", embeddings);
|
||||
println!("{embeddings:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -32,6 +32,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Generate Embeddings
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
println!("{:?}", embeddings);
|
||||
println!("{embeddings:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = sentiment_classifier.predict(input);
|
||||
for sentiment in output {
|
||||
println!("{:?}", sentiment);
|
||||
println!("{sentiment:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = sentiment_classifier.predict(input);
|
||||
for sentiment in output {
|
||||
println!("{:?}", sentiment);
|
||||
println!("{sentiment:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// Run model
|
||||
let output = sequence_classification_model.predict(input);
|
||||
for label in output {
|
||||
println!("{:?}", label);
|
||||
println!("{label:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -29,7 +29,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
|
||||
if let Ok(labels) = output {
|
||||
for label in labels {
|
||||
println!("{:?}", label);
|
||||
println!("{label:?}");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,7 @@ about exoplanets like K2-18b."];
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
for sentence in _output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -68,7 +68,7 @@ about exoplanets like K2-18b."];
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
for sentence in _output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -70,7 +70,7 @@ about exoplanets like K2-18b."];
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
for sentence in _output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -56,7 +56,7 @@ about exoplanets like K2-18b."];
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let _output = summarization_model.summarize(&input);
|
||||
for sentence in _output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let token_outputs = token_classification_model.predict(&input);
|
||||
|
||||
for token in token_outputs {
|
||||
println!("{:?}", token);
|
||||
println!("{token:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let output = model.translate(&[input_context_1, input_context_2], None, None)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ fn main() -> anyhow::Result<()> {
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
println!("{sentence}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -27,13 +27,13 @@ fn main() -> anyhow::Result<()> {
|
||||
[input_sentence, input_sequence_2],
|
||||
candidate_labels,
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
format!("This example is about {label}.")
|
||||
})),
|
||||
128,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", output);
|
||||
println!("{output:?}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ use crate::pipelines::generation_utils::{
|
||||
};
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
|
||||
use rust_tokenizers::vocab::RobertaVocab;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
@ -1263,9 +1263,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0],
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -14,8 +14,7 @@ pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError>
|
||||
Kind::Double => Scalar::float(f64::INFINITY),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Type not supported: attempted to get positive infinity for {:?}",
|
||||
kind
|
||||
"Type not supported: attempted to get positive infinity for {kind:?}",
|
||||
)))
|
||||
}
|
||||
})
|
||||
@ -34,8 +33,7 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError>
|
||||
Kind::Double => Scalar::float(f64::NEG_INFINITY),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Type not supported: attempted to get negative infinity for {:?}",
|
||||
kind
|
||||
"Type not supported: attempted to get negative infinity for {kind:?}",
|
||||
)))
|
||||
}
|
||||
})
|
||||
|
@ -111,8 +111,7 @@ impl FromStr for PositionAttentionType {
|
||||
"c2p" => Ok(PositionAttentionType::c2p),
|
||||
"p2p" => Ok(PositionAttentionType::p2p),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Position attention type `{}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
|
||||
s
|
||||
"Position attention type `{s}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
@ -118,8 +118,7 @@ impl FromStr for NormRelEmbedType {
|
||||
match s {
|
||||
"layer_norm" => Ok(NormRelEmbedType::layer_norm),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Layer normalization type `{}` not in accepted variants (`layer_norm`)",
|
||||
s
|
||||
"Layer normalization type `{s}` not in accepted variants (`layer_norm`)",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ use crate::pipelines::generation_utils::{
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
|
||||
use rust_tokenizers::vocab::M2M100Vocab;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
@ -804,9 +804,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0],
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -885,7 +885,9 @@ impl MarianGenerator {
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id =
|
||||
Some(tokenizer.convert_tokens_to_ids(&[MarianVocab::pad_value()])[0]);
|
||||
Some(tokenizer.get_pad_id().ok_or(RustBertError::TokenizerError(
|
||||
"The tokenizer must contain a pad token ID to be used as BOS".to_string(),
|
||||
))?);
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(MarianGenerator {
|
||||
|
@ -25,7 +25,7 @@ use crate::pipelines::generation_utils::{
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::{Activation, Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
|
||||
use rust_tokenizers::vocab::MBart50Vocab;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
@ -1059,9 +1059,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0],
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -780,7 +780,8 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
|
||||
Some(value) => value,
|
||||
None => self
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0],
|
||||
.get_pad_id()
|
||||
.expect("A padding token must be provided to encode prompt texts."),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -46,11 +46,7 @@ use rust_tokenizers::tokenizer::{
|
||||
OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer,
|
||||
T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
|
||||
};
|
||||
use rust_tokenizers::vocab::{
|
||||
AlbertVocab, BertVocab, DeBERTaV2Vocab, DeBERTaVocab, FNetVocab, Gpt2Vocab, M2M100Vocab,
|
||||
MBart50Vocab, MarianVocab, OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab,
|
||||
RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab,
|
||||
};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@ -1275,174 +1271,144 @@ impl TokenizerOption {
|
||||
/// Interface method
|
||||
pub fn get_unk_id(&self) -> i64 {
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(BertVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Deberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::DebertaV2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaV2Vocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Roberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Bart(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::XLMRoberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLMRobertaVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Marian(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MarianVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::T5(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(T5Vocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Albert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(AlbertVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::XLNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLNetVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::GPT2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(Gpt2Vocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::OpenAiGpt(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(OpenAiGptVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Reformer(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(ReformerVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::ProphetNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(ProphetNetVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Pegasus(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(PegasusVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::MBart50(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MBart50Vocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::M2M100(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(M2M100Vocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(FNetVocab::unknown_value())
|
||||
.expect("UNK token not found in vocabulary"),
|
||||
Self::Bert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Marian(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::T5(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::GPT2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::OpenAiGpt(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Reformer(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::ProphetNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::Pegasus(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::MBart50(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::M2M100(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
Self::FNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
vocab.token_to_id(vocab.get_unknown_value())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method
|
||||
pub fn get_pad_id(&self) -> Option<i64> {
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(BertVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Deberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::DebertaV2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaV2Vocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Roberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Bart(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::pad_value())
|
||||
.unwrap_or(&1),
|
||||
),
|
||||
Self::XLMRoberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLMRobertaVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Marian(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MarianVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::T5(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(T5Vocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Albert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(AlbertVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::XLNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLNetVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::ProphetNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(ProphetNetVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Pegasus(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(PegasusVocab::pad_value())
|
||||
.unwrap_or(&0),
|
||||
),
|
||||
Self::MBart50(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MBart50Vocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::M2M100(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(M2M100Vocab::pad_value())
|
||||
.unwrap_or(&1),
|
||||
),
|
||||
Self::FNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(FNetVocab::pad_value())
|
||||
.expect("PAD token not found in vocabulary"),
|
||||
),
|
||||
Self::Bert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Marian(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::T5(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::ProphetNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Pegasus(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::MBart50(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::M2M100(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::FNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||
}
|
||||
Self::Reformer(_) => None,
|
||||
Self::GPT2(_) => None,
|
||||
Self::OpenAiGpt(_) => None,
|
||||
@ -1452,78 +1418,54 @@ impl TokenizerOption {
|
||||
/// Interface method
|
||||
pub fn get_sep_id(&self) -> Option<i64> {
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(BertVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::Deberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::DebertaV2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaV2Vocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::Roberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::Bart(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::XLMRoberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLMRobertaVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::Albert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(AlbertVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::XLNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLNetVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::ProphetNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(ProphetNetVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::MBart50(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MBart50Vocab::sep_value())
|
||||
.unwrap_or(&1),
|
||||
),
|
||||
Self::M2M100(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(M2M100Vocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::FNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(FNetVocab::sep_value())
|
||||
.expect("SEP token not found in vocabulary"),
|
||||
),
|
||||
Self::Bert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::ProphetNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::MBart50(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::M2M100(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::FNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||
}
|
||||
Self::Marian(_) => None,
|
||||
Self::T5(_) => None,
|
||||
Self::GPT2(_) => None,
|
||||
@ -1536,78 +1478,54 @@ impl TokenizerOption {
|
||||
/// Interface method
|
||||
pub fn get_mask_id(&self) -> Option<i64> {
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(BertVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::Deberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::DebertaV2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaV2Vocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::Roberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::Bart(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::XLMRoberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLMRobertaVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::Albert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(AlbertVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::XLNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLNetVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::ProphetNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(ProphetNetVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::MBart50(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MBart50Vocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::FNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(FNetVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::Pegasus(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(PegasusVocab::mask_value())
|
||||
.expect("MASK token not found in vocabulary"),
|
||||
),
|
||||
Self::Bert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::ProphetNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::MBart50(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::FNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::Pegasus(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||
}
|
||||
Self::Marian(_) => None,
|
||||
Self::M2M100(_) => None,
|
||||
Self::T5(_) => None,
|
||||
@ -1620,84 +1538,90 @@ impl TokenizerOption {
|
||||
/// Interface method
|
||||
pub fn get_mask_value(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Bert(_) => Some(BertVocab::mask_value()),
|
||||
Self::Deberta(_) => Some(DeBERTaVocab::mask_value()),
|
||||
Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()),
|
||||
Self::Roberta(_) => Some(RobertaVocab::mask_value()),
|
||||
Self::Bart(_) => Some(RobertaVocab::mask_value()),
|
||||
Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()),
|
||||
Self::Albert(_) => Some(AlbertVocab::mask_value()),
|
||||
Self::XLNet(_) => Some(XLNetVocab::mask_value()),
|
||||
Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()),
|
||||
Self::MBart50(_) => Some(MBart50Vocab::mask_value()),
|
||||
Self::FNet(_er) => Some(FNetVocab::mask_value()),
|
||||
Self::Bert(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::ProphetNet(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::MBart50(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::FNet(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::Pegasus(ref tokenizer) => {
|
||||
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||
}
|
||||
Self::M2M100(_) => None,
|
||||
Self::Marian(_) => None,
|
||||
Self::T5(_) => None,
|
||||
Self::GPT2(_) => None,
|
||||
Self::OpenAiGpt(_) => None,
|
||||
Self::Reformer(_) => None,
|
||||
Self::Pegasus(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method
|
||||
pub fn get_bos_id(&self) -> Option<i64> {
|
||||
match *self {
|
||||
Self::Roberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Bart(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::bos_value())
|
||||
.unwrap_or(&0),
|
||||
),
|
||||
Self::DebertaV2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaV2Vocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::XLMRoberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLMRobertaVocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Albert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(AlbertVocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::XLNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLNetVocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::M2M100(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(M2M100Vocab::bos_value())
|
||||
.unwrap_or(&0),
|
||||
),
|
||||
Self::GPT2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(Gpt2Vocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Deberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaVocab::bos_value())
|
||||
.expect("BOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::M2M100(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::GPT2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||
}
|
||||
Self::MBart50(_) => Some(0),
|
||||
Self::FNet(_) => None,
|
||||
Self::Bert(_) => None,
|
||||
@ -1713,90 +1637,62 @@ impl TokenizerOption {
|
||||
/// Interface method
|
||||
pub fn get_eos_id(&self) -> Option<i64> {
|
||||
match *self {
|
||||
Self::Roberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Bart(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(RobertaVocab::eos_value())
|
||||
.unwrap_or(&2),
|
||||
),
|
||||
Self::DebertaV2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaV2Vocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::XLMRoberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLMRobertaVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Albert(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(AlbertVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::XLNet(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(XLNetVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::MBart50(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MBart50Vocab::eos_value())
|
||||
.unwrap_or(&2),
|
||||
),
|
||||
Self::M2M100(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(M2M100Vocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::GPT2(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(Gpt2Vocab::eos_value())
|
||||
.unwrap_or(&2),
|
||||
),
|
||||
Self::Deberta(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(DeBERTaVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Marian(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(MarianVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::T5(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(T5Vocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Reformer(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(ReformerVocab::eos_value())
|
||||
.expect("EOS token not found in vocabulary"),
|
||||
),
|
||||
Self::Pegasus(ref tokenizer) => Some(
|
||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
||||
.special_values
|
||||
.get(PegasusVocab::eos_value())
|
||||
.unwrap_or(&1),
|
||||
),
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::Bart(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::DebertaV2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::XLMRoberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::Albert(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::XLNet(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::MBart50(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::M2M100(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::GPT2(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::Deberta(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::Marian(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::T5(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::Reformer(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::Pegasus(ref tokenizer) => {
|
||||
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||
}
|
||||
Self::FNet(_) => None,
|
||||
Self::Bert(_) => None,
|
||||
Self::ProphetNet(_) => None,
|
||||
|
@ -257,8 +257,7 @@ impl MaskedLanguageOption {
|
||||
}
|
||||
}
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Masked Language is not implemented for {:?}!",
|
||||
model_type
|
||||
"Masked Language is not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
@ -456,8 +456,7 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"QuestionAnswering not implemented for {:?}!",
|
||||
model_type
|
||||
"QuestionAnswering not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
@ -88,8 +88,7 @@ impl SentenceEmbeddingsBuilder<Local> {
|
||||
ModelType::T5 => (model_dir.join("spiece.model"), None),
|
||||
_ => {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Unsupported transformer model {:?} for Sentence Embeddings",
|
||||
transformer_type
|
||||
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
@ -408,7 +408,7 @@ mod serde_sentence_embeddings_module_type {
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type))
|
||||
serializer.serialize_str(&format!("sentence_transformers.models.{module_type:?}"))
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<SentenceEmbeddingsModuleType, D::Error>
|
||||
@ -430,7 +430,7 @@ mod serde_sentence_embeddings_module_type {
|
||||
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_string())))
|
||||
.transpose()
|
||||
.map_err(de::Error::custom)?
|
||||
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s))
|
||||
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {s}"))
|
||||
.map_err(de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ where
|
||||
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase())))
|
||||
.transpose()
|
||||
.map_err(de::Error::custom)?
|
||||
.ok_or_else(|| format!("Invalid Activation: {}", activation))
|
||||
.ok_or_else(|| format!("Invalid Activation: {activation}"))
|
||||
.map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
|
@ -65,8 +65,7 @@ impl SentenceEmbeddingsOption {
|
||||
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
|
||||
_ => {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Unsupported transformer model {:?} for Sentence Embeddings",
|
||||
transformer_type
|
||||
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
@ -373,8 +373,7 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Sequence Classification not implemented for {:?}!",
|
||||
model_type
|
||||
"Sequence Classification not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
@ -484,8 +484,7 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Token classification not implemented for {:?}!",
|
||||
model_type
|
||||
"Token classification not implemented for {model_type:?}!"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
@ -372,8 +372,7 @@ impl TranslationModelBuilder {
|
||||
}
|
||||
(Some(model_type), _, _) => {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Automated translation model builder not implemented for {:?}",
|
||||
model_type
|
||||
"Automated translation model builder not implemented for {model_type:?}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
@ -459,8 +458,10 @@ mod model_fetchers {
|
||||
source_languages: Option<&Vec<Language>>,
|
||||
target_languages: Option<&Vec<Language>>,
|
||||
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||
let (resources, source_languages, target_languages) =
|
||||
if let (Some(source_languages), Some(target_languages)) =
|
||||
let (resources, source_languages, target_languages) = if let (
|
||||
Some(source_languages),
|
||||
Some(target_languages),
|
||||
) =
|
||||
(source_languages, target_languages)
|
||||
{
|
||||
match (source_languages.as_slice(), target_languages.as_slice()) {
|
||||
@ -534,8 +535,7 @@ mod model_fetchers {
|
||||
}
|
||||
(_, _) => {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"No Pretrained Marian configuration found for {:?} to {:?} translation",
|
||||
source_languages, target_languages
|
||||
"No Pretrained Marian configuration found for {source_languages:?} to {target_languages:?} translation",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +135,7 @@ pub enum Language {
|
||||
impl Display for Language {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", {
|
||||
let input_string = format!("{:?}", self);
|
||||
let input_string = format!("{self:?}");
|
||||
let mut output: Vec<&str> = Vec::new();
|
||||
let mut start: usize = 0;
|
||||
|
||||
@ -584,8 +584,7 @@ impl TranslationOption {
|
||||
if let Some(source_language) = source_language {
|
||||
if !supported_source_languages.contains(source_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{} not in list of supported languages: {:?}",
|
||||
source_language, supported_source_languages
|
||||
"{source_language} not in list of supported languages: {supported_source_languages:?}",
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -593,8 +592,7 @@ impl TranslationOption {
|
||||
if let Some(target_language) = target_language {
|
||||
if !supported_target_languages.contains(target_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{} not in list of supported languages: {:?}",
|
||||
target_language, supported_target_languages
|
||||
"{target_language} not in list of supported languages: {supported_target_languages:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -610,9 +608,8 @@ impl TranslationOption {
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for Marian \
|
||||
(multiple languages supported by model: {:?}, \
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)",
|
||||
supported_target_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -653,9 +650,8 @@ impl TranslationOption {
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing source language for MBart\
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_source_languages
|
||||
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -670,9 +666,8 @@ impl TranslationOption {
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for MBart\
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_target_languages
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
},
|
||||
),
|
||||
@ -681,8 +676,8 @@ impl TranslationOption {
|
||||
Some(value) => {
|
||||
let language_code = value.get_iso_639_1_code();
|
||||
match language_code.len() {
|
||||
2 => format!(">>{}.<< ", language_code),
|
||||
3 => format!(">>{}<< ", language_code),
|
||||
2 => format!(">>{language_code}.<< "),
|
||||
3 => format!(">>{language_code}<< "),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
@ -693,9 +688,8 @@ impl TranslationOption {
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing source language for M2M100 \
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_source_languages
|
||||
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
}
|
||||
}),
|
||||
@ -704,8 +698,8 @@ impl TranslationOption {
|
||||
Some(
|
||||
model._get_tokenizer().convert_tokens_to_ids(&[
|
||||
match language_code.len() {
|
||||
2 => format!(">>{}.<<", language_code),
|
||||
3 => format!(">>{}<<", language_code),
|
||||
2 => format!(">>{language_code}.<<"),
|
||||
3 => format!(">>{language_code}<<"),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
@ -717,9 +711,8 @@ impl TranslationOption {
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for M2M100 \
|
||||
(multiple languages supported by model: {:?}, \
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)",
|
||||
supported_target_languages
|
||||
)));
|
||||
},
|
||||
),
|
||||
|
@ -370,8 +370,7 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Zero shot classification not implemented for {:?}!",
|
||||
model_type
|
||||
"Zero shot classification not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}
|
||||
@ -604,7 +603,7 @@ impl ZeroShotClassificationModel {
|
||||
None => labels
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|label| format!("This example is about {}.", label))
|
||||
.map(|label| format!("This example is about {label}."))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
|
@ -14,7 +14,7 @@ use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rust_tokenizers::tokenizer::{ProphetNetTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::{ProphetNetVocab, Vocab};
|
||||
use rust_tokenizers::vocab::ProphetNetVocab;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
@ -1098,9 +1098,7 @@ impl
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0],
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -548,7 +548,7 @@ impl ReformerModelWithLMHead {
|
||||
if let Some(lsh_num_chunks_after) = config.lsh_num_chunks_after {
|
||||
if config.attn_layers.contains(&AttentionType::lsh) & (lsh_num_chunks_after != 0) {
|
||||
return Err(RustBertError::InvalidConfigurationError(
|
||||
format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {})", lsh_num_chunks_after),
|
||||
format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {lsh_num_chunks_after})"),
|
||||
));
|
||||
}
|
||||
}
|
||||
@ -556,7 +556,7 @@ impl ReformerModelWithLMHead {
|
||||
if let Some(local_num_chunks_after) = config.local_num_chunks_after {
|
||||
if config.attn_layers.contains(&AttentionType::local) & (local_num_chunks_after != 0) {
|
||||
return Err(RustBertError::InvalidConfigurationError(
|
||||
format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {})", local_num_chunks_after),
|
||||
format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {local_num_chunks_after})"),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -215,7 +215,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
|
||||
[input_sentence, input_sequence_2],
|
||||
candidate_labels,
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
format!("This example is about {label}.")
|
||||
})),
|
||||
128,
|
||||
)?;
|
||||
@ -244,7 +244,7 @@ fn bart_zero_shot_classification_try_error() -> anyhow::Result<()> {
|
||||
[],
|
||||
[],
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
format!("This example is about {label}.")
|
||||
})),
|
||||
128,
|
||||
);
|
||||
@ -276,7 +276,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
|
||||
[input_sentence, input_sequence_2],
|
||||
candidate_labels,
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
format!("This example is about {label}.")
|
||||
})),
|
||||
128,
|
||||
)?;
|
||||
@ -319,7 +319,7 @@ fn bart_zero_shot_classification_multilabel_try_error() -> anyhow::Result<()> {
|
||||
[],
|
||||
[],
|
||||
Some(Box::new(|label: &str| {
|
||||
format!("This example is about {}.", label)
|
||||
format!("This example is about {label}.")
|
||||
})),
|
||||
128,
|
||||
);
|
||||
|
@ -14,7 +14,7 @@ use rust_bert::pipelines::question_answering::{
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use std::collections::HashMap;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
@ -67,7 +67,9 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![
|
||||
tokenizer.vocab().token_to_id(RobertaVocab::pad_value());
|
||||
tokenizer
|
||||
.vocab()
|
||||
.token_to_id(tokenizer.vocab().get_pad_value());
|
||||
max_len - input.len()
|
||||
]);
|
||||
input
|
||||
|
@ -63,8 +63,8 @@ about exoplanets like K2-18b."];
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"scientists have confirmed the presence of water in the atmosphere of k2 - 18b. \
|
||||
[X_SEP] this is the first such discovery in a planet in its star's habitable zone. \
|
||||
[X_SEP] the planet is 110 light - years from earth and has a star in the constellation leo."
|
||||
this is the first such discovery in a planet in its star's habitable zone. \
|
||||
the planet is 110 light - years from earth and has a star in the constellation leo."
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
Loading…
Reference in New Issue
Block a user