mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Tokenizer special token map update (#330)
* Updates for compatibility with tokenizers special token rework * Updated mask pipline methods * Bumped version * Fix clippy warnings
This commit is contained in:
parent
80e0197e2c
commit
84561ec82b
@ -2,6 +2,8 @@
|
|||||||
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
## Changed
|
||||||
|
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
|
||||||
|
|
||||||
## [0.20.0] - 2023-01-21
|
## [0.20.0] - 2023-01-21
|
||||||
## Added
|
## Added
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "rust-bert"
|
name = "rust-bert"
|
||||||
version = "0.20.0"
|
version = "0.20.1-alpha"
|
||||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
description = "Ready-to-use NLP pipelines and language models"
|
description = "Ready-to-use NLP pipelines and language models"
|
||||||
@ -69,7 +69,7 @@ remote = ["cached-path", "dirs", "lazy_static"]
|
|||||||
features = ["doc-only"]
|
features = ["doc-only"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rust_tokenizers = "~7.0.2"
|
rust_tokenizers = "8.0.0"
|
||||||
tch = "~0.10.1"
|
tch = "~0.10.1"
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
@ -16,7 +16,7 @@ async fn main() -> Result<()> {
|
|||||||
"Classify this negative text".to_owned(),
|
"Classify this negative text".to_owned(),
|
||||||
];
|
];
|
||||||
let sentiments = classifier.predict(texts).await?;
|
let sentiments = classifier.predict(texts).await?;
|
||||||
println!("Results: {:?}", sentiments);
|
println!("Results: {sentiments:?}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = sequence_classification_model.predict(input);
|
let output = sequence_classification_model.predict(input);
|
||||||
for label in output {
|
for label in output {
|
||||||
println!("{:?}", label);
|
println!("{label:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Masked language model
|
// Masked language model
|
||||||
@ -78,7 +78,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = mask_language_model.predict(input)?;
|
let output = mask_language_model.predict(input)?;
|
||||||
for sentence_output in output {
|
for sentence_output in output {
|
||||||
println!("{:?}", sentence_output);
|
println!("{sentence_output:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||||
|
|
||||||
println!("{:?}", output);
|
println!("{output:?}");
|
||||||
|
|
||||||
let _ = conversation_manager
|
let _ = conversation_manager
|
||||||
.get(&conversation_1_id)
|
.get(&conversation_1_id)
|
||||||
@ -40,11 +40,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||||
|
|
||||||
println!("{:?}", output);
|
println!("{output:?}");
|
||||||
|
|
||||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||||
|
|
||||||
println!("{:?}", output);
|
println!("{output:?}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = model.generate(&[input_context], None);
|
let output = model.generate(&[input_context], None);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{:?}", sentence);
|
println!("{sentence:?}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = model.generate(&[input_context_1, input_context_2], None);
|
let output = model.generate(&[input_context_1, input_context_2], None);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = model.generate(&[input_context], None);
|
let output = model.generate(&[input_context], None);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = mask_language_model.predict(input)?;
|
let output = mask_language_model.predict(input)?;
|
||||||
for sentence_output in output {
|
for sentence_output in output {
|
||||||
println!("{:?}", sentence_output);
|
println!("{sentence_output:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = ner_model.predict_full_entities(&input);
|
let output = ner_model.predict_full_entities(&input);
|
||||||
for entity in output {
|
for entity in output {
|
||||||
println!("{:?}", entity);
|
println!("{entity:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = pos_model.predict(&input);
|
let output = pos_model.predict(&input);
|
||||||
for (pos, pos_tag) in output[0].iter().enumerate() {
|
for (pos, pos_tag) in output[0].iter().enumerate() {
|
||||||
println!("{} - {:?}", pos, pos_tag);
|
println!("{pos} - {pos_tag:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -34,6 +34,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Get answer
|
// Get answer
|
||||||
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||||
println!("{:?}", answers);
|
println!("{answers:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Get answer
|
// Get answer
|
||||||
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||||
println!("{:?}", answers);
|
println!("{answers:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -55,6 +55,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Get answer
|
// Get answer
|
||||||
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||||
println!("{:?}", answers);
|
println!("{answers:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Generate Embeddings
|
// Generate Embeddings
|
||||||
let embeddings = model.encode(&sentences)?;
|
let embeddings = model.encode(&sentences)?;
|
||||||
println!("{:?}", embeddings);
|
println!("{embeddings:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Generate Embeddings
|
// Generate Embeddings
|
||||||
let embeddings = model.encode(&sentences)?;
|
let embeddings = model.encode(&sentences)?;
|
||||||
println!("{:?}", embeddings);
|
println!("{embeddings:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = sentiment_classifier.predict(input);
|
let output = sentiment_classifier.predict(input);
|
||||||
for sentiment in output {
|
for sentiment in output {
|
||||||
println!("{:?}", sentiment);
|
println!("{sentiment:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = sentiment_classifier.predict(input);
|
let output = sentiment_classifier.predict(input);
|
||||||
for sentiment in output {
|
for sentiment in output {
|
||||||
println!("{:?}", sentiment);
|
println!("{sentiment:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Run model
|
// Run model
|
||||||
let output = sequence_classification_model.predict(input);
|
let output = sequence_classification_model.predict(input);
|
||||||
for label in output {
|
for label in output {
|
||||||
println!("{:?}", label);
|
println!("{label:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -29,7 +29,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
|
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
|
||||||
if let Ok(labels) = output {
|
if let Ok(labels) = output {
|
||||||
for label in labels {
|
for label in labels {
|
||||||
println!("{:?}", label);
|
println!("{label:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ about exoplanets like K2-18b."];
|
|||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let _output = summarization_model.summarize(&input);
|
let _output = summarization_model.summarize(&input);
|
||||||
for sentence in _output {
|
for sentence in _output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -68,7 +68,7 @@ about exoplanets like K2-18b."];
|
|||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let _output = summarization_model.summarize(&input);
|
let _output = summarization_model.summarize(&input);
|
||||||
for sentence in _output {
|
for sentence in _output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -70,7 +70,7 @@ about exoplanets like K2-18b."];
|
|||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let _output = summarization_model.summarize(&input);
|
let _output = summarization_model.summarize(&input);
|
||||||
for sentence in _output {
|
for sentence in _output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -56,7 +56,7 @@ about exoplanets like K2-18b."];
|
|||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let _output = summarization_model.summarize(&input);
|
let _output = summarization_model.summarize(&input);
|
||||||
for sentence in _output {
|
for sentence in _output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let token_outputs = token_classification_model.predict(&input);
|
let token_outputs = token_classification_model.predict(&input);
|
||||||
|
|
||||||
for token in token_outputs {
|
for token in token_outputs {
|
||||||
println!("{:?}", token);
|
println!("{token:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
|
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||||
|
|
||||||
for sentence in outputs {
|
for sentence in outputs {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let output = model.translate(&[input_context_1, input_context_2], None, None)?;
|
let output = model.translate(&[input_context_1, input_context_2], None, None)?;
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||||
|
|
||||||
for sentence in outputs {
|
for sentence in outputs {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
|
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
|
||||||
|
|
||||||
for sentence in outputs {
|
for sentence in outputs {
|
||||||
println!("{}", sentence);
|
println!("{sentence}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -27,13 +27,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
[input_sentence, input_sequence_2],
|
[input_sentence, input_sequence_2],
|
||||||
candidate_labels,
|
candidate_labels,
|
||||||
Some(Box::new(|label: &str| {
|
Some(Box::new(|label: &str| {
|
||||||
format!("This example is about {}.", label)
|
format!("This example is about {label}.")
|
||||||
})),
|
})),
|
||||||
128,
|
128,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
println!("{:?}", output);
|
println!("{output:?}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ use crate::pipelines::generation_utils::{
|
|||||||
};
|
};
|
||||||
use crate::{Config, RustBertError};
|
use crate::{Config, RustBertError};
|
||||||
use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
|
use rust_tokenizers::vocab::RobertaVocab;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -1263,9 +1263,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
|||||||
|
|
||||||
let pad_token = match pad_token_id {
|
let pad_token = match pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => self
|
None => self._get_tokenizer().get_unk_id(),
|
||||||
._get_tokenizer()
|
|
||||||
.convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_ids = token_ids
|
let token_ids = token_ids
|
||||||
|
@ -14,8 +14,7 @@ pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError>
|
|||||||
Kind::Double => Scalar::float(f64::INFINITY),
|
Kind::Double => Scalar::float(f64::INFINITY),
|
||||||
_ => {
|
_ => {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Type not supported: attempted to get positive infinity for {:?}",
|
"Type not supported: attempted to get positive infinity for {kind:?}",
|
||||||
kind
|
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -34,8 +33,7 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError>
|
|||||||
Kind::Double => Scalar::float(f64::NEG_INFINITY),
|
Kind::Double => Scalar::float(f64::NEG_INFINITY),
|
||||||
_ => {
|
_ => {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Type not supported: attempted to get negative infinity for {:?}",
|
"Type not supported: attempted to get negative infinity for {kind:?}",
|
||||||
kind
|
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -111,8 +111,7 @@ impl FromStr for PositionAttentionType {
|
|||||||
"c2p" => Ok(PositionAttentionType::c2p),
|
"c2p" => Ok(PositionAttentionType::c2p),
|
||||||
"p2p" => Ok(PositionAttentionType::p2p),
|
"p2p" => Ok(PositionAttentionType::p2p),
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Position attention type `{}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
|
"Position attention type `{s}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
|
||||||
s
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,8 +118,7 @@ impl FromStr for NormRelEmbedType {
|
|||||||
match s {
|
match s {
|
||||||
"layer_norm" => Ok(NormRelEmbedType::layer_norm),
|
"layer_norm" => Ok(NormRelEmbedType::layer_norm),
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Layer normalization type `{}` not in accepted variants (`layer_norm`)",
|
"Layer normalization type `{s}` not in accepted variants (`layer_norm`)",
|
||||||
s
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ use crate::pipelines::generation_utils::{
|
|||||||
use crate::pipelines::translation::Language;
|
use crate::pipelines::translation::Language;
|
||||||
use crate::{Config, RustBertError};
|
use crate::{Config, RustBertError};
|
||||||
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
|
use rust_tokenizers::vocab::M2M100Vocab;
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
use tch::nn::{embedding, EmbeddingConfig};
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
use tch::{nn, Kind, Tensor};
|
use tch::{nn, Kind, Tensor};
|
||||||
@ -804,9 +804,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
|
|||||||
|
|
||||||
let pad_token = match pad_token_id {
|
let pad_token = match pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => self
|
None => self._get_tokenizer().get_unk_id(),
|
||||||
._get_tokenizer()
|
|
||||||
.convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_ids = token_ids
|
let token_ids = token_ids
|
||||||
|
@ -885,7 +885,9 @@ impl MarianGenerator {
|
|||||||
let vocab_size = config.vocab_size;
|
let vocab_size = config.vocab_size;
|
||||||
let is_encoder_decoder = true;
|
let is_encoder_decoder = true;
|
||||||
let decoder_start_id =
|
let decoder_start_id =
|
||||||
Some(tokenizer.convert_tokens_to_ids(&[MarianVocab::pad_value()])[0]);
|
Some(tokenizer.get_pad_id().ok_or(RustBertError::TokenizerError(
|
||||||
|
"The tokenizer must contain a pad token ID to be used as BOS".to_string(),
|
||||||
|
))?);
|
||||||
let max_position_embeddings = config.max_position_embeddings;
|
let max_position_embeddings = config.max_position_embeddings;
|
||||||
|
|
||||||
Ok(MarianGenerator {
|
Ok(MarianGenerator {
|
||||||
|
@ -25,7 +25,7 @@ use crate::pipelines::generation_utils::{
|
|||||||
use crate::pipelines::translation::Language;
|
use crate::pipelines::translation::Language;
|
||||||
use crate::{Activation, Config, RustBertError};
|
use crate::{Activation, Config, RustBertError};
|
||||||
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
|
use rust_tokenizers::vocab::MBart50Vocab;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -1059,9 +1059,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
|
|||||||
|
|
||||||
let pad_token = match pad_token_id {
|
let pad_token = match pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => self
|
None => self._get_tokenizer().get_unk_id(),
|
||||||
._get_tokenizer()
|
|
||||||
.convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_ids = token_ids
|
let token_ids = token_ids
|
||||||
|
@ -780,7 +780,8 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
|
|||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => self
|
None => self
|
||||||
._get_tokenizer()
|
._get_tokenizer()
|
||||||
.convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0],
|
.get_pad_id()
|
||||||
|
.expect("A padding token must be provided to encode prompt texts."),
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_ids = token_ids
|
let token_ids = token_ids
|
||||||
|
@ -46,11 +46,7 @@ use rust_tokenizers::tokenizer::{
|
|||||||
OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer,
|
OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer,
|
||||||
T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
|
T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
|
||||||
};
|
};
|
||||||
use rust_tokenizers::vocab::{
|
use rust_tokenizers::vocab::Vocab;
|
||||||
AlbertVocab, BertVocab, DeBERTaV2Vocab, DeBERTaVocab, FNetVocab, Gpt2Vocab, M2M100Vocab,
|
|
||||||
MBart50Vocab, MarianVocab, OpenAiGptVocab, PegasusVocab, ProphetNetVocab, ReformerVocab,
|
|
||||||
RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab,
|
|
||||||
};
|
|
||||||
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
|
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -1275,174 +1271,144 @@ impl TokenizerOption {
|
|||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_unk_id(&self) -> i64 {
|
pub fn get_unk_id(&self) -> i64 {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Bert(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(BertVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Deberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Deberta(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(DeBERTaVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::DebertaV2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(DeBERTaV2Vocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Roberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Roberta(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(RobertaVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Bart(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Bart(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(RobertaVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::XLMRoberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(XLMRobertaVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Marian(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Marian(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(MarianVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::T5(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::T5(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(T5Vocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Albert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Albert(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(AlbertVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::XLNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::XLNet(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(XLNetVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::GPT2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::GPT2(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(Gpt2Vocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::OpenAiGpt(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::OpenAiGpt(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(OpenAiGptVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Reformer(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Reformer(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(ReformerVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::ProphetNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::ProphetNet(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(ProphetNetVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::Pegasus(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::Pegasus(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(PegasusVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::MBart50(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::MBart50(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(MBart50Vocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::M2M100(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::M2M100(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(M2M100Vocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
Self::FNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
|
Self::FNet(ref tokenizer) => {
|
||||||
.special_values
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.get(FNetVocab::unknown_value())
|
vocab.token_to_id(vocab.get_unknown_value())
|
||||||
.expect("UNK token not found in vocabulary"),
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_pad_id(&self) -> Option<i64> {
|
pub fn get_pad_id(&self) -> Option<i64> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref tokenizer) => Some(
|
Self::Bert(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
.get(BertVocab::pad_value())
|
}
|
||||||
.expect("PAD token not found in vocabulary"),
|
Self::Deberta(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Deberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
.get(DeBERTaVocab::pad_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("PAD token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
),
|
}
|
||||||
Self::DebertaV2(ref tokenizer) => Some(
|
Self::Roberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
.get(DeBERTaV2Vocab::pad_value())
|
}
|
||||||
.expect("PAD token not found in vocabulary"),
|
Self::Bart(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Roberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
.get(RobertaVocab::pad_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("PAD token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
),
|
}
|
||||||
Self::Bart(ref tokenizer) => Some(
|
Self::Marian(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
.get(RobertaVocab::pad_value())
|
}
|
||||||
.unwrap_or(&1),
|
Self::T5(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLMRoberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::Albert(ref tokenizer) => {
|
||||||
.get(XLMRobertaVocab::pad_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("PAD token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
),
|
}
|
||||||
Self::Marian(ref tokenizer) => Some(
|
Self::XLNet(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
.get(MarianVocab::pad_value())
|
}
|
||||||
.expect("PAD token not found in vocabulary"),
|
Self::ProphetNet(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::T5(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::Pegasus(ref tokenizer) => {
|
||||||
.get(T5Vocab::pad_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("PAD token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
),
|
}
|
||||||
Self::Albert(ref tokenizer) => Some(
|
Self::MBart50(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
.get(AlbertVocab::pad_value())
|
}
|
||||||
.expect("PAD token not found in vocabulary"),
|
Self::M2M100(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLNet(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::FNet(ref tokenizer) => {
|
||||||
.get(XLNetVocab::pad_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("PAD token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_pad_value()))
|
||||||
),
|
}
|
||||||
Self::ProphetNet(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(ProphetNetVocab::pad_value())
|
|
||||||
.expect("PAD token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Pegasus(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(PegasusVocab::pad_value())
|
|
||||||
.unwrap_or(&0),
|
|
||||||
),
|
|
||||||
Self::MBart50(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(MBart50Vocab::pad_value())
|
|
||||||
.expect("PAD token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::M2M100(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(M2M100Vocab::pad_value())
|
|
||||||
.unwrap_or(&1),
|
|
||||||
),
|
|
||||||
Self::FNet(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(FNetVocab::pad_value())
|
|
||||||
.expect("PAD token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Reformer(_) => None,
|
Self::Reformer(_) => None,
|
||||||
Self::GPT2(_) => None,
|
Self::GPT2(_) => None,
|
||||||
Self::OpenAiGpt(_) => None,
|
Self::OpenAiGpt(_) => None,
|
||||||
@ -1452,78 +1418,54 @@ impl TokenizerOption {
|
|||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_sep_id(&self) -> Option<i64> {
|
pub fn get_sep_id(&self) -> Option<i64> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref tokenizer) => Some(
|
Self::Bert(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
.get(BertVocab::sep_value())
|
}
|
||||||
.expect("SEP token not found in vocabulary"),
|
Self::Deberta(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Deberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
.get(DeBERTaVocab::sep_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("SEP token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
),
|
}
|
||||||
Self::DebertaV2(ref tokenizer) => Some(
|
Self::Roberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
.get(DeBERTaV2Vocab::sep_value())
|
}
|
||||||
.expect("SEP token not found in vocabulary"),
|
Self::Bart(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Roberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
.get(RobertaVocab::sep_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("SEP token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
),
|
}
|
||||||
Self::Bart(ref tokenizer) => Some(
|
Self::Albert(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
.get(RobertaVocab::sep_value())
|
}
|
||||||
.expect("SEP token not found in vocabulary"),
|
Self::XLNet(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLMRoberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::ProphetNet(ref tokenizer) => {
|
||||||
.get(XLMRobertaVocab::sep_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("SEP token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
),
|
}
|
||||||
Self::Albert(ref tokenizer) => Some(
|
Self::MBart50(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
.get(AlbertVocab::sep_value())
|
}
|
||||||
.expect("SEP token not found in vocabulary"),
|
Self::M2M100(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLNet(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::FNet(ref tokenizer) => {
|
||||||
.get(XLNetVocab::sep_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("SEP token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_sep_value()))
|
||||||
),
|
}
|
||||||
Self::ProphetNet(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(ProphetNetVocab::sep_value())
|
|
||||||
.expect("SEP token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::MBart50(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(MBart50Vocab::sep_value())
|
|
||||||
.unwrap_or(&1),
|
|
||||||
),
|
|
||||||
Self::M2M100(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(M2M100Vocab::sep_value())
|
|
||||||
.expect("SEP token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::FNet(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(FNetVocab::sep_value())
|
|
||||||
.expect("SEP token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Marian(_) => None,
|
Self::Marian(_) => None,
|
||||||
Self::T5(_) => None,
|
Self::T5(_) => None,
|
||||||
Self::GPT2(_) => None,
|
Self::GPT2(_) => None,
|
||||||
@ -1536,78 +1478,54 @@ impl TokenizerOption {
|
|||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_mask_id(&self) -> Option<i64> {
|
pub fn get_mask_id(&self) -> Option<i64> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref tokenizer) => Some(
|
Self::Bert(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
.get(BertVocab::mask_value())
|
}
|
||||||
.expect("MASK token not found in vocabulary"),
|
Self::Deberta(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Deberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
.get(DeBERTaVocab::mask_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("MASK token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
),
|
}
|
||||||
Self::DebertaV2(ref tokenizer) => Some(
|
Self::Roberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
.get(DeBERTaV2Vocab::mask_value())
|
}
|
||||||
.expect("MASK token not found in vocabulary"),
|
Self::Bart(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Roberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
.get(RobertaVocab::mask_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("MASK token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
),
|
}
|
||||||
Self::Bart(ref tokenizer) => Some(
|
Self::Albert(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
.get(RobertaVocab::mask_value())
|
}
|
||||||
.expect("MASK token not found in vocabulary"),
|
Self::XLNet(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLMRoberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::ProphetNet(ref tokenizer) => {
|
||||||
.get(XLMRobertaVocab::mask_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("MASK token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
),
|
}
|
||||||
Self::Albert(ref tokenizer) => Some(
|
Self::MBart50(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
.get(AlbertVocab::mask_value())
|
}
|
||||||
.expect("MASK token not found in vocabulary"),
|
Self::FNet(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLNet(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::Pegasus(ref tokenizer) => {
|
||||||
.get(XLNetVocab::mask_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("MASK token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_mask_value()))
|
||||||
),
|
}
|
||||||
Self::ProphetNet(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(ProphetNetVocab::mask_value())
|
|
||||||
.expect("MASK token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::MBart50(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(MBart50Vocab::mask_value())
|
|
||||||
.expect("MASK token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::FNet(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(FNetVocab::mask_value())
|
|
||||||
.expect("MASK token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Pegasus(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(PegasusVocab::mask_value())
|
|
||||||
.expect("MASK token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Marian(_) => None,
|
Self::Marian(_) => None,
|
||||||
Self::M2M100(_) => None,
|
Self::M2M100(_) => None,
|
||||||
Self::T5(_) => None,
|
Self::T5(_) => None,
|
||||||
@ -1620,84 +1538,90 @@ impl TokenizerOption {
|
|||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_mask_value(&self) -> Option<&str> {
|
pub fn get_mask_value(&self) -> Option<&str> {
|
||||||
match self {
|
match self {
|
||||||
Self::Bert(_) => Some(BertVocab::mask_value()),
|
Self::Bert(ref tokenizer) => {
|
||||||
Self::Deberta(_) => Some(DeBERTaVocab::mask_value()),
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()),
|
}
|
||||||
Self::Roberta(_) => Some(RobertaVocab::mask_value()),
|
Self::Deberta(ref tokenizer) => {
|
||||||
Self::Bart(_) => Some(RobertaVocab::mask_value()),
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()),
|
}
|
||||||
Self::Albert(_) => Some(AlbertVocab::mask_value()),
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
Self::XLNet(_) => Some(XLNetVocab::mask_value()),
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()),
|
}
|
||||||
Self::MBart50(_) => Some(MBart50Vocab::mask_value()),
|
Self::Roberta(ref tokenizer) => {
|
||||||
Self::FNet(_er) => Some(FNetVocab::mask_value()),
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::Bart(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::Albert(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::XLNet(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::ProphetNet(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::MBart50(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::FNet(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
|
Self::Pegasus(ref tokenizer) => {
|
||||||
|
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
|
||||||
|
}
|
||||||
Self::M2M100(_) => None,
|
Self::M2M100(_) => None,
|
||||||
Self::Marian(_) => None,
|
Self::Marian(_) => None,
|
||||||
Self::T5(_) => None,
|
Self::T5(_) => None,
|
||||||
Self::GPT2(_) => None,
|
Self::GPT2(_) => None,
|
||||||
Self::OpenAiGpt(_) => None,
|
Self::OpenAiGpt(_) => None,
|
||||||
Self::Reformer(_) => None,
|
Self::Reformer(_) => None,
|
||||||
Self::Pegasus(_) => None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_bos_id(&self) -> Option<i64> {
|
pub fn get_bos_id(&self) -> Option<i64> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Roberta(ref tokenizer) => Some(
|
Self::Roberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
.get(RobertaVocab::bos_value())
|
}
|
||||||
.expect("BOS token not found in vocabulary"),
|
Self::Bart(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Bart(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
.get(RobertaVocab::bos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.unwrap_or(&0),
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
),
|
}
|
||||||
Self::DebertaV2(ref tokenizer) => Some(
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
.get(DeBERTaV2Vocab::bos_value())
|
}
|
||||||
.expect("BOS token not found in vocabulary"),
|
Self::Albert(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLMRoberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::XLNet(ref tokenizer) => {
|
||||||
.get(XLMRobertaVocab::bos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("BOS token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
),
|
}
|
||||||
Self::Albert(ref tokenizer) => Some(
|
Self::M2M100(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
.get(AlbertVocab::bos_value())
|
}
|
||||||
.expect("BOS token not found in vocabulary"),
|
Self::GPT2(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLNet(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::Deberta(ref tokenizer) => {
|
||||||
.get(XLNetVocab::bos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("BOS token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_bos_value()))
|
||||||
),
|
}
|
||||||
Self::M2M100(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(M2M100Vocab::bos_value())
|
|
||||||
.unwrap_or(&0),
|
|
||||||
),
|
|
||||||
Self::GPT2(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(Gpt2Vocab::bos_value())
|
|
||||||
.expect("BOS token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Deberta(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(DeBERTaVocab::bos_value())
|
|
||||||
.expect("BOS token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::MBart50(_) => Some(0),
|
Self::MBart50(_) => Some(0),
|
||||||
Self::FNet(_) => None,
|
Self::FNet(_) => None,
|
||||||
Self::Bert(_) => None,
|
Self::Bert(_) => None,
|
||||||
@ -1713,90 +1637,62 @@ impl TokenizerOption {
|
|||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn get_eos_id(&self) -> Option<i64> {
|
pub fn get_eos_id(&self) -> Option<i64> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Roberta(ref tokenizer) => Some(
|
Self::Roberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
.get(RobertaVocab::eos_value())
|
}
|
||||||
.expect("EOS token not found in vocabulary"),
|
Self::Bart(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Bart(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::DebertaV2(ref tokenizer) => {
|
||||||
.get(RobertaVocab::eos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.unwrap_or(&2),
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
),
|
}
|
||||||
Self::DebertaV2(ref tokenizer) => Some(
|
Self::XLMRoberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
.get(DeBERTaV2Vocab::eos_value())
|
}
|
||||||
.expect("EOS token not found in vocabulary"),
|
Self::Albert(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLMRoberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::XLNet(ref tokenizer) => {
|
||||||
.get(XLMRobertaVocab::eos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("EOS token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
),
|
}
|
||||||
Self::Albert(ref tokenizer) => Some(
|
Self::MBart50(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
.get(AlbertVocab::eos_value())
|
}
|
||||||
.expect("EOS token not found in vocabulary"),
|
Self::M2M100(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::XLNet(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::GPT2(ref tokenizer) => {
|
||||||
.get(XLNetVocab::eos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("EOS token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
),
|
}
|
||||||
Self::MBart50(ref tokenizer) => Some(
|
Self::Deberta(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
.get(MBart50Vocab::eos_value())
|
}
|
||||||
.unwrap_or(&2),
|
Self::Marian(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::M2M100(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
Self::T5(ref tokenizer) => {
|
||||||
.get(M2M100Vocab::eos_value())
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.expect("EOS token not found in vocabulary"),
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
),
|
}
|
||||||
Self::GPT2(ref tokenizer) => Some(
|
Self::Reformer(ref tokenizer) => {
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
.special_values
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
.get(Gpt2Vocab::eos_value())
|
}
|
||||||
.unwrap_or(&2),
|
Self::Pegasus(ref tokenizer) => {
|
||||||
),
|
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
|
||||||
Self::Deberta(ref tokenizer) => Some(
|
Some(vocab.token_to_id(vocab.get_eos_value()))
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
}
|
||||||
.special_values
|
|
||||||
.get(DeBERTaVocab::eos_value())
|
|
||||||
.expect("EOS token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Marian(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(MarianVocab::eos_value())
|
|
||||||
.expect("EOS token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::T5(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(T5Vocab::eos_value())
|
|
||||||
.expect("EOS token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Reformer(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(ReformerVocab::eos_value())
|
|
||||||
.expect("EOS token not found in vocabulary"),
|
|
||||||
),
|
|
||||||
Self::Pegasus(ref tokenizer) => Some(
|
|
||||||
*MultiThreadedTokenizer::vocab(tokenizer)
|
|
||||||
.special_values
|
|
||||||
.get(PegasusVocab::eos_value())
|
|
||||||
.unwrap_or(&1),
|
|
||||||
),
|
|
||||||
Self::FNet(_) => None,
|
Self::FNet(_) => None,
|
||||||
Self::Bert(_) => None,
|
Self::Bert(_) => None,
|
||||||
Self::ProphetNet(_) => None,
|
Self::ProphetNet(_) => None,
|
||||||
|
@ -257,8 +257,7 @@ impl MaskedLanguageOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Masked Language is not implemented for {:?}!",
|
"Masked Language is not implemented for {model_type:?}!",
|
||||||
model_type
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -456,8 +456,7 @@ impl QuestionAnsweringOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"QuestionAnswering not implemented for {:?}!",
|
"QuestionAnswering not implemented for {model_type:?}!",
|
||||||
model_type
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,8 +88,7 @@ impl SentenceEmbeddingsBuilder<Local> {
|
|||||||
ModelType::T5 => (model_dir.join("spiece.model"), None),
|
ModelType::T5 => (model_dir.join("spiece.model"), None),
|
||||||
_ => {
|
_ => {
|
||||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Unsupported transformer model {:?} for Sentence Embeddings",
|
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings",
|
||||||
transformer_type
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -408,7 +408,7 @@ mod serde_sentence_embeddings_module_type {
|
|||||||
where
|
where
|
||||||
S: Serializer,
|
S: Serializer,
|
||||||
{
|
{
|
||||||
serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type))
|
serializer.serialize_str(&format!("sentence_transformers.models.{module_type:?}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<SentenceEmbeddingsModuleType, D::Error>
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<SentenceEmbeddingsModuleType, D::Error>
|
||||||
@ -430,7 +430,7 @@ mod serde_sentence_embeddings_module_type {
|
|||||||
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_string())))
|
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_string())))
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(de::Error::custom)?
|
.map_err(de::Error::custom)?
|
||||||
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s))
|
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {s}"))
|
||||||
.map_err(de::Error::custom)
|
.map_err(de::Error::custom)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ where
|
|||||||
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase())))
|
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase())))
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(de::Error::custom)?
|
.map_err(de::Error::custom)?
|
||||||
.ok_or_else(|| format!("Invalid Activation: {}", activation))
|
.ok_or_else(|| format!("Invalid Activation: {activation}"))
|
||||||
.map_err(de::Error::custom)
|
.map_err(de::Error::custom)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,8 +65,7 @@ impl SentenceEmbeddingsOption {
|
|||||||
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
|
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
|
||||||
_ => {
|
_ => {
|
||||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Unsupported transformer model {:?} for Sentence Embeddings",
|
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings"
|
||||||
transformer_type
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -373,8 +373,7 @@ impl SequenceClassificationOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Sequence Classification not implemented for {:?}!",
|
"Sequence Classification not implemented for {model_type:?}!",
|
||||||
model_type
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -484,8 +484,7 @@ impl TokenClassificationOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Token classification not implemented for {:?}!",
|
"Token classification not implemented for {model_type:?}!"
|
||||||
model_type
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -372,8 +372,7 @@ impl TranslationModelBuilder {
|
|||||||
}
|
}
|
||||||
(Some(model_type), _, _) => {
|
(Some(model_type), _, _) => {
|
||||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Automated translation model builder not implemented for {:?}",
|
"Automated translation model builder not implemented for {model_type:?}"
|
||||||
model_type
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -459,91 +458,92 @@ mod model_fetchers {
|
|||||||
source_languages: Option<&Vec<Language>>,
|
source_languages: Option<&Vec<Language>>,
|
||||||
target_languages: Option<&Vec<Language>>,
|
target_languages: Option<&Vec<Language>>,
|
||||||
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||||
let (resources, source_languages, target_languages) =
|
let (resources, source_languages, target_languages) = if let (
|
||||||
if let (Some(source_languages), Some(target_languages)) =
|
Some(source_languages),
|
||||||
(source_languages, target_languages)
|
Some(target_languages),
|
||||||
{
|
) =
|
||||||
match (source_languages.as_slice(), target_languages.as_slice()) {
|
(source_languages, target_languages)
|
||||||
([Language::English], [Language::German]) => {
|
{
|
||||||
get_marian_resources!(ENGLISH2GERMAN)
|
match (source_languages.as_slice(), target_languages.as_slice()) {
|
||||||
}
|
([Language::English], [Language::German]) => {
|
||||||
([Language::English], [Language::Russian]) => {
|
get_marian_resources!(ENGLISH2GERMAN)
|
||||||
get_marian_resources!(ENGLISH2RUSSIAN)
|
|
||||||
}
|
|
||||||
([Language::English], [Language::Dutch]) => {
|
|
||||||
get_marian_resources!(ENGLISH2DUTCH)
|
|
||||||
}
|
|
||||||
([Language::English], [Language::ChineseMandarin]) => {
|
|
||||||
get_marian_resources!(ENGLISH2CHINESE)
|
|
||||||
}
|
|
||||||
([Language::English], [Language::Swedish]) => {
|
|
||||||
get_marian_resources!(ENGLISH2SWEDISH)
|
|
||||||
}
|
|
||||||
([Language::English], [Language::Arabic]) => {
|
|
||||||
get_marian_resources!(ENGLISH2ARABIC)
|
|
||||||
}
|
|
||||||
([Language::English], [Language::Hindi]) => {
|
|
||||||
get_marian_resources!(ENGLISH2HINDI)
|
|
||||||
}
|
|
||||||
([Language::English], [Language::Hebrew]) => {
|
|
||||||
get_marian_resources!(ENGLISH2HEBREW)
|
|
||||||
}
|
|
||||||
([Language::German], [Language::English]) => {
|
|
||||||
get_marian_resources!(GERMAN2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::German], [Language::French]) => {
|
|
||||||
get_marian_resources!(GERMAN2FRENCH)
|
|
||||||
}
|
|
||||||
([Language::French], [Language::German]) => {
|
|
||||||
get_marian_resources!(FRENCH2GERMAN)
|
|
||||||
}
|
|
||||||
([Language::Russian], [Language::English]) => {
|
|
||||||
get_marian_resources!(RUSSIAN2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::Dutch], [Language::English]) => {
|
|
||||||
get_marian_resources!(DUTCH2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::ChineseMandarin], [Language::English]) => {
|
|
||||||
get_marian_resources!(CHINESE2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::Swedish], [Language::English]) => {
|
|
||||||
get_marian_resources!(SWEDISH2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::Arabic], [Language::English]) => {
|
|
||||||
get_marian_resources!(ARABIC2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::Hindi], [Language::English]) => {
|
|
||||||
get_marian_resources!(HINDI2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::Hebrew], [Language::English]) => {
|
|
||||||
get_marian_resources!(HEBREW2ENGLISH)
|
|
||||||
}
|
|
||||||
([Language::English], languages)
|
|
||||||
if languages
|
|
||||||
.iter()
|
|
||||||
.all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) =>
|
|
||||||
{
|
|
||||||
get_marian_resources!(ENGLISH2ROMANCE)
|
|
||||||
}
|
|
||||||
(languages, [Language::English])
|
|
||||||
if languages
|
|
||||||
.iter()
|
|
||||||
.all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) =>
|
|
||||||
{
|
|
||||||
get_marian_resources!(ROMANCE2ENGLISH)
|
|
||||||
}
|
|
||||||
(_, _) => {
|
|
||||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
|
||||||
"No Pretrained Marian configuration found for {:?} to {:?} translation",
|
|
||||||
source_languages, target_languages
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
([Language::English], [Language::Russian]) => {
|
||||||
return Err(RustBertError::InvalidConfigurationError(
|
get_marian_resources!(ENGLISH2RUSSIAN)
|
||||||
"Source and target languages must be provided for Marian models".to_string(),
|
}
|
||||||
));
|
([Language::English], [Language::Dutch]) => {
|
||||||
};
|
get_marian_resources!(ENGLISH2DUTCH)
|
||||||
|
}
|
||||||
|
([Language::English], [Language::ChineseMandarin]) => {
|
||||||
|
get_marian_resources!(ENGLISH2CHINESE)
|
||||||
|
}
|
||||||
|
([Language::English], [Language::Swedish]) => {
|
||||||
|
get_marian_resources!(ENGLISH2SWEDISH)
|
||||||
|
}
|
||||||
|
([Language::English], [Language::Arabic]) => {
|
||||||
|
get_marian_resources!(ENGLISH2ARABIC)
|
||||||
|
}
|
||||||
|
([Language::English], [Language::Hindi]) => {
|
||||||
|
get_marian_resources!(ENGLISH2HINDI)
|
||||||
|
}
|
||||||
|
([Language::English], [Language::Hebrew]) => {
|
||||||
|
get_marian_resources!(ENGLISH2HEBREW)
|
||||||
|
}
|
||||||
|
([Language::German], [Language::English]) => {
|
||||||
|
get_marian_resources!(GERMAN2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::German], [Language::French]) => {
|
||||||
|
get_marian_resources!(GERMAN2FRENCH)
|
||||||
|
}
|
||||||
|
([Language::French], [Language::German]) => {
|
||||||
|
get_marian_resources!(FRENCH2GERMAN)
|
||||||
|
}
|
||||||
|
([Language::Russian], [Language::English]) => {
|
||||||
|
get_marian_resources!(RUSSIAN2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::Dutch], [Language::English]) => {
|
||||||
|
get_marian_resources!(DUTCH2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::ChineseMandarin], [Language::English]) => {
|
||||||
|
get_marian_resources!(CHINESE2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::Swedish], [Language::English]) => {
|
||||||
|
get_marian_resources!(SWEDISH2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::Arabic], [Language::English]) => {
|
||||||
|
get_marian_resources!(ARABIC2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::Hindi], [Language::English]) => {
|
||||||
|
get_marian_resources!(HINDI2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::Hebrew], [Language::English]) => {
|
||||||
|
get_marian_resources!(HEBREW2ENGLISH)
|
||||||
|
}
|
||||||
|
([Language::English], languages)
|
||||||
|
if languages
|
||||||
|
.iter()
|
||||||
|
.all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) =>
|
||||||
|
{
|
||||||
|
get_marian_resources!(ENGLISH2ROMANCE)
|
||||||
|
}
|
||||||
|
(languages, [Language::English])
|
||||||
|
if languages
|
||||||
|
.iter()
|
||||||
|
.all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) =>
|
||||||
|
{
|
||||||
|
get_marian_resources!(ROMANCE2ENGLISH)
|
||||||
|
}
|
||||||
|
(_, _) => {
|
||||||
|
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
|
"No Pretrained Marian configuration found for {source_languages:?} to {target_languages:?} translation",
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(RustBertError::InvalidConfigurationError(
|
||||||
|
"Source and target languages must be provided for Marian models".to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
Ok(TranslationResources {
|
Ok(TranslationResources {
|
||||||
model_type: ModelType::Marian,
|
model_type: ModelType::Marian,
|
||||||
|
@ -135,7 +135,7 @@ pub enum Language {
|
|||||||
impl Display for Language {
|
impl Display for Language {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(f, "{}", {
|
write!(f, "{}", {
|
||||||
let input_string = format!("{:?}", self);
|
let input_string = format!("{self:?}");
|
||||||
let mut output: Vec<&str> = Vec::new();
|
let mut output: Vec<&str> = Vec::new();
|
||||||
let mut start: usize = 0;
|
let mut start: usize = 0;
|
||||||
|
|
||||||
@ -584,8 +584,7 @@ impl TranslationOption {
|
|||||||
if let Some(source_language) = source_language {
|
if let Some(source_language) = source_language {
|
||||||
if !supported_source_languages.contains(source_language) {
|
if !supported_source_languages.contains(source_language) {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"{} not in list of supported languages: {:?}",
|
"{source_language} not in list of supported languages: {supported_source_languages:?}",
|
||||||
source_language, supported_source_languages
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -593,8 +592,7 @@ impl TranslationOption {
|
|||||||
if let Some(target_language) = target_language {
|
if let Some(target_language) = target_language {
|
||||||
if !supported_target_languages.contains(target_language) {
|
if !supported_target_languages.contains(target_language) {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"{} not in list of supported languages: {:?}",
|
"{target_language} not in list of supported languages: {supported_target_languages:?}"
|
||||||
target_language, supported_target_languages
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -610,9 +608,8 @@ impl TranslationOption {
|
|||||||
None => {
|
None => {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Missing target language for Marian \
|
"Missing target language for Marian \
|
||||||
(multiple languages supported by model: {:?}, \
|
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||||
need to specify target language)",
|
need to specify target language)",
|
||||||
supported_target_languages
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -653,9 +650,8 @@ impl TranslationOption {
|
|||||||
None => {
|
None => {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Missing source language for MBart\
|
"Missing source language for MBart\
|
||||||
(multiple languages supported by model: {:?}, \
|
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||||
need to specify target language)",
|
need to specify target language)"
|
||||||
supported_source_languages
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -670,9 +666,8 @@ impl TranslationOption {
|
|||||||
} else {
|
} else {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Missing target language for MBart\
|
"Missing target language for MBart\
|
||||||
(multiple languages supported by model: {:?}, \
|
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||||
need to specify target language)",
|
need to specify target language)"
|
||||||
supported_target_languages
|
|
||||||
)));
|
)));
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -681,8 +676,8 @@ impl TranslationOption {
|
|||||||
Some(value) => {
|
Some(value) => {
|
||||||
let language_code = value.get_iso_639_1_code();
|
let language_code = value.get_iso_639_1_code();
|
||||||
match language_code.len() {
|
match language_code.len() {
|
||||||
2 => format!(">>{}.<< ", language_code),
|
2 => format!(">>{language_code}.<< "),
|
||||||
3 => format!(">>{}<< ", language_code),
|
3 => format!(">>{language_code}<< "),
|
||||||
_ => {
|
_ => {
|
||||||
return Err(RustBertError::ValueError(
|
return Err(RustBertError::ValueError(
|
||||||
"Invalid ISO 639-3 code".to_string(),
|
"Invalid ISO 639-3 code".to_string(),
|
||||||
@ -693,9 +688,8 @@ impl TranslationOption {
|
|||||||
None => {
|
None => {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Missing source language for M2M100 \
|
"Missing source language for M2M100 \
|
||||||
(multiple languages supported by model: {:?}, \
|
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||||
need to specify target language)",
|
need to specify target language)"
|
||||||
supported_source_languages
|
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
@ -704,8 +698,8 @@ impl TranslationOption {
|
|||||||
Some(
|
Some(
|
||||||
model._get_tokenizer().convert_tokens_to_ids(&[
|
model._get_tokenizer().convert_tokens_to_ids(&[
|
||||||
match language_code.len() {
|
match language_code.len() {
|
||||||
2 => format!(">>{}.<<", language_code),
|
2 => format!(">>{language_code}.<<"),
|
||||||
3 => format!(">>{}<<", language_code),
|
3 => format!(">>{language_code}<<"),
|
||||||
_ => {
|
_ => {
|
||||||
return Err(RustBertError::ValueError(
|
return Err(RustBertError::ValueError(
|
||||||
"Invalid ISO 639-3 code".to_string(),
|
"Invalid ISO 639-3 code".to_string(),
|
||||||
@ -717,9 +711,8 @@ impl TranslationOption {
|
|||||||
} else {
|
} else {
|
||||||
return Err(RustBertError::ValueError(format!(
|
return Err(RustBertError::ValueError(format!(
|
||||||
"Missing target language for M2M100 \
|
"Missing target language for M2M100 \
|
||||||
(multiple languages supported by model: {:?}, \
|
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||||
need to specify target language)",
|
need to specify target language)",
|
||||||
supported_target_languages
|
|
||||||
)));
|
)));
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -370,8 +370,7 @@ impl ZeroShotClassificationOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
"Zero shot classification not implemented for {:?}!",
|
"Zero shot classification not implemented for {model_type:?}!",
|
||||||
model_type
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -604,7 +603,7 @@ impl ZeroShotClassificationModel {
|
|||||||
None => labels
|
None => labels
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|label| format!("This example is about {}.", label))
|
.map(|label| format!("This example is about {label}."))
|
||||||
.collect(),
|
.collect(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ use std::borrow::Borrow;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use rust_tokenizers::tokenizer::{ProphetNetTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{ProphetNetTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{ProphetNetVocab, Vocab};
|
use rust_tokenizers::vocab::ProphetNetVocab;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tch::{nn, Kind, Tensor};
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
@ -1098,9 +1098,7 @@ impl
|
|||||||
|
|
||||||
let pad_token = match pad_token_id {
|
let pad_token = match pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => self
|
None => self._get_tokenizer().get_unk_id(),
|
||||||
._get_tokenizer()
|
|
||||||
.convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_ids = token_ids
|
let token_ids = token_ids
|
||||||
|
@ -548,7 +548,7 @@ impl ReformerModelWithLMHead {
|
|||||||
if let Some(lsh_num_chunks_after) = config.lsh_num_chunks_after {
|
if let Some(lsh_num_chunks_after) = config.lsh_num_chunks_after {
|
||||||
if config.attn_layers.contains(&AttentionType::lsh) & (lsh_num_chunks_after != 0) {
|
if config.attn_layers.contains(&AttentionType::lsh) & (lsh_num_chunks_after != 0) {
|
||||||
return Err(RustBertError::InvalidConfigurationError(
|
return Err(RustBertError::InvalidConfigurationError(
|
||||||
format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {})", lsh_num_chunks_after),
|
format!("For text generation using LSH attention ensure `config.lsh_num_chunks_after` is set to 0 (currently {lsh_num_chunks_after})"),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -556,7 +556,7 @@ impl ReformerModelWithLMHead {
|
|||||||
if let Some(local_num_chunks_after) = config.local_num_chunks_after {
|
if let Some(local_num_chunks_after) = config.local_num_chunks_after {
|
||||||
if config.attn_layers.contains(&AttentionType::local) & (local_num_chunks_after != 0) {
|
if config.attn_layers.contains(&AttentionType::local) & (local_num_chunks_after != 0) {
|
||||||
return Err(RustBertError::InvalidConfigurationError(
|
return Err(RustBertError::InvalidConfigurationError(
|
||||||
format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {})", local_num_chunks_after),
|
format!("For text generation using local attention ensure `config.local_num_chunks_after` is set to 0 (currently {local_num_chunks_after})"),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
|
|||||||
[input_sentence, input_sequence_2],
|
[input_sentence, input_sequence_2],
|
||||||
candidate_labels,
|
candidate_labels,
|
||||||
Some(Box::new(|label: &str| {
|
Some(Box::new(|label: &str| {
|
||||||
format!("This example is about {}.", label)
|
format!("This example is about {label}.")
|
||||||
})),
|
})),
|
||||||
128,
|
128,
|
||||||
)?;
|
)?;
|
||||||
@ -244,7 +244,7 @@ fn bart_zero_shot_classification_try_error() -> anyhow::Result<()> {
|
|||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
Some(Box::new(|label: &str| {
|
Some(Box::new(|label: &str| {
|
||||||
format!("This example is about {}.", label)
|
format!("This example is about {label}.")
|
||||||
})),
|
})),
|
||||||
128,
|
128,
|
||||||
);
|
);
|
||||||
@ -276,7 +276,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
|
|||||||
[input_sentence, input_sequence_2],
|
[input_sentence, input_sequence_2],
|
||||||
candidate_labels,
|
candidate_labels,
|
||||||
Some(Box::new(|label: &str| {
|
Some(Box::new(|label: &str| {
|
||||||
format!("This example is about {}.", label)
|
format!("This example is about {label}.")
|
||||||
})),
|
})),
|
||||||
128,
|
128,
|
||||||
)?;
|
)?;
|
||||||
@ -319,7 +319,7 @@ fn bart_zero_shot_classification_multilabel_try_error() -> anyhow::Result<()> {
|
|||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
Some(Box::new(|label: &str| {
|
Some(Box::new(|label: &str| {
|
||||||
format!("This example is about {}.", label)
|
format!("This example is about {label}.")
|
||||||
})),
|
})),
|
||||||
128,
|
128,
|
||||||
);
|
);
|
||||||
|
@ -14,7 +14,7 @@ use rust_bert::pipelines::question_answering::{
|
|||||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
|
use rust_tokenizers::vocab::Vocab;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tch::{nn, no_grad, Device, Tensor};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
@ -67,7 +67,9 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
|
|||||||
.map(|input| input.token_ids.clone())
|
.map(|input| input.token_ids.clone())
|
||||||
.map(|mut input| {
|
.map(|mut input| {
|
||||||
input.extend(vec![
|
input.extend(vec![
|
||||||
tokenizer.vocab().token_to_id(RobertaVocab::pad_value());
|
tokenizer
|
||||||
|
.vocab()
|
||||||
|
.token_to_id(tokenizer.vocab().get_pad_value());
|
||||||
max_len - input.len()
|
max_len - input.len()
|
||||||
]);
|
]);
|
||||||
input
|
input
|
||||||
|
@ -63,8 +63,8 @@ about exoplanets like K2-18b."];
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
output[0],
|
output[0],
|
||||||
"scientists have confirmed the presence of water in the atmosphere of k2 - 18b. \
|
"scientists have confirmed the presence of water in the atmosphere of k2 - 18b. \
|
||||||
[X_SEP] this is the first such discovery in a planet in its star's habitable zone. \
|
this is the first such discovery in a planet in its star's habitable zone. \
|
||||||
[X_SEP] the planet is 110 light - years from earth and has a star in the constellation leo."
|
the planet is 110 light - years from earth and has a star in the constellation leo."
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Loading…
Reference in New Issue
Block a user