diff --git a/benches/translation_benchmark.rs b/benches/translation_benchmark.rs index c1d20ed..0ff342b 100644 --- a/benches/translation_benchmark.rs +++ b/benches/translation_benchmark.rs @@ -4,13 +4,14 @@ extern crate criterion; use criterion::{black_box, Criterion}; // use rust_bert::pipelines::common::ModelType; // use rust_bert::pipelines::translation::TranslationOption::{Marian, T5}; -use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; +use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; // use rust_bert::resources::{LocalResource, Resource}; use std::time::{Duration, Instant}; use tch::Device; fn create_translation_model() -> TranslationModel { - let config = TranslationConfig::new(Language::EnglishToFrenchV2, Device::cuda_if_available()); + let config = + TranslationConfig::new(OldLanguage::EnglishToFrenchV2, Device::cuda_if_available()); // let config = TranslationConfig::new_from_resources( // Resource::Local(LocalResource { // local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(), @@ -46,7 +47,7 @@ fn translation_load_model(iters: u64) -> Duration { for _i in 0..iters { let start = Instant::now(); let config = - TranslationConfig::new(Language::EnglishToFrenchV2, Device::cuda_if_available()); + TranslationConfig::new(OldLanguage::EnglishToFrenchV2, Device::cuda_if_available()); // let config = TranslationConfig::new_from_resources( // Resource::Local(LocalResource { // local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(), diff --git a/examples/translation_marian.rs b/examples/translation_marian.rs index a6ea00f..18bc68c 100644 --- a/examples/translation_marian.rs +++ b/examples/translation_marian.rs @@ -13,12 +13,12 @@ extern crate anyhow; -use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; +use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; use tch::Device; fn main() -> anyhow::Result<()> { let translation_config = - TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available()); + TranslationConfig::new(OldLanguage::EnglishToGerman, Device::cuda_if_available()); let model = TranslationModel::new(translation_config)?; let input_context_1 = "The quick brown fox jumps over the lazy dog"; diff --git a/src/lib.rs b/src/lib.rs index 01f6dad..0d4303d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,10 +167,10 @@ //! ```no_run //! # fn main() -> anyhow::Result<()> { //! # use rust_bert::pipelines::generation_utils::LanguageGenerator; -//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; +//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; //! use tch::Device; //! let translation_config = -//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available()); +//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available()); //! let mut model = TranslationModel::new(translation_config)?; //! //! let input = ["This is a sentence to be translated"]; diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 4a72b42..50b8525 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -19,6 +19,7 @@ use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::{ Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, }; +use crate::pipelines::translation::Language; use crate::{Config, RustBertError}; use rust_tokenizers::tokenizer::{MarianTokenizer, TruncationStrategy}; use rust_tokenizers::vocab::MarianVocab; @@ -41,6 +42,12 @@ pub struct MarianSpmResources; /// # Marian optional prefixes pub struct MarianPrefix; +/// # Marian source languages pre-sets +pub struct MarianSourceLanguages; + +/// # Marian target languages pre-sets +pub struct MarianTargetLanguages; + impl MarianModelResources { /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ( @@ -487,6 +494,68 @@ impl MarianPrefix { pub const HEBREW2ENGLISH: Option<&'static str> = None; } +impl MarianSourceLanguages { + pub const ENGLISH2ROMANCE: [Language; 1] = [Language::English]; + pub const ENGLISH2GERMAN: [Language; 1] = [Language::English]; + pub const ENGLISH2RUSSIAN: [Language; 1] = [Language::English]; + pub const ENGLISH2DUTCH: [Language; 1] = [Language::English]; + pub const ENGLISH2CHINESE: [Language; 1] = [Language::English]; + pub const ENGLISH2SWEDISH: [Language; 1] = [Language::English]; + pub const ENGLISH2ARABIC: [Language; 1] = [Language::English]; + pub const ENGLISH2HINDI: [Language; 1] = [Language::English]; + pub const ENGLISH2HEBREW: [Language; 1] = [Language::English]; + pub const ROMANCE2ENGLISH: [Language; 7] = [ + Language::French, + Language::Spanish, + Language::Italian, + Language::Catalan, + Language::Romanian, + Language::Portuguese, + Language::Occitan, + ]; + pub const GERMAN2ENGLISH: [Language; 1] = [Language::German]; + pub const RUSSIAN2ENGLISH: [Language; 1] = [Language::Russian]; + pub const DUTCH2ENGLISH: [Language; 1] = [Language::Dutch]; + pub const CHINESE2ENGLISH: [Language; 1] = [Language::ChineseMandarin]; + pub const SWEDISH2ENGLISH: [Language; 1] = [Language::Swedish]; + pub const ARABIC2ENGLISH: [Language; 1] = [Language::Arabic]; + pub const HINDI2ENGLISH: [Language; 1] = [Language::Hindi]; + pub const HEBREW2ENGLISH: [Language; 1] = [Language::Hebrew]; + pub const FRENCH2GERMAN: [Language; 1] = [Language::French]; + pub const GERMAN2FRENCH: [Language; 1] = [Language::German]; +} + +impl MarianTargetLanguages { + pub const ENGLISH2ROMANCE: [Language; 7] = [ + Language::French, + Language::Spanish, + Language::Italian, + Language::Catalan, + Language::Romanian, + Language::Portuguese, + Language::Occitan, + ]; + pub const ENGLISH2GERMAN: [Language; 1] = [Language::German]; + pub const ENGLISH2RUSSIAN: [Language; 1] = [Language::Russian]; + pub const ENGLISH2DUTCH: [Language; 1] = [Language::Dutch]; + pub const ENGLISH2CHINESE: [Language; 1] = [Language::ChineseMandarin]; + pub const ENGLISH2SWEDISH: [Language; 1] = [Language::Swedish]; + pub const ENGLISH2ARABIC: [Language; 1] = [Language::Arabic]; + pub const ENGLISH2HINDI: [Language; 1] = [Language::Hindi]; + pub const ENGLISH2HEBREW: [Language; 1] = [Language::Hebrew]; + pub const ROMANCE2ENGLISH: [Language; 1] = [Language::English]; + pub const GERMAN2ENGLISH: [Language; 1] = [Language::English]; + pub const RUSSIAN2ENGLISH: [Language; 1] = [Language::English]; + pub const DUTCH2ENGLISH: [Language; 1] = [Language::English]; + pub const CHINESE2ENGLISH: [Language; 1] = [Language::English]; + pub const SWEDISH2ENGLISH: [Language; 1] = [Language::English]; + pub const ARABIC2ENGLISH: [Language; 1] = [Language::English]; + pub const HINDI2ENGLISH: [Language; 1] = [Language::English]; + pub const HEBREW2ENGLISH: [Language; 1] = [Language::English]; + pub const FRENCH2GERMAN: [Language; 1] = [Language::German]; + pub const GERMAN2FRENCH: [Language; 1] = [Language::French]; +} + /// # Marian Model for conditional generation /// Marian model with a vocabulary decoding head /// It is made of the following blocks: diff --git a/src/marian/mod.rs b/src/marian/mod.rs index b385d7f..3e66438 100644 --- a/src/marian/mod.rs +++ b/src/marian/mod.rs @@ -61,5 +61,6 @@ mod marian_model; pub use marian_model::{ MarianConfigResources, MarianForConditionalGeneration, MarianGenerator, MarianModelResources, - MarianPrefix, MarianSpmResources, MarianVocabResources, + MarianPrefix, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, + MarianVocabResources, }; diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 4bae02c..4523a43 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -56,10 +56,10 @@ //! ```no_run //! # fn main() -> anyhow::Result<()> { //! # use rust_bert::pipelines::generation_utils::LanguageGenerator; -//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; +//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; //! use tch::Device; //! let translation_config = -//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available()); +//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available()); //! let mut model = TranslationModel::new(translation_config)?; //! //! let input = ["This is a sentence to be translated"]; diff --git a/src/pipelines/translation.rs b/src/pipelines/translation.rs index 3fb2a15..4a55c8b 100644 --- a/src/pipelines/translation.rs +++ b/src/pipelines/translation.rs @@ -14,18 +14,16 @@ use tch::{Device, Tensor}; use crate::common::error::RustBertError; -use crate::common::resources::{RemoteResource, Resource}; -use crate::marian::{ - MarianConfigResources, MarianGenerator, MarianModelResources, MarianPrefix, MarianSpmResources, - MarianVocabResources, -}; +use crate::common::resources::Resource; +use crate::marian::MarianGenerator; use crate::pipelines::common::ModelType; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; -use crate::t5::{T5ConfigResources, T5Generator, T5ModelResources, T5Prefix, T5VocabResources}; +use crate::t5::T5Generator; +use std::collections::HashSet; use std::fmt; /// Pretrained languages available for direct use -pub enum Language { +pub enum OldLanguage { FrenchToEnglish, CatalanToEnglish, SpanishToEnglish, @@ -62,8 +60,8 @@ pub enum Language { } /// Language -#[derive(Debug)] -pub enum NewLanguage { +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub enum Language { Afrikaans, Danish, Dutch, @@ -166,7 +164,7 @@ pub enum NewLanguage { HaitianCreole, } -impl fmt::Display for NewLanguage { +impl fmt::Display for Language { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", { let input_string = format!("{:?}", self); @@ -189,506 +187,232 @@ impl fmt::Display for NewLanguage { } } -impl NewLanguage { +impl Language { pub fn get_iso_639_1_code(&self) -> &'static str { match self { - NewLanguage::Afrikaans => "af", - NewLanguage::Danish => "da", - NewLanguage::Dutch => "nl", - NewLanguage::German => "de", - NewLanguage::English => "en", - NewLanguage::Icelandic => "is", - NewLanguage::Luxembourgish => "lb", - NewLanguage::Norwegian => "no", - NewLanguage::Swedish => "sv", - NewLanguage::WesternFrisian => "fy", - NewLanguage::Yiddish => "yi", - NewLanguage::Asturian => "ast", - NewLanguage::Catalan => "ca", - NewLanguage::French => "fr", - NewLanguage::Galician => "gl", - NewLanguage::Italian => "it", - NewLanguage::Occitan => "oc", - NewLanguage::Portuguese => "pt", - NewLanguage::Romanian => "ro", - NewLanguage::Spanish => "es", - NewLanguage::Belarusian => "be", - NewLanguage::Bosnian => "bs", - NewLanguage::Bulgarian => "bg", - NewLanguage::Croatian => "hr", - NewLanguage::Czech => "cs", - NewLanguage::Macedonian => "mk", - NewLanguage::Polish => "pl", - NewLanguage::Russian => "ru", - NewLanguage::Serbian => "sr", - NewLanguage::Slovak => "sk", - NewLanguage::Slovenian => "sl", - NewLanguage::Ukrainian => "uk", - NewLanguage::Estonian => "et", - NewLanguage::Finnish => "fi", - NewLanguage::Hungarian => "hu", - NewLanguage::Latvian => "lv", - NewLanguage::Lithuanian => "lt", - NewLanguage::Albanian => "sq", - NewLanguage::Armenian => "hy", - NewLanguage::Georgian => "ka", - NewLanguage::Greek => "el", - NewLanguage::Breton => "br", - NewLanguage::Irish => "ga", - NewLanguage::ScottishGaelic => "gd", - NewLanguage::Welsh => "cy", - NewLanguage::Azerbaijani => "az", - NewLanguage::Bashkir => "ba", - NewLanguage::Kazakh => "kk", - NewLanguage::Turkish => "tr", - NewLanguage::Uzbek => "uz", - NewLanguage::Japanese => "ja", - NewLanguage::Korean => "ko", - NewLanguage::Vietnamese => "vi", - NewLanguage::ChineseMandarin => "zh", - NewLanguage::Bengali => "bn", - NewLanguage::Gujarati => "gu", - NewLanguage::Hindi => "hi", - NewLanguage::Kannada => "kn", - NewLanguage::Marathi => "mr", - NewLanguage::Nepali => "ne", - NewLanguage::Oriya => "or", - NewLanguage::Panjabi => "pa", - NewLanguage::Sindhi => "sd", - NewLanguage::Sinhala => "si", - NewLanguage::Urdu => "ur", - NewLanguage::Tamil => "ta", - NewLanguage::Cebuano => "ceb", - NewLanguage::Iloko => "ilo", - NewLanguage::Indonesian => "id", - NewLanguage::Javanese => "jv", - NewLanguage::Malagasy => "mg", - NewLanguage::Malay => "ms", - NewLanguage::Malayalam => "ml", - NewLanguage::Sundanese => "su", - NewLanguage::Tagalog => "tl", - NewLanguage::Burmese => "my", - NewLanguage::CentralKhmer => "km", - NewLanguage::Lao => "lo", - NewLanguage::Thai => "th", - NewLanguage::Mongolian => "mn", - NewLanguage::Arabic => "ar", - NewLanguage::Hebrew => "he", - NewLanguage::Pashto => "ps", - NewLanguage::Farsi => "fa", - NewLanguage::Amharic => "am", - NewLanguage::Fulah => "ff", - NewLanguage::Hausa => "ha", - NewLanguage::Igbo => "ig", - NewLanguage::Lingala => "ln", - NewLanguage::Luganda => "lg", - NewLanguage::NorthernSotho => "nso", - NewLanguage::Somali => "so", - NewLanguage::Swahili => "sw", - NewLanguage::Swati => "ss", - NewLanguage::Tswana => "tn", - NewLanguage::Wolof => "wo", - NewLanguage::Xhosa => "xh", - NewLanguage::Yoruba => "yo", - NewLanguage::Zulu => "zu", - NewLanguage::HaitianCreole => "ht", + Language::Afrikaans => "af", + Language::Danish => "da", + Language::Dutch => "nl", + Language::German => "de", + Language::English => "en", + Language::Icelandic => "is", + Language::Luxembourgish => "lb", + Language::Norwegian => "no", + Language::Swedish => "sv", + Language::WesternFrisian => "fy", + Language::Yiddish => "yi", + Language::Asturian => "ast", + Language::Catalan => "ca", + Language::French => "fr", + Language::Galician => "gl", + Language::Italian => "it", + Language::Occitan => "oc", + Language::Portuguese => "pt", + Language::Romanian => "ro", + Language::Spanish => "es", + Language::Belarusian => "be", + Language::Bosnian => "bs", + Language::Bulgarian => "bg", + Language::Croatian => "hr", + Language::Czech => "cs", + Language::Macedonian => "mk", + Language::Polish => "pl", + Language::Russian => "ru", + Language::Serbian => "sr", + Language::Slovak => "sk", + Language::Slovenian => "sl", + Language::Ukrainian => "uk", + Language::Estonian => "et", + Language::Finnish => "fi", + Language::Hungarian => "hu", + Language::Latvian => "lv", + Language::Lithuanian => "lt", + Language::Albanian => "sq", + Language::Armenian => "hy", + Language::Georgian => "ka", + Language::Greek => "el", + Language::Breton => "br", + Language::Irish => "ga", + Language::ScottishGaelic => "gd", + Language::Welsh => "cy", + Language::Azerbaijani => "az", + Language::Bashkir => "ba", + Language::Kazakh => "kk", + Language::Turkish => "tr", + Language::Uzbek => "uz", + Language::Japanese => "ja", + Language::Korean => "ko", + Language::Vietnamese => "vi", + Language::ChineseMandarin => "zh", + Language::Bengali => "bn", + Language::Gujarati => "gu", + Language::Hindi => "hi", + Language::Kannada => "kn", + Language::Marathi => "mr", + Language::Nepali => "ne", + Language::Oriya => "or", + Language::Panjabi => "pa", + Language::Sindhi => "sd", + Language::Sinhala => "si", + Language::Urdu => "ur", + Language::Tamil => "ta", + Language::Cebuano => "ceb", + Language::Iloko => "ilo", + Language::Indonesian => "id", + Language::Javanese => "jv", + Language::Malagasy => "mg", + Language::Malay => "ms", + Language::Malayalam => "ml", + Language::Sundanese => "su", + Language::Tagalog => "tl", + Language::Burmese => "my", + Language::CentralKhmer => "km", + Language::Lao => "lo", + Language::Thai => "th", + Language::Mongolian => "mn", + Language::Arabic => "ar", + Language::Hebrew => "he", + Language::Pashto => "ps", + Language::Farsi => "fa", + Language::Amharic => "am", + Language::Fulah => "ff", + Language::Hausa => "ha", + Language::Igbo => "ig", + Language::Lingala => "ln", + Language::Luganda => "lg", + Language::NorthernSotho => "nso", + Language::Somali => "so", + Language::Swahili => "sw", + Language::Swati => "ss", + Language::Tswana => "tn", + Language::Wolof => "wo", + Language::Xhosa => "xh", + Language::Yoruba => "yo", + Language::Zulu => "zu", + Language::HaitianCreole => "ht", } } pub fn get_iso_639_3_code(&self) -> &'static str { match self { - NewLanguage::Afrikaans => "afr", - NewLanguage::Danish => "dan", - NewLanguage::Dutch => "nld", - NewLanguage::German => "deu", - NewLanguage::English => "eng", - NewLanguage::Icelandic => "isl", - NewLanguage::Luxembourgish => "ltz", - NewLanguage::Norwegian => "nor", - NewLanguage::Swedish => "swe", - NewLanguage::WesternFrisian => "fry", - NewLanguage::Yiddish => "yid", - NewLanguage::Asturian => "ast", - NewLanguage::Catalan => "cat", - NewLanguage::French => "fra", - NewLanguage::Galician => "glg", - NewLanguage::Italian => "ita", - NewLanguage::Occitan => "oci", - NewLanguage::Portuguese => "por", - NewLanguage::Romanian => "ron", - NewLanguage::Spanish => "spa", - NewLanguage::Belarusian => "bel", - NewLanguage::Bosnian => "bos", - NewLanguage::Bulgarian => "bul", - NewLanguage::Croatian => "hrv", - NewLanguage::Czech => "ces", - NewLanguage::Macedonian => "mkd", - NewLanguage::Polish => "pol", - NewLanguage::Russian => "rus", - NewLanguage::Serbian => "srp", - NewLanguage::Slovak => "slk", - NewLanguage::Slovenian => "slv", - NewLanguage::Ukrainian => "ukr", - NewLanguage::Estonian => "est", - NewLanguage::Finnish => "fin", - NewLanguage::Hungarian => "hun", - NewLanguage::Latvian => "lav", - NewLanguage::Lithuanian => "lit", - NewLanguage::Albanian => "sqi", - NewLanguage::Armenian => "hye", - NewLanguage::Georgian => "kat", - NewLanguage::Greek => "ell", - NewLanguage::Breton => "bre", - NewLanguage::Irish => "gle", - NewLanguage::ScottishGaelic => "gla", - NewLanguage::Welsh => "cym", - NewLanguage::Azerbaijani => "aze", - NewLanguage::Bashkir => "bak", - NewLanguage::Kazakh => "kaz", - NewLanguage::Turkish => "tur", - NewLanguage::Uzbek => "uzb", - NewLanguage::Japanese => "jpn", - NewLanguage::Korean => "kor", - NewLanguage::Vietnamese => "vie", - NewLanguage::ChineseMandarin => "cmn", - NewLanguage::Bengali => "ben", - NewLanguage::Gujarati => "guj", - NewLanguage::Hindi => "hin", - NewLanguage::Kannada => "kan", - NewLanguage::Marathi => "mar", - NewLanguage::Nepali => "nep", - NewLanguage::Oriya => "ori", - NewLanguage::Panjabi => "pan", - NewLanguage::Sindhi => "snd", - NewLanguage::Sinhala => "sin", - NewLanguage::Urdu => "urd", - NewLanguage::Tamil => "tam", - NewLanguage::Cebuano => "ceb", - NewLanguage::Iloko => "ilo", - NewLanguage::Indonesian => "ind", - NewLanguage::Javanese => "jav", - NewLanguage::Malagasy => "mlg", - NewLanguage::Malay => "msa", - NewLanguage::Malayalam => "mal", - NewLanguage::Sundanese => "sun", - NewLanguage::Tagalog => "tgl", - NewLanguage::Burmese => "mya", - NewLanguage::CentralKhmer => "khm", - NewLanguage::Lao => "lao", - NewLanguage::Thai => "tha", - NewLanguage::Mongolian => "mon", - NewLanguage::Arabic => "ara", - NewLanguage::Hebrew => "heb", - NewLanguage::Pashto => "pus", - NewLanguage::Farsi => "fas", - NewLanguage::Amharic => "amh", - NewLanguage::Fulah => "ful", - NewLanguage::Hausa => "hau", - NewLanguage::Igbo => "ibo", - NewLanguage::Lingala => "lin", - NewLanguage::Luganda => "lug", - NewLanguage::NorthernSotho => "nso", - NewLanguage::Somali => "som", - NewLanguage::Swahili => "swa", - NewLanguage::Swati => "ssw", - NewLanguage::Tswana => "tsn", - NewLanguage::Wolof => "wol", - NewLanguage::Xhosa => "xho", - NewLanguage::Yoruba => "yor", - NewLanguage::Zulu => "zul", - NewLanguage::HaitianCreole => "hat", + Language::Afrikaans => "afr", + Language::Danish => "dan", + Language::Dutch => "nld", + Language::German => "deu", + Language::English => "eng", + Language::Icelandic => "isl", + Language::Luxembourgish => "ltz", + Language::Norwegian => "nor", + Language::Swedish => "swe", + Language::WesternFrisian => "fry", + Language::Yiddish => "yid", + Language::Asturian => "ast", + Language::Catalan => "cat", + Language::French => "fra", + Language::Galician => "glg", + Language::Italian => "ita", + Language::Occitan => "oci", + Language::Portuguese => "por", + Language::Romanian => "ron", + Language::Spanish => "spa", + Language::Belarusian => "bel", + Language::Bosnian => "bos", + Language::Bulgarian => "bul", + Language::Croatian => "hrv", + Language::Czech => "ces", + Language::Macedonian => "mkd", + Language::Polish => "pol", + Language::Russian => "rus", + Language::Serbian => "srp", + Language::Slovak => "slk", + Language::Slovenian => "slv", + Language::Ukrainian => "ukr", + Language::Estonian => "est", + Language::Finnish => "fin", + Language::Hungarian => "hun", + Language::Latvian => "lav", + Language::Lithuanian => "lit", + Language::Albanian => "sqi", + Language::Armenian => "hye", + Language::Georgian => "kat", + Language::Greek => "ell", + Language::Breton => "bre", + Language::Irish => "gle", + Language::ScottishGaelic => "gla", + Language::Welsh => "cym", + Language::Azerbaijani => "aze", + Language::Bashkir => "bak", + Language::Kazakh => "kaz", + Language::Turkish => "tur", + Language::Uzbek => "uzb", + Language::Japanese => "jpn", + Language::Korean => "kor", + Language::Vietnamese => "vie", + Language::ChineseMandarin => "cmn", + Language::Bengali => "ben", + Language::Gujarati => "guj", + Language::Hindi => "hin", + Language::Kannada => "kan", + Language::Marathi => "mar", + Language::Nepali => "nep", + Language::Oriya => "ori", + Language::Panjabi => "pan", + Language::Sindhi => "snd", + Language::Sinhala => "sin", + Language::Urdu => "urd", + Language::Tamil => "tam", + Language::Cebuano => "ceb", + Language::Iloko => "ilo", + Language::Indonesian => "ind", + Language::Javanese => "jav", + Language::Malagasy => "mlg", + Language::Malay => "msa", + Language::Malayalam => "mal", + Language::Sundanese => "sun", + Language::Tagalog => "tgl", + Language::Burmese => "mya", + Language::CentralKhmer => "khm", + Language::Lao => "lao", + Language::Thai => "tha", + Language::Mongolian => "mon", + Language::Arabic => "ara", + Language::Hebrew => "heb", + Language::Pashto => "pus", + Language::Farsi => "fas", + Language::Amharic => "amh", + Language::Fulah => "ful", + Language::Hausa => "hau", + Language::Igbo => "ibo", + Language::Lingala => "lin", + Language::Luganda => "lug", + Language::NorthernSotho => "nso", + Language::Somali => "som", + Language::Swahili => "swa", + Language::Swati => "ssw", + Language::Tswana => "tsn", + Language::Wolof => "wol", + Language::Xhosa => "xho", + Language::Yoruba => "yor", + Language::Zulu => "zul", + Language::HaitianCreole => "hat", } } pub fn get_marian_code(&self) -> &'static str { match self { - NewLanguage::ChineseMandarin => "cmn_Hans", - NewLanguage::Arabic => self.get_iso_639_3_code(), + Language::ChineseMandarin => "cmn_Hans", + Language::Arabic => self.get_iso_639_3_code(), _ => self.get_iso_639_1_code(), } } } -struct RemoteTranslationResources { - model_resource: (&'static str, &'static str), - config_resource: (&'static str, &'static str), - vocab_resource: (&'static str, &'static str), - merges_resource: (&'static str, &'static str), - prefix: Option<&'static str>, - model_type: ModelType, -} - -impl RemoteTranslationResources { - pub const ENGLISH2FRENCH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ROMANCE, - config_resource: MarianConfigResources::ENGLISH2ROMANCE, - vocab_resource: MarianVocabResources::ENGLISH2ROMANCE, - merges_resource: MarianSpmResources::ENGLISH2ROMANCE, - prefix: MarianPrefix::ENGLISH2FRENCH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2FRENCH_V2: RemoteTranslationResources = Self { - model_resource: T5ModelResources::T5_BASE, - config_resource: T5ConfigResources::T5_BASE, - vocab_resource: T5VocabResources::T5_BASE, - merges_resource: T5VocabResources::T5_BASE, - prefix: T5Prefix::ENGLISH2FRENCH, - model_type: ModelType::T5, - }; - pub const ENGLISH2GERMAN_V2: RemoteTranslationResources = Self { - model_resource: T5ModelResources::T5_BASE, - config_resource: T5ConfigResources::T5_BASE, - vocab_resource: T5VocabResources::T5_BASE, - merges_resource: T5VocabResources::T5_BASE, - prefix: T5Prefix::ENGLISH2GERMAN, - model_type: ModelType::T5, - }; - pub const ENGLISH2CATALAN: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ROMANCE, - config_resource: MarianConfigResources::ENGLISH2ROMANCE, - vocab_resource: MarianVocabResources::ENGLISH2ROMANCE, - merges_resource: MarianSpmResources::ENGLISH2ROMANCE, - prefix: MarianPrefix::ENGLISH2CATALAN, - model_type: ModelType::Marian, - }; - pub const ENGLISH2SPANISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ROMANCE, - config_resource: MarianConfigResources::ENGLISH2ROMANCE, - vocab_resource: MarianVocabResources::ENGLISH2ROMANCE, - merges_resource: MarianSpmResources::ENGLISH2ROMANCE, - prefix: MarianPrefix::ENGLISH2SPANISH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2PORTUGUESE: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ROMANCE, - config_resource: MarianConfigResources::ENGLISH2ROMANCE, - vocab_resource: MarianVocabResources::ENGLISH2ROMANCE, - merges_resource: MarianSpmResources::ENGLISH2ROMANCE, - prefix: MarianPrefix::ENGLISH2PORTUGUESE, - model_type: ModelType::Marian, - }; - pub const ENGLISH2ITALIAN: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ROMANCE, - config_resource: MarianConfigResources::ENGLISH2ROMANCE, - vocab_resource: MarianVocabResources::ENGLISH2ROMANCE, - merges_resource: MarianSpmResources::ENGLISH2ROMANCE, - prefix: MarianPrefix::ENGLISH2ITALIAN, - model_type: ModelType::Marian, - }; - pub const ENGLISH2ROMANIAN: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ROMANCE, - config_resource: MarianConfigResources::ENGLISH2ROMANCE, - vocab_resource: MarianVocabResources::ENGLISH2ROMANCE, - merges_resource: MarianSpmResources::ENGLISH2ROMANCE, - prefix: MarianPrefix::ENGLISH2ROMANIAN, - model_type: ModelType::Marian, - }; - pub const ENGLISH2GERMAN: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2GERMAN, - config_resource: MarianConfigResources::ENGLISH2GERMAN, - vocab_resource: MarianVocabResources::ENGLISH2GERMAN, - merges_resource: MarianSpmResources::ENGLISH2GERMAN, - prefix: MarianPrefix::ENGLISH2GERMAN, - model_type: ModelType::Marian, - }; - pub const ENGLISH2RUSSIAN: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2RUSSIAN, - config_resource: MarianConfigResources::ENGLISH2RUSSIAN, - vocab_resource: MarianVocabResources::ENGLISH2RUSSIAN, - merges_resource: MarianSpmResources::ENGLISH2RUSSIAN, - prefix: MarianPrefix::ENGLISH2RUSSIAN, - model_type: ModelType::Marian, - }; - pub const FRENCH2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ROMANCE2ENGLISH, - config_resource: MarianConfigResources::ROMANCE2ENGLISH, - vocab_resource: MarianVocabResources::ROMANCE2ENGLISH, - merges_resource: MarianSpmResources::ROMANCE2ENGLISH, - prefix: MarianPrefix::FRENCH2ENGLISH, - model_type: ModelType::Marian, - }; - pub const CATALAN2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ROMANCE2ENGLISH, - config_resource: MarianConfigResources::ROMANCE2ENGLISH, - vocab_resource: MarianVocabResources::ROMANCE2ENGLISH, - merges_resource: MarianSpmResources::ROMANCE2ENGLISH, - prefix: MarianPrefix::CATALAN2ENGLISH, - model_type: ModelType::Marian, - }; - pub const SPANISH2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ROMANCE2ENGLISH, - config_resource: MarianConfigResources::ROMANCE2ENGLISH, - vocab_resource: MarianVocabResources::ROMANCE2ENGLISH, - merges_resource: MarianSpmResources::ROMANCE2ENGLISH, - prefix: MarianPrefix::SPANISH2ENGLISH, - model_type: ModelType::Marian, - }; - pub const PORTUGUESE2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ROMANCE2ENGLISH, - config_resource: MarianConfigResources::ROMANCE2ENGLISH, - vocab_resource: MarianVocabResources::ROMANCE2ENGLISH, - merges_resource: MarianSpmResources::ROMANCE2ENGLISH, - prefix: MarianPrefix::PORTUGUESE2ENGLISH, - model_type: ModelType::Marian, - }; - pub const ITALIAN2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ROMANCE2ENGLISH, - config_resource: MarianConfigResources::ROMANCE2ENGLISH, - vocab_resource: MarianVocabResources::ROMANCE2ENGLISH, - merges_resource: MarianSpmResources::ROMANCE2ENGLISH, - prefix: MarianPrefix::ITALIAN2ENGLISH, - model_type: ModelType::Marian, - }; - pub const ROMANIAN2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ROMANCE2ENGLISH, - config_resource: MarianConfigResources::ROMANCE2ENGLISH, - vocab_resource: MarianVocabResources::ROMANCE2ENGLISH, - merges_resource: MarianSpmResources::ROMANCE2ENGLISH, - prefix: MarianPrefix::ROMANIAN2ENGLISH, - model_type: ModelType::Marian, - }; - pub const GERMAN2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::GERMAN2ENGLISH, - config_resource: MarianConfigResources::GERMAN2ENGLISH, - vocab_resource: MarianVocabResources::GERMAN2ENGLISH, - merges_resource: MarianSpmResources::GERMAN2ENGLISH, - prefix: MarianPrefix::GERMAN2ENGLISH, - model_type: ModelType::Marian, - }; - pub const RUSSIAN2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::RUSSIAN2ENGLISH, - config_resource: MarianConfigResources::RUSSIAN2ENGLISH, - vocab_resource: MarianVocabResources::RUSSIAN2ENGLISH, - merges_resource: MarianSpmResources::RUSSIAN2ENGLISH, - prefix: MarianPrefix::RUSSIAN2ENGLISH, - model_type: ModelType::Marian, - }; - pub const FRENCH2GERMAN: RemoteTranslationResources = Self { - model_resource: MarianModelResources::FRENCH2GERMAN, - config_resource: MarianConfigResources::FRENCH2GERMAN, - vocab_resource: MarianVocabResources::FRENCH2GERMAN, - merges_resource: MarianSpmResources::FRENCH2GERMAN, - prefix: MarianPrefix::FRENCH2GERMAN, - model_type: ModelType::Marian, - }; - pub const GERMAN2FRENCH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::GERMAN2FRENCH, - config_resource: MarianConfigResources::GERMAN2FRENCH, - vocab_resource: MarianVocabResources::GERMAN2FRENCH, - merges_resource: MarianSpmResources::GERMAN2FRENCH, - prefix: MarianPrefix::GERMAN2FRENCH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2DUTCH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2DUTCH, - config_resource: MarianConfigResources::ENGLISH2DUTCH, - vocab_resource: MarianVocabResources::ENGLISH2DUTCH, - merges_resource: MarianSpmResources::ENGLISH2DUTCH, - prefix: MarianPrefix::ENGLISH2DUTCH, - model_type: ModelType::Marian, - }; - pub const DUTCH2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::DUTCH2ENGLISH, - config_resource: MarianConfigResources::DUTCH2ENGLISH, - vocab_resource: MarianVocabResources::DUTCH2ENGLISH, - merges_resource: MarianSpmResources::DUTCH2ENGLISH, - prefix: MarianPrefix::DUTCH2ENGLISH, - model_type: ModelType::Marian, - }; - pub const CHINESE2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::CHINESE2ENGLISH, - config_resource: MarianConfigResources::CHINESE2ENGLISH, - vocab_resource: MarianVocabResources::CHINESE2ENGLISH, - merges_resource: MarianSpmResources::CHINESE2ENGLISH, - prefix: MarianPrefix::CHINESE2ENGLISH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2CHINESE_SIMPLIFIED: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2CHINESE, - config_resource: MarianConfigResources::ENGLISH2CHINESE, - vocab_resource: MarianVocabResources::ENGLISH2CHINESE, - merges_resource: MarianSpmResources::ENGLISH2CHINESE, - prefix: MarianPrefix::ENGLISH2CHINESE_SIMPLIFIED, - model_type: ModelType::Marian, - }; - pub const ENGLISH2CHINESE_TRADITIONAL: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2CHINESE, - config_resource: MarianConfigResources::ENGLISH2CHINESE, - vocab_resource: MarianVocabResources::ENGLISH2CHINESE, - merges_resource: MarianSpmResources::ENGLISH2CHINESE, - prefix: MarianPrefix::ENGLISH2CHINESE_TRADITIONAL, - model_type: ModelType::Marian, - }; - pub const ENGLISH2SWEDISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2SWEDISH, - config_resource: MarianConfigResources::ENGLISH2SWEDISH, - vocab_resource: MarianVocabResources::ENGLISH2SWEDISH, - merges_resource: MarianSpmResources::ENGLISH2SWEDISH, - prefix: MarianPrefix::ENGLISH2SWEDISH, - model_type: ModelType::Marian, - }; - pub const SWEDISH2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::SWEDISH2ENGLISH, - config_resource: MarianConfigResources::SWEDISH2ENGLISH, - vocab_resource: MarianVocabResources::SWEDISH2ENGLISH, - merges_resource: MarianSpmResources::SWEDISH2ENGLISH, - prefix: MarianPrefix::SWEDISH2ENGLISH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2ARABIC: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2ARABIC, - config_resource: MarianConfigResources::ENGLISH2ARABIC, - vocab_resource: MarianVocabResources::ENGLISH2ARABIC, - merges_resource: MarianSpmResources::ENGLISH2ARABIC, - prefix: MarianPrefix::ENGLISH2ARABIC, - model_type: ModelType::Marian, - }; - pub const ARABIC2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ARABIC2ENGLISH, - config_resource: MarianConfigResources::ARABIC2ENGLISH, - vocab_resource: MarianVocabResources::ARABIC2ENGLISH, - merges_resource: MarianSpmResources::ARABIC2ENGLISH, - prefix: MarianPrefix::ARABIC2ENGLISH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2HINDI: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2HINDI, - config_resource: MarianConfigResources::ENGLISH2HINDI, - vocab_resource: MarianVocabResources::ENGLISH2HINDI, - merges_resource: MarianSpmResources::ENGLISH2HINDI, - prefix: MarianPrefix::ENGLISH2HINDI, - model_type: ModelType::Marian, - }; - pub const HINDI2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::HINDI2ENGLISH, - config_resource: MarianConfigResources::HINDI2ENGLISH, - vocab_resource: MarianVocabResources::HINDI2ENGLISH, - merges_resource: MarianSpmResources::HINDI2ENGLISH, - prefix: MarianPrefix::HINDI2ENGLISH, - model_type: ModelType::Marian, - }; - pub const ENGLISH2HEBREW: RemoteTranslationResources = Self { - model_resource: MarianModelResources::ENGLISH2HEBREW, - config_resource: MarianConfigResources::ENGLISH2HEBREW, - vocab_resource: MarianVocabResources::ENGLISH2HEBREW, - merges_resource: MarianSpmResources::ENGLISH2HEBREW, - prefix: MarianPrefix::ENGLISH2HEBREW, - model_type: ModelType::Marian, - }; - pub const HEBREW2ENGLISH: RemoteTranslationResources = Self { - model_resource: MarianModelResources::HEBREW2ENGLISH, - config_resource: MarianConfigResources::HEBREW2ENGLISH, - vocab_resource: MarianVocabResources::HEBREW2ENGLISH, - merges_resource: MarianSpmResources::HEBREW2ENGLISH, - prefix: MarianPrefix::HEBREW2ENGLISH, - model_type: ModelType::Marian, - }; -} - /// # Configuration for text translation /// Contains information regarding the model to load, mirrors the GenerationConfig, with a /// different set of default parameters and sets the device to place the model on. pub struct TranslationConfig { + /// Model type used for translation + pub model_type: ModelType, /// Model weights resource (default: pretrained BART model on CNN-DM) pub model_resource: Resource, /// Config resource (default: pretrained BART model on CNN-DM) @@ -697,6 +421,10 @@ pub struct TranslationConfig { pub vocab_resource: Resource, /// Merges resource (default: pretrained BART model on CNN-DM) pub merges_resource: Resource, + /// Supported source languages + pub source_languages: HashSet, + /// Supported target languages + pub target_languages: HashSet, /// Minimum sequence length (default: 0) pub min_length: i64, /// Maximum sequence length (default: 20) @@ -723,18 +451,14 @@ pub struct TranslationConfig { pub num_return_sequences: i64, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, - /// Prefix to append translation inputs with - pub prefix: Option, /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation. pub num_beam_groups: Option, /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5) pub diversity_penalty: Option, - /// Model type used for translation - pub model_type: ModelType, } impl TranslationConfig { - /// Create a new `TranslationCondiguration` from an available language. + /// Create a new `TranslationConfiguration` from an available language. /// /// # Arguments /// @@ -745,159 +469,71 @@ impl TranslationConfig { /// /// ```no_run /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::pipelines::translation::{Language, TranslationConfig}; - /// use tch::Device; - /// - /// let translation_config = - /// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available()); - /// # Ok(()) - /// # } - /// ``` - pub fn new(language: Language, device: Device) -> TranslationConfig { - let translation_resource = match language { - Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH, - Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN, - Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH, - Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE, - Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN, - Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN, - Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN, - Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN, - Language::EnglishToDutch => RemoteTranslationResources::ENGLISH2DUTCH, - Language::EnglishToChineseSimplified => { - RemoteTranslationResources::ENGLISH2CHINESE_SIMPLIFIED - } - Language::EnglishToChineseTraditional => { - RemoteTranslationResources::ENGLISH2CHINESE_TRADITIONAL - } - Language::EnglishToSwedish => RemoteTranslationResources::ENGLISH2SWEDISH, - Language::EnglishToArabic => RemoteTranslationResources::ENGLISH2ARABIC, - Language::EnglishToHindi => RemoteTranslationResources::ENGLISH2HINDI, - Language::EnglishToHebrew => RemoteTranslationResources::ENGLISH2HEBREW, - - Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH, - Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH, - Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH, - Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH, - Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH, - Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH, - Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH, - Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH, - Language::DutchToEnglish => RemoteTranslationResources::DUTCH2ENGLISH, - Language::ChineseToEnglish => RemoteTranslationResources::CHINESE2ENGLISH, - Language::SwedishToEnglish => RemoteTranslationResources::SWEDISH2ENGLISH, - Language::ArabicToEnglish => RemoteTranslationResources::ARABIC2ENGLISH, - Language::HindiToEnglish => RemoteTranslationResources::HINDI2ENGLISH, - Language::HebrewToEnglish => RemoteTranslationResources::HEBREW2ENGLISH, - - Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2, - Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2, - - Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN, - Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH, - }; - let model_resource = Resource::Remote(RemoteResource::from_pretrained( - translation_resource.model_resource, - )); - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - translation_resource.config_resource, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - translation_resource.vocab_resource, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - translation_resource.merges_resource, - )); - let prefix = translation_resource.prefix.map(|value| value.to_string()); - TranslationConfig { - model_resource, - config_resource, - vocab_resource, - merges_resource, - min_length: 0, - max_length: 512, - do_sample: false, - early_stopping: true, - num_beams: 6, - temperature: 1.0, - top_k: 50, - top_p: 1.0, - repetition_penalty: 1.0, - length_penalty: 1.0, - no_repeat_ngram_size: 0, - num_return_sequences: 1, - device, - prefix, - num_beam_groups: None, - diversity_penalty: None, - model_type: translation_resource.model_type, - } - } - - /// Create a new `TranslationConfiguration` from custom (e.g. local) resources. - /// - /// # Arguments - /// - /// * `model_resource` - `Resource` pointing to the model - /// * `config_resource` - `Resource` pointing to the configuration - /// * `vocab_resource` - `Resource` pointing to the vocabulary - /// * `sentence_piece_resource` - `Resource` pointing to the sentence piece model of the source language - /// * `device` - `Device` to place the model on (CPU/GPU) - /// - /// # Example - /// - /// ```no_run - /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::marian::{ + /// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, + /// MarianVocabResources, + /// }; /// use rust_bert::pipelines::common::ModelType; - /// use rust_bert::pipelines::translation::TranslationConfig; - /// use rust_bert::resources::{LocalResource, Resource}; - /// use std::path::PathBuf; + /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig}; + /// use rust_bert::resources::{RemoteResource, Resource}; /// use tch::Device; /// - /// let config_resource = Resource::Local(LocalResource { - /// local_path: PathBuf::from("path/to/config.json"), - /// }); - /// let model_resource = Resource::Local(LocalResource { - /// local_path: PathBuf::from("path/to/model.ot"), - /// }); - /// let vocab_resource = Resource::Local(LocalResource { - /// local_path: PathBuf::from("path/to/vocab.json"), - /// }); - /// let sentence_piece_resource = Resource::Local(LocalResource { - /// local_path: PathBuf::from("path/to/spiece.model"), - /// }); + /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianModelResources::ROMANCE2ENGLISH, + /// )); + /// let config_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianConfigResources::ROMANCE2ENGLISH, + /// )); + /// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianVocabResources::ROMANCE2ENGLISH, + /// )); /// - /// let translation_config = TranslationConfig::new_from_resources( + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; + /// + /// let translation_config = TranslationConfig::new( + /// ModelType::Marian, /// model_resource, /// config_resource, + /// vocab_resource.clone(), /// vocab_resource, - /// sentence_piece_resource, - /// Some(">>fr<<".to_string()), - /// Device::cuda_if_available(), - /// ModelType::Marian, + /// source_languages, + /// target_languages, + /// device: Device::cuda_if_available(), /// ); /// # Ok(()) /// # } /// ``` - pub fn new_from_resources( + pub fn new( + model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, - sentence_piece_resource: Resource, - prefix: Option, - device: Device, - model_type: ModelType, - ) -> TranslationConfig { + merges_resource: Resource, + source_languages: S, + target_languages: T, + device: Option, + ) -> TranslationConfig + where + S: AsRef<[Language]>, + T: AsRef<[Language]>, + { + let device = device.unwrap_or_else(|| Device::cuda_if_available()); + TranslationConfig { + model_type, model_resource, config_resource, vocab_resource, - merges_resource: sentence_piece_resource, + merges_resource, + source_languages: source_languages.as_ref().iter().cloned().collect(), + target_languages: target_languages.as_ref().iter().cloned().collect(), + device, min_length: 0, max_length: 512, do_sample: false, early_stopping: true, - num_beams: 6, + num_beams: 4, temperature: 1.0, top_k: 50, top_p: 1.0, @@ -905,11 +541,8 @@ impl TranslationConfig { length_penalty: 1.0, no_repeat_ngram_size: 0, num_return_sequences: 1, - device, - prefix, num_beam_groups: None, diversity_penalty: None, - model_type, } } } @@ -970,6 +603,58 @@ impl TranslationOption { } } + fn validate_and_get_prefix( + &self, + source_language: Option<&Language>, + target_language: Option<&Language>, + supported_source_languages: &HashSet, + supported_target_languages: &HashSet, + ) -> Result, RustBertError> { + if let Some(source_language) = source_language { + if supported_source_languages.contains(source_language) { + return Err(RustBertError::ValueError(format!( + "{} not in list of supported languages: {:?}", + source_language.to_string(), + supported_source_languages + ))); + } + } + + if let Some(target_language) = target_language { + if supported_target_languages.contains(target_language) { + return Err(RustBertError::ValueError(format!( + "{} not in list of supported languages: {:?}", + target_language.to_string(), + supported_target_languages + ))); + } + } + + Ok(match *self { + Self::Marian(_) => { + if supported_target_languages.len() > 1 { + Some(format!( + ">>{}<< ", + target_language + .expect("Missing target language for Marian") + .get_marian_code() + )) + } else { + None + } + } + Self::T5(_) => Some(format!( + "translate {} to {}:", + source_language + .expect("Missing source language for T5") + .to_string(), + target_language + .expect("Missing target language for T5") + .to_string() + )), + }) + } + /// Interface method to generate() of the particular models. pub fn generate<'a, S>( &self, @@ -1015,7 +700,8 @@ impl TranslationOption { /// # TranslationModel to perform translation pub struct TranslationModel { model: TranslationOption, - prefix: Option, + supported_source_languages: HashSet, + supported_target_languages: HashSet, } impl TranslationModel { @@ -1029,20 +715,50 @@ impl TranslationModel { /// /// ```no_run /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; + /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; /// use tch::Device; + /// use rust_bert::resources::{Resource, RemoteResource}; + /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages}; + /// use rust_bert::pipelines::common::ModelType; /// - /// let translation_config = - /// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available()); + /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianModelResources::ROMANCE2ENGLISH, + /// )); + /// let config_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianConfigResources::ROMANCE2ENGLISH, + /// )); + /// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianVocabResources::ROMANCE2ENGLISH, + /// )); + /// + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; + /// + /// let translation_config = TranslationConfig::new( + /// ModelType::Marian, + /// model_resource, + /// config_resource, + /// vocab_resource.clone(), + /// vocab_resource, + /// source_languages, + /// target_languages, + /// device: Device::cuda_if_available(), + /// ); /// let mut summarization_model = TranslationModel::new(translation_config)?; /// # Ok(()) /// # } /// ``` pub fn new(translation_config: TranslationConfig) -> Result { - let prefix = translation_config.prefix.clone(); + let supported_source_languages = translation_config.source_languages.clone(); + let supported_target_languages = translation_config.target_languages.clone(); + let model = TranslationOption::new(translation_config)?; - Ok(TranslationModel { model, prefix }) + Ok(TranslationModel { + model, + supported_source_languages, + supported_target_languages, + }) } /// Translates texts provided @@ -1058,12 +774,35 @@ impl TranslationModel { /// /// ```no_run /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::pipelines::generation_utils::LanguageGenerator; - /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; + /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; /// use tch::Device; + /// use rust_bert::resources::{Resource, RemoteResource}; + /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages}; + /// use rust_bert::pipelines::common::ModelType; /// - /// let translation_config = - /// TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available()); + /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianModelResources::ROMANCE2ENGLISH, + /// )); + /// let config_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianConfigResources::ROMANCE2ENGLISH, + /// )); + /// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianVocabResources::ROMANCE2ENGLISH, + /// )); + /// + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; + /// + /// let translation_config = TranslationConfig::new( + /// ModelType::Marian, + /// model_resource, + /// config_resource, + /// vocab_resource.clone(), + /// vocab_resource, + /// source_languages, + /// target_languages, + /// device: Device::cuda_if_available(), + /// ); /// let model = TranslationModel::new(translation_config)?; /// /// let input = ["This is a sentence to be translated"]; @@ -1072,11 +811,23 @@ impl TranslationModel { /// # Ok(()) /// # } /// ``` - pub fn translate<'a, S>(&self, texts: S) -> Vec + pub fn translate<'a, S, L>( + &self, + texts: S, + source_language: impl Into>, + target_language: impl Into>, + ) -> Result, RustBertError> where S: AsRef<[&'a str]>, { - match &self.prefix { + let prefix = self.model.validate_and_get_prefix( + source_language.into().as_ref(), + target_language.into().as_ref(), + &self.supported_source_languages, + &self.supported_target_languages, + )?; + + Ok(match prefix { Some(value) => { let texts = texts .as_ref() @@ -1089,18 +840,54 @@ impl TranslationModel { ) } None => self.model.generate(Some(texts), None), - } + }) } } #[cfg(test)] mod test { use super::*; + use crate::marian::{ + MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, + MarianVocabResources, + }; + use crate::resources::RemoteResource; #[test] #[ignore] // no need to run, compilation is enough to verify it is Send fn test() { - let config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available()); + use rust_bert::marian::{ + MarianConfigResources, MarianModelResources, MarianSourceLanguages, + MarianTargetLanguages, MarianVocabResources, + }; + use rust_bert::pipelines::common::ModelType; + use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig}; + use rust_bert::resources::{RemoteResource, Resource}; + use tch::Device; + + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianModelResources::ROMANCE2ENGLISH, + )); + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianConfigResources::ROMANCE2ENGLISH, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianVocabResources::ROMANCE2ENGLISH, + )); + + let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; + + let translation_config = TranslationConfig::new( + ModelType::Marian, + model_resource, + config_resource, + vocab_resource.clone(), + vocab_resource, + source_languages, + target_languages, + device: Device::cuda_if_available(), + ); let _: Box = Box::new(TranslationModel::new(config)); } } diff --git a/src/t5/mod.rs b/src/t5/mod.rs index 99f4b46..a55e464 100644 --- a/src/t5/mod.rs +++ b/src/t5/mod.rs @@ -56,5 +56,5 @@ mod t5_model; pub use attention::LayerState; pub use t5_model::{ T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Generator, T5Model, T5ModelOutput, - T5ModelResources, T5Prefix, T5VocabResources, + T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, T5VocabResources, }; diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 7c06807..943d424 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -27,6 +27,7 @@ use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::{ Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, }; +use crate::pipelines::translation::Language; use crate::t5::attention::LayerState; use crate::t5::encoder::T5Stack; @@ -42,6 +43,12 @@ pub struct T5VocabResources; /// # T5 optional prefixes pub struct T5Prefix; +/// # T5 source languages pre-sets +pub struct T5SourceLanguages; + +/// # T5 target languages pre-sets +pub type T5TargetLanguages = T5SourceLanguages; + impl T5ModelResources { /// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format. pub const T5_SMALL: (&'static str, &'static str) = ( @@ -81,6 +88,13 @@ impl T5VocabResources { ); } +const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German]; + +impl T5SourceLanguages { + pub const T5_SMALL: [Language; 3] = T5LANGUAGES; + pub const T5_BASE: [Language; 3] = T5LANGUAGES; +} + impl T5Prefix { pub const ENGLISH2FRENCH: Option<&'static str> = Some("translate English to French:"); pub const ENGLISH2GERMAN: Option<&'static str> = Some("translate English to German:"); diff --git a/tests/marian.rs b/tests/marian.rs index 5740cb0..063d9e0 100644 --- a/tests/marian.rs +++ b/tests/marian.rs @@ -1,11 +1,11 @@ -use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; +use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; use tch::Device; #[test] // #[cfg_attr(not(feature = "all-tests"), ignore)] fn test_translation() -> anyhow::Result<()> { // Set-up translation model - let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu); + let translation_config = TranslationConfig::new(OldLanguage::EnglishToFrench, Device::Cpu); let model = TranslationModel::new(translation_config)?; let input_context_1 = "The quick brown fox jumps over the lazy dog";