From ce90d8901dd41ae188bcf45fb255211126a341e6 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sun, 11 Jul 2021 11:13:00 +0200 Subject: [PATCH] Updated examples and integration tests --- examples/translation_builder.rs | 8 +- examples/translation_m2m100.rs | 78 +- examples/translation_mbart.rs | 72 +- examples/translation_t5.rs | 70 +- .../translation/translation_builder.rs | 66 +- .../translation/translation_pipeline.rs | 991 ++++++++++++++++++ tests/m2m100.rs | 78 +- tests/marian.rs | 70 +- tests/mbart.rs | 58 +- tests/t5.rs | 60 +- 10 files changed, 1311 insertions(+), 240 deletions(-) create mode 100644 src/pipelines/translation/translation_pipeline.rs diff --git a/examples/translation_builder.rs b/examples/translation_builder.rs index f85968f..0234e28 100644 --- a/examples/translation_builder.rs +++ b/examples/translation_builder.rs @@ -23,17 +23,13 @@ fn main() -> anyhow::Result<()> { .with_model_type(ModelType::Marian) // .with_large_model() .with_source_languages(vec![Language::English]) - .with_target_languages(vec![Language::Hebrew]) + .with_target_languages(vec![Language::Spanish]) .create_model()?; let input_context_1 = "The quick brown fox jumps over the lazy dog."; let input_context_2 = "The dog did not wake up."; - let output = model.translate( - &[input_context_1, input_context_2], - Language::English, - Language::Hebrew, - )?; + let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?; for sentence in output { println!("{}", sentence); diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs index a8d3b62..74c0a65 100644 --- a/examples/translation_m2m100.rs +++ b/examples/translation_m2m100.rs @@ -13,54 +13,52 @@ extern crate anyhow; use rust_bert::m2m_100::{ - M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources, - M2M100VocabResources, + M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages, + M2M100TargetLanguages, M2M100VocabResources, }; -use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::{RemoteResource, Resource}; +use tch::Device; fn main() -> anyhow::Result<()> { - let generate_config = GenerateConfig { - max_length: 512, - min_length: 0, - model_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100ModelResources::M2M100_418M, - )), - config_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100ConfigResources::M2M100_418M, - )), - vocab_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100VocabResources::M2M100_418M, - )), - merges_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100MergesResources::M2M100_418M, - )), - do_sample: false, - early_stopping: true, - num_beams: 3, - no_repeat_ngram_size: 0, - ..Default::default() - }; + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )); + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )); + let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )); - let model = M2M100Generator::new(generate_config)?; + let source_languages = M2M100SourceLanguages::M2M100_418M; + let target_languages = M2M100TargetLanguages::M2M100_418M; - let input_context_1 = ">>en.<< The dog did not wake up."; - let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0]; - - println!("{:?} - {:?}", input_context_1, target_language); - let output = model.generate( - Some(&[input_context_1]), - None, - None, - None, - None, - target_language, - None, - false, + let translation_config = TranslationConfig::new( + ModelType::M2M100, + model_resource, + config_resource, + vocab_resource, + merges_resource, + source_languages, + target_languages, + Device::cuda_if_available(), ); + let model = TranslationModel::new(translation_config)?; - for sentence in output { - println!("{:?}", sentence); + let source_sentence = "This sentence will be translated in multiple languages."; + + let mut outputs = Vec::new(); + outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + + for sentence in outputs { + println!("{}", sentence); } Ok(()) } diff --git a/examples/translation_mbart.rs b/examples/translation_mbart.rs index 64c60c0..aea5a11 100644 --- a/examples/translation_mbart.rs +++ b/examples/translation_mbart.rs @@ -13,48 +13,52 @@ extern crate anyhow; use rust_bert::mbart::{ - MBartConfigResources, MBartGenerator, MBartModelResources, MBartVocabResources, + MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages, + MBartVocabResources, }; -use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::{RemoteResource, Resource}; +use tch::Device; fn main() -> anyhow::Result<()> { - let generate_config = GenerateConfig { - max_length: 56, - model_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartModelResources::MBART50_MANY_TO_MANY, - )), - config_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartConfigResources::MBART50_MANY_TO_MANY, - )), - vocab_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartVocabResources::MBART50_MANY_TO_MANY, - )), - merges_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartVocabResources::MBART50_MANY_TO_MANY, - )), - do_sample: false, - num_beams: 1, - ..Default::default() - }; - let model = MBartGenerator::new(generate_config)?; + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartModelResources::MBART50_MANY_TO_MANY, + )); + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartConfigResources::MBART50_MANY_TO_MANY, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )); + let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )); - let input_context_1 = "en_XX The quick brown fox jumps over the lazy dog."; - let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0]; + let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY; + let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY; - let output = model.generate( - Some(&[input_context_1]), - None, - None, - None, - None, - target_language, - None, - false, + let translation_config = TranslationConfig::new( + ModelType::MBart, + model_resource, + config_resource, + vocab_resource, + merges_resource, + source_languages, + target_languages, + Device::cuda_if_available(), ); + let model = TranslationModel::new(translation_config)?; - for sentence in output { - println!("{:?}", sentence.text); + let source_sentence = "This sentence will be translated in multiple languages."; + + let mut outputs = Vec::new(); + outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + + for sentence in outputs { + println!("{}", sentence); } Ok(()) } diff --git a/examples/translation_t5.rs b/examples/translation_t5.rs index 58280b6..b2efeb8 100644 --- a/examples/translation_t5.rs +++ b/examples/translation_t5.rs @@ -12,46 +12,56 @@ extern crate anyhow; -use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::t5::{T5ConfigResources, T5Generator, T5ModelResources, T5VocabResources}; +use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; +use tch::Device; fn main() -> anyhow::Result<()> { - // Resources paths + let model_resource = + Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE)); + let merges_resource = + Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE)); - let generate_config = GenerateConfig { - model_resource: weights_resource, - vocab_resource, + let source_languages = [ + Language::English, + Language::French, + Language::German, + Language::Romanian, + ]; + let target_languages = [ + Language::English, + Language::French, + Language::German, + Language::Romanian, + ]; + + let translation_config = TranslationConfig::new( + ModelType::T5, + model_resource, config_resource, - max_length: 40, - do_sample: false, - num_beams: 4, - ..Default::default() - }; - - // Set-up model - let t5_model = T5Generator::new(generate_config)?; - - // Define input - let input = ["translate English to German: This sentence will get translated to German"]; - - let output = t5_model.generate( - Some(input.to_vec()), - None, - None, - None, - None, - None, - None, - false, + vocab_resource, + merges_resource, + source_languages, + target_languages, + Device::cuda_if_available(), ); - println!("{:?}", output); + let model = TranslationModel::new(translation_config)?; + let source_sentence = "This sentence will be translated in multiple languages."; + + let mut outputs = Vec::new(); + outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::German)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?); + + for sentence in outputs { + println!("{}", sentence); + } Ok(()) } diff --git a/src/pipelines/translation/translation_builder.rs b/src/pipelines/translation/translation_builder.rs index 277ef0b..f54fba7 100644 --- a/src/pipelines/translation/translation_builder.rs +++ b/src/pipelines/translation/translation_builder.rs @@ -34,14 +34,10 @@ enum ModelSize { XLarge, } -pub struct TranslationModelBuilder -where - S: AsRef<[Language]> + Debug, - T: AsRef<[Language]> + Debug, -{ +pub struct TranslationModelBuilder { model_type: Option, - source_languages: Option, - target_languages: Option, + source_languages: Option>, + target_languages: Option>, device: Option, model_size: Option, } @@ -61,12 +57,8 @@ macro_rules! get_marian_resources { }; } -impl TranslationModelBuilder -where - S: AsRef<[Language]> + Debug, - T: AsRef<[Language]> + Debug, -{ - pub fn new() -> TranslationModelBuilder { +impl TranslationModelBuilder { + pub fn new() -> TranslationModelBuilder { TranslationModelBuilder { model_type: None, source_languages: None, @@ -131,20 +123,26 @@ where self } - pub fn with_source_languages(&mut self, source_languages: S) -> &mut Self { - self.source_languages = Some(source_languages); + pub fn with_source_languages(&mut self, source_languages: S) -> &mut Self + where + S: AsRef<[Language]> + Debug, + { + self.source_languages = Some(source_languages.as_ref().to_vec()); self } - pub fn with_target_languages(&mut self, target_languages: T) -> &mut Self { - self.target_languages = Some(target_languages); + pub fn with_target_languages(&mut self, target_languages: T) -> &mut Self + where + T: AsRef<[Language]> + Debug, + { + self.target_languages = Some(target_languages.as_ref().to_vec()); self } fn get_default_model( &self, - source_languages: Option<&S>, - target_languages: Option<&T>, + source_languages: Option<&Vec>, + target_languages: Option<&Vec>, ) -> Result { Ok( match self.get_marian_model(source_languages, target_languages) { @@ -161,14 +159,14 @@ where fn get_marian_model( &self, - source_languages: Option<&S>, - target_languages: Option<&T>, + source_languages: Option<&Vec>, + target_languages: Option<&Vec>, ) -> Result { let (resources, source_languages, target_languages) = if let (Some(source_languages), Some(target_languages)) = (source_languages, target_languages) { - match (source_languages.as_ref(), target_languages.as_ref()) { + match (source_languages.as_slice(), target_languages.as_slice()) { ([Language::English], [Language::German]) => { get_marian_resources!(ENGLISH2RUSSIAN) } @@ -257,18 +255,17 @@ where fn get_mbart50_resources( &self, - source_languages: Option<&S>, - target_languages: Option<&T>, + source_languages: Option<&Vec>, + target_languages: Option<&Vec>, ) -> Result { if let Some(source_languages) = source_languages { if !source_languages - .as_ref() .iter() .all(|lang| MBartSourceLanguages::MBART50_MANY_TO_MANY.contains(lang)) { return Err(RustBertError::ValueError(format!( "{:?} not in list of supported languages: {:?}", - source_languages.as_ref(), + source_languages, MBartSourceLanguages::MBART50_MANY_TO_MANY ))); } @@ -276,7 +273,6 @@ where if let Some(target_languages) = target_languages { if !target_languages - .as_ref() .iter() .all(|lang| MBartTargetLanguages::MBART50_MANY_TO_MANY.contains(lang)) { @@ -315,18 +311,17 @@ where fn get_m2m100_large_resources( &self, - source_languages: Option<&S>, - target_languages: Option<&T>, + source_languages: Option<&Vec>, + target_languages: Option<&Vec>, ) -> Result { if let Some(source_languages) = source_languages { if !source_languages - .as_ref() .iter() .all(|lang| M2M100SourceLanguages::M2M100_418M.contains(lang)) { return Err(RustBertError::ValueError(format!( "{:?} not in list of supported languages: {:?}", - source_languages.as_ref(), + source_languages, M2M100SourceLanguages::M2M100_418M ))); } @@ -334,7 +329,6 @@ where if let Some(target_languages) = target_languages { if !target_languages - .as_ref() .iter() .all(|lang| M2M100TargetLanguages::M2M100_418M.contains(lang)) { @@ -367,18 +361,17 @@ where fn get_m2m100_xlarge_resources( &self, - source_languages: Option<&S>, - target_languages: Option<&T>, + source_languages: Option<&Vec>, + target_languages: Option<&Vec>, ) -> Result { if let Some(source_languages) = source_languages { if !source_languages - .as_ref() .iter() .all(|lang| M2M100SourceLanguages::M2M100_1_2B.contains(lang)) { return Err(RustBertError::ValueError(format!( "{:?} not in list of supported languages: {:?}", - source_languages.as_ref(), + source_languages, M2M100SourceLanguages::M2M100_1_2B ))); } @@ -386,7 +379,6 @@ where if let Some(target_languages) = target_languages { if !target_languages - .as_ref() .iter() .all(|lang| M2M100TargetLanguages::M2M100_1_2B.contains(lang)) { diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs new file mode 100644 index 0000000..8b98d18 --- /dev/null +++ b/src/pipelines/translation/translation_pipeline.rs @@ -0,0 +1,991 @@ +// Copyright 2018-2020 The HuggingFace Inc. team. +// Copyright 2020 Marian Team Authors +// Copyright 2019-2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use tch::{Device, Tensor}; + +use crate::common::error::RustBertError; +use crate::common::resources::Resource; +use crate::m2m_100::M2M100Generator; +use crate::marian::MarianGenerator; +use crate::mbart::MBartGenerator; +use crate::pipelines::common::ModelType; +use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; +use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use crate::t5::T5Generator; +use std::collections::HashSet; +use std::fmt; +use std::fmt::{Debug, Display}; + +/// Language +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub enum Language { + Afrikaans, + Danish, + Dutch, + German, + English, + Icelandic, + Luxembourgish, + Norwegian, + Swedish, + WesternFrisian, + Yiddish, + Asturian, + Catalan, + French, + Galician, + Italian, + Occitan, + Portuguese, + Romanian, + Spanish, + Belarusian, + Bosnian, + Bulgarian, + Croatian, + Czech, + Macedonian, + Polish, + Russian, + Serbian, + Slovak, + Slovenian, + Ukrainian, + Estonian, + Finnish, + Hungarian, + Latvian, + Lithuanian, + Albanian, + Armenian, + Georgian, + Greek, + Breton, + Irish, + ScottishGaelic, + Welsh, + Azerbaijani, + Bashkir, + Kazakh, + Turkish, + Uzbek, + Japanese, + Korean, + Vietnamese, + ChineseMandarin, + Bengali, + Gujarati, + Hindi, + Kannada, + Marathi, + Nepali, + Oriya, + Panjabi, + Sindhi, + Sinhala, + Urdu, + Tamil, + Cebuano, + Iloko, + Indonesian, + Javanese, + Malagasy, + Malay, + Malayalam, + Sundanese, + Tagalog, + Burmese, + CentralKhmer, + Lao, + Thai, + Mongolian, + Arabic, + Hebrew, + Pashto, + Farsi, + Amharic, + Fulah, + Hausa, + Igbo, + Lingala, + Luganda, + NorthernSotho, + Somali, + Swahili, + Swati, + Tswana, + Wolof, + Xhosa, + Yoruba, + Zulu, + HaitianCreole, +} + +impl Display for Language { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", { + let input_string = format!("{:?}", self); + let mut output: Vec<&str> = Vec::new(); + let mut start: usize = 0; + + for (c_pos, c) in input_string.char_indices() { + if c.is_uppercase() { + if start < c_pos { + output.push(&input_string[start..c_pos]); + } + start = c_pos; + } + } + if start < input_string.len() { + output.push(&input_string[start..]); + } + output.join(" ") + }) + } +} + +impl Language { + pub fn get_iso_639_1_code(&self) -> &'static str { + match self { + Language::Afrikaans => "af", + Language::Danish => "da", + Language::Dutch => "nl", + Language::German => "de", + Language::English => "en", + Language::Icelandic => "is", + Language::Luxembourgish => "lb", + Language::Norwegian => "no", + Language::Swedish => "sv", + Language::WesternFrisian => "fy", + Language::Yiddish => "yi", + Language::Asturian => "ast", + Language::Catalan => "ca", + Language::French => "fr", + Language::Galician => "gl", + Language::Italian => "it", + Language::Occitan => "oc", + Language::Portuguese => "pt", + Language::Romanian => "ro", + Language::Spanish => "es", + Language::Belarusian => "be", + Language::Bosnian => "bs", + Language::Bulgarian => "bg", + Language::Croatian => "hr", + Language::Czech => "cs", + Language::Macedonian => "mk", + Language::Polish => "pl", + Language::Russian => "ru", + Language::Serbian => "sr", + Language::Slovak => "sk", + Language::Slovenian => "sl", + Language::Ukrainian => "uk", + Language::Estonian => "et", + Language::Finnish => "fi", + Language::Hungarian => "hu", + Language::Latvian => "lv", + Language::Lithuanian => "lt", + Language::Albanian => "sq", + Language::Armenian => "hy", + Language::Georgian => "ka", + Language::Greek => "el", + Language::Breton => "br", + Language::Irish => "ga", + Language::ScottishGaelic => "gd", + Language::Welsh => "cy", + Language::Azerbaijani => "az", + Language::Bashkir => "ba", + Language::Kazakh => "kk", + Language::Turkish => "tr", + Language::Uzbek => "uz", + Language::Japanese => "ja", + Language::Korean => "ko", + Language::Vietnamese => "vi", + Language::ChineseMandarin => "zh", + Language::Bengali => "bn", + Language::Gujarati => "gu", + Language::Hindi => "hi", + Language::Kannada => "kn", + Language::Marathi => "mr", + Language::Nepali => "ne", + Language::Oriya => "or", + Language::Panjabi => "pa", + Language::Sindhi => "sd", + Language::Sinhala => "si", + Language::Urdu => "ur", + Language::Tamil => "ta", + Language::Cebuano => "ceb", + Language::Iloko => "ilo", + Language::Indonesian => "id", + Language::Javanese => "jv", + Language::Malagasy => "mg", + Language::Malay => "ms", + Language::Malayalam => "ml", + Language::Sundanese => "su", + Language::Tagalog => "tl", + Language::Burmese => "my", + Language::CentralKhmer => "km", + Language::Lao => "lo", + Language::Thai => "th", + Language::Mongolian => "mn", + Language::Arabic => "ar", + Language::Hebrew => "he", + Language::Pashto => "ps", + Language::Farsi => "fa", + Language::Amharic => "am", + Language::Fulah => "ff", + Language::Hausa => "ha", + Language::Igbo => "ig", + Language::Lingala => "ln", + Language::Luganda => "lg", + Language::NorthernSotho => "nso", + Language::Somali => "so", + Language::Swahili => "sw", + Language::Swati => "ss", + Language::Tswana => "tn", + Language::Wolof => "wo", + Language::Xhosa => "xh", + Language::Yoruba => "yo", + Language::Zulu => "zu", + Language::HaitianCreole => "ht", + } + } + + pub fn get_iso_639_3_code(&self) -> &'static str { + match self { + Language::Afrikaans => "afr", + Language::Danish => "dan", + Language::Dutch => "nld", + Language::German => "deu", + Language::English => "eng", + Language::Icelandic => "isl", + Language::Luxembourgish => "ltz", + Language::Norwegian => "nor", + Language::Swedish => "swe", + Language::WesternFrisian => "fry", + Language::Yiddish => "yid", + Language::Asturian => "ast", + Language::Catalan => "cat", + Language::French => "fra", + Language::Galician => "glg", + Language::Italian => "ita", + Language::Occitan => "oci", + Language::Portuguese => "por", + Language::Romanian => "ron", + Language::Spanish => "spa", + Language::Belarusian => "bel", + Language::Bosnian => "bos", + Language::Bulgarian => "bul", + Language::Croatian => "hrv", + Language::Czech => "ces", + Language::Macedonian => "mkd", + Language::Polish => "pol", + Language::Russian => "rus", + Language::Serbian => "srp", + Language::Slovak => "slk", + Language::Slovenian => "slv", + Language::Ukrainian => "ukr", + Language::Estonian => "est", + Language::Finnish => "fin", + Language::Hungarian => "hun", + Language::Latvian => "lav", + Language::Lithuanian => "lit", + Language::Albanian => "sqi", + Language::Armenian => "hye", + Language::Georgian => "kat", + Language::Greek => "ell", + Language::Breton => "bre", + Language::Irish => "gle", + Language::ScottishGaelic => "gla", + Language::Welsh => "cym", + Language::Azerbaijani => "aze", + Language::Bashkir => "bak", + Language::Kazakh => "kaz", + Language::Turkish => "tur", + Language::Uzbek => "uzb", + Language::Japanese => "jpn", + Language::Korean => "kor", + Language::Vietnamese => "vie", + Language::ChineseMandarin => "cmn", + Language::Bengali => "ben", + Language::Gujarati => "guj", + Language::Hindi => "hin", + Language::Kannada => "kan", + Language::Marathi => "mar", + Language::Nepali => "nep", + Language::Oriya => "ori", + Language::Panjabi => "pan", + Language::Sindhi => "snd", + Language::Sinhala => "sin", + Language::Urdu => "urd", + Language::Tamil => "tam", + Language::Cebuano => "ceb", + Language::Iloko => "ilo", + Language::Indonesian => "ind", + Language::Javanese => "jav", + Language::Malagasy => "mlg", + Language::Malay => "msa", + Language::Malayalam => "mal", + Language::Sundanese => "sun", + Language::Tagalog => "tgl", + Language::Burmese => "mya", + Language::CentralKhmer => "khm", + Language::Lao => "lao", + Language::Thai => "tha", + Language::Mongolian => "mon", + Language::Arabic => "ara", + Language::Hebrew => "heb", + Language::Pashto => "pus", + Language::Farsi => "fas", + Language::Amharic => "amh", + Language::Fulah => "ful", + Language::Hausa => "hau", + Language::Igbo => "ibo", + Language::Lingala => "lin", + Language::Luganda => "lug", + Language::NorthernSotho => "nso", + Language::Somali => "som", + Language::Swahili => "swa", + Language::Swati => "ssw", + Language::Tswana => "tsn", + Language::Wolof => "wol", + Language::Xhosa => "xho", + Language::Yoruba => "yor", + Language::Zulu => "zul", + Language::HaitianCreole => "hat", + } + } +} + +/// # Configuration for text translation +/// Contains information regarding the model to load, mirrors the GenerationConfig, with a +/// different set of default parameters and sets the device to place the model on. +pub struct TranslationConfig { + /// Model type used for translation + pub model_type: ModelType, + /// Model weights resource + pub model_resource: Resource, + /// Config resource + pub config_resource: Resource, + /// Vocab resource + pub vocab_resource: Resource, + /// Merges resource + pub merges_resource: Resource, + /// Supported source languages + pub source_languages: HashSet, + /// Supported target languages + pub target_languages: HashSet, + /// Minimum sequence length (default: 0) + pub min_length: i64, + /// Maximum sequence length (default: 20) + pub max_length: i64, + /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) + pub do_sample: bool, + /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) + pub early_stopping: bool, + /// Number of beams for beam search (default: 5) + pub num_beams: i64, + /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0) + pub temperature: f64, + /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0) + pub top_k: i64, + /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9) + pub top_p: f64, + /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0) + pub repetition_penalty: f64, + /// Exponential penalty based on the length of the hypotheses generated (default: 1.0) + pub length_penalty: f64, + /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3) + pub no_repeat_ngram_size: i64, + /// Number of sequences to return for each prompt text (default: 1) + pub num_return_sequences: i64, + /// Device to place the model on (default: CUDA/GPU when available) + pub device: Device, + /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation. + pub num_beam_groups: Option, + /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5) + pub diversity_penalty: Option, +} + +impl TranslationConfig { + /// Create a new `TranslationConfiguration` from an available language. + /// + /// # Arguments + /// + /// * `language` - `Language` enum value (e.g. `Language::EnglishToFrench`) + /// * `device` - `Device` to place the model on (CPU/GPU) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::marian::{ + /// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, + /// MarianVocabResources, + /// }; + /// use rust_bert::pipelines::common::ModelType; + /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig}; + /// use rust_bert::resources::{RemoteResource, Resource}; + /// use tch::Device; + /// + /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianModelResources::ROMANCE2ENGLISH, + /// )); + /// let config_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianConfigResources::ROMANCE2ENGLISH, + /// )); + /// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianVocabResources::ROMANCE2ENGLISH, + /// )); + /// + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect(); + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect(); + /// + /// let translation_config = TranslationConfig::new( + /// ModelType::Marian, + /// model_resource, + /// config_resource, + /// vocab_resource.clone(), + /// vocab_resource, + /// source_languages, + /// target_languages, + /// device: Device::cuda_if_available(), + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn new( + model_type: ModelType, + model_resource: Resource, + config_resource: Resource, + vocab_resource: Resource, + merges_resource: Resource, + source_languages: S, + target_languages: T, + device: impl Into>, + ) -> TranslationConfig + where + S: AsRef<[Language]>, + T: AsRef<[Language]>, + { + let device = device.into().unwrap_or_else(|| Device::cuda_if_available()); + + TranslationConfig { + model_type, + model_resource, + config_resource, + vocab_resource, + merges_resource, + source_languages: source_languages.as_ref().iter().cloned().collect(), + target_languages: target_languages.as_ref().iter().cloned().collect(), + device, + min_length: 0, + max_length: 512, + do_sample: false, + early_stopping: true, + num_beams: 3, + temperature: 1.0, + top_k: 50, + top_p: 1.0, + repetition_penalty: 1.0, + length_penalty: 1.0, + no_repeat_ngram_size: 0, + num_return_sequences: 1, + num_beam_groups: None, + diversity_penalty: None, + } + } +} + +impl From for GenerateConfig { + fn from(config: TranslationConfig) -> GenerateConfig { + GenerateConfig { + model_resource: config.model_resource, + config_resource: config.config_resource, + merges_resource: config.merges_resource, + vocab_resource: config.vocab_resource, + min_length: config.min_length, + max_length: config.max_length, + do_sample: config.do_sample, + early_stopping: config.early_stopping, + num_beams: config.num_beams, + temperature: config.temperature, + top_k: config.top_k, + top_p: config.top_p, + repetition_penalty: config.repetition_penalty, + length_penalty: config.length_penalty, + no_repeat_ngram_size: config.no_repeat_ngram_size, + num_return_sequences: config.num_return_sequences, + num_beam_groups: config.num_beam_groups, + diversity_penalty: config.diversity_penalty, + device: config.device, + } + } +} + +/// # Abstraction that holds one particular translation model, for any of the supported models +pub enum TranslationOption { + /// Translator based on Marian model + Marian(MarianGenerator), + /// Translator based on T5 model + T5(T5Generator), + /// Translator based on MBart50 model + MBart(MBartGenerator), + /// Translator based on M2M100 model + M2M100(M2M100Generator), +} + +impl TranslationOption { + pub fn new(config: TranslationConfig) -> Result { + match config.model_type { + ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new( + config.into(), + )?)), + ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)), + ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new( + config.into(), + )?)), + ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new( + config.into(), + )?)), + _ => Err(RustBertError::InvalidConfigurationError(format!( + "Translation not implemented for {:?}!", + config.model_type + ))), + } + } + + /// Returns the `ModelType` for this TranslationOption + pub fn model_type(&self) -> ModelType { + match *self { + Self::Marian(_) => ModelType::Marian, + Self::T5(_) => ModelType::T5, + Self::MBart(_) => ModelType::MBart, + Self::M2M100(_) => ModelType::M2M100, + } + } + + fn validate_and_get_prefix_and_forced_bos_id( + &self, + source_language: Option<&Language>, + target_language: Option<&Language>, + supported_source_languages: &HashSet, + supported_target_languages: &HashSet, + ) -> Result<(Option, Option), RustBertError> { + if let Some(source_language) = source_language { + if !supported_source_languages.contains(source_language) { + return Err(RustBertError::ValueError(format!( + "{} not in list of supported languages: {:?}", + source_language.to_string(), + supported_source_languages + ))); + } + } + + if let Some(target_language) = target_language { + if !supported_target_languages.contains(target_language) { + return Err(RustBertError::ValueError(format!( + "{} not in list of supported languages: {:?}", + target_language.to_string(), + supported_target_languages + ))); + } + } + + Ok(match *self { + Self::Marian(_) => { + if supported_target_languages.len() > 1 { + ( + Some(format!( + ">>{}<< ", + match target_language { + Some(value) => value.get_iso_639_1_code(), + None => { + return Err(RustBertError::ValueError(format!( + "Missing target language for Marian \ + (multiple languages supported by model: {:?}, \ + need to specify target language)", + supported_target_languages + ))); + } + } + )), + None, + ) + } else { + (None, None) + } + } + Self::T5(_) => ( + Some(format!( + "translate {} to {}:", + match source_language { + Some(value) => value.to_string(), + None => { + return Err(RustBertError::ValueError( + "Missing source language for T5".to_string(), + )); + } + }, + match target_language { + Some(value) => value.to_string(), + None => { + return Err(RustBertError::ValueError( + "Missing target language for T5".to_string(), + )); + } + } + )), + None, + ), + Self::MBart(ref model) => ( + Some(format!( + ">>{}<< ", + match source_language { + Some(value) => value.get_iso_639_1_code(), + None => { + return Err(RustBertError::ValueError(format!( + "Missing source language for MBart\ + (multiple languages supported by model: {:?}, \ + need to specify target language)", + supported_source_languages + ))); + } + } + )), + if let Some(target_language) = target_language { + Some( + model._get_tokenizer().convert_tokens_to_ids([format!( + ">>{}<<", + target_language.get_iso_639_1_code() + )])[0], + ) + } else { + return Err(RustBertError::ValueError(format!( + "Missing target language for MBart\ + (multiple languages supported by model: {:?}, \ + need to specify target language)", + supported_target_languages + ))); + }, + ), + Self::M2M100(ref model) => ( + Some(match source_language { + Some(value) => { + let language_code = value.get_iso_639_1_code(); + match language_code.len() { + 2 => format!(">>{}.<< ", language_code), + 3 => format!(">>{}<< ", language_code), + _ => { + return Err(RustBertError::ValueError( + "Invalid ISO 639-3 code".to_string(), + )); + } + } + } + None => { + return Err(RustBertError::ValueError(format!( + "Missing source language for M2M100 \ + (multiple languages supported by model: {:?}, \ + need to specify target language)", + supported_source_languages + ))); + } + }), + if let Some(target_language) = target_language { + let language_code = target_language.get_iso_639_1_code(); + Some( + model + ._get_tokenizer() + .convert_tokens_to_ids([match language_code.len() { + 2 => format!(">>{}.<<", language_code), + 3 => format!(">>{}<<", language_code), + _ => { + return Err(RustBertError::ValueError( + "Invalid ISO 639-3 code".to_string(), + )); + } + }])[0], + ) + } else { + return Err(RustBertError::ValueError(format!( + "Missing target language for M2M100 \ + (multiple languages supported by model: {:?}, \ + need to specify target language)", + supported_target_languages + ))); + }, + ), + }) + } + + /// Interface method to generate() of the particular models. + pub fn generate<'a, S>( + &self, + prompt_texts: Option, + attention_mask: Option, + forced_bos_token_id: Option, + ) -> Vec + where + S: AsRef<[&'a str]>, + { + match *self { + Self::Marian(ref model) => model + .generate( + prompt_texts, + attention_mask, + None, + None, + None, + None, + None, + false, + ) + .into_iter() + .map(|output| output.text) + .collect(), + Self::T5(ref model) => model + .generate( + prompt_texts, + attention_mask, + None, + None, + None, + None, + None, + false, + ) + .into_iter() + .map(|output| output.text) + .collect(), + Self::MBart(ref model) => model + .generate( + prompt_texts, + attention_mask, + None, + None, + None, + forced_bos_token_id, + None, + false, + ) + .into_iter() + .map(|output| output.text) + .collect(), + Self::M2M100(ref model) => model + .generate( + prompt_texts, + attention_mask, + None, + None, + None, + forced_bos_token_id, + None, + false, + ) + .into_iter() + .map(|output| output.text) + .collect(), + } + } +} + +/// # TranslationModel to perform translation +pub struct TranslationModel { + model: TranslationOption, + supported_source_languages: HashSet, + supported_target_languages: HashSet, +} + +impl TranslationModel { + /// Build a new `TranslationModel` + /// + /// # Arguments + /// + /// * `translation_config` - `TranslationConfig` object containing the resource references (model, vocabulary, configuration), translation options and device placement (CPU/GPU) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; + /// use tch::Device; + /// use rust_bert::resources::{Resource, RemoteResource}; + /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages}; + /// use rust_bert::pipelines::common::ModelType; + /// + /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianModelResources::ROMANCE2ENGLISH, + /// )); + /// let config_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianConfigResources::ROMANCE2ENGLISH, + /// )); + /// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianVocabResources::ROMANCE2ENGLISH, + /// )); + /// + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect(); + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect(); + /// + /// let translation_config = TranslationConfig::new( + /// ModelType::Marian, + /// model_resource, + /// config_resource, + /// vocab_resource.clone(), + /// vocab_resource, + /// source_languages, + /// target_languages, + /// device: Device::cuda_if_available(), + /// ); + /// let mut summarization_model = TranslationModel::new(translation_config)?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(translation_config: TranslationConfig) -> Result { + let supported_source_languages = translation_config.source_languages.clone(); + let supported_target_languages = translation_config.target_languages.clone(); + + let model = TranslationOption::new(translation_config)?; + + Ok(TranslationModel { + model, + supported_source_languages, + supported_target_languages, + }) + } + + /// Translates texts provided + /// + /// # Arguments + /// + /// * `input` - `&[&str]` Array of texts to summarize. + /// + /// # Returns + /// * `Vec` Translated texts + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel, Language}; + /// use tch::Device; + /// use rust_bert::resources::{Resource, RemoteResource}; + /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages, MarianSpmResources}; + /// use rust_bert::pipelines::common::ModelType; + /// + /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianModelResources::ENGLISH2ROMANCE, + /// )); + /// let config_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianConfigResources::ENGLISH2ROMANCE, + /// )); + /// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianVocabResources::ENGLISH2ROMANCE, + /// )); + /// let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + /// MarianSpmResources::ENGLISH2ROMANCE, + /// )); + /// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE.iter().collect(); + /// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE.iter().collect(); + /// + /// let translation_config = TranslationConfig::new( + /// ModelType::Marian, + /// model_resource, + /// config_resource, + /// vocab_resource, + /// merges_resource, + /// source_languages, + /// target_languages, + /// device: Device::cuda_if_available(), + /// ); + /// let model = TranslationModel::new(translation_config)?; + /// + /// let input = ["This is a sentence to be translated"]; + /// let source_language = None; + /// let target_language = Language::French; + /// + /// let output = model.translate(&input, source_language, target_language); + /// # Ok(()) + /// # } + /// ``` + pub fn translate<'a, S>( + &self, + texts: S, + source_language: impl Into>, + target_language: impl Into>, + ) -> Result, RustBertError> + where + S: AsRef<[&'a str]>, + { + let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id( + source_language.into().as_ref(), + target_language.into().as_ref(), + &self.supported_source_languages, + &self.supported_target_languages, + )?; + + Ok(match prefix { + Some(value) => { + let texts = texts + .as_ref() + .iter() + .map(|&v| format!("{}{}", value, v)) + .collect::>(); + self.model.generate( + Some(texts.iter().map(AsRef::as_ref).collect::>()), + None, + forced_bos_token_id, + ) + } + None => self.model.generate(Some(texts), None, forced_bos_token_id), + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::marian::{ + MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, + MarianVocabResources, + }; + use crate::resources::RemoteResource; + + #[test] + #[ignore] // no need to run, compilation is enough to verify it is Send + fn test() { + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianModelResources::ROMANCE2ENGLISH, + )); + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianConfigResources::ROMANCE2ENGLISH, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianVocabResources::ROMANCE2ENGLISH, + )); + + let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; + + let translation_config = TranslationConfig::new( + ModelType::Marian, + model_resource, + config_resource, + vocab_resource.clone(), + vocab_resource, + source_languages, + target_languages, + Device::cuda_if_available(), + ); + let _: Box = Box::new(TranslationModel::new(translation_config)); + } +} diff --git a/tests/m2m100.rs b/tests/m2m100.rs index 60c2cc4..db01b6d 100644 --- a/tests/m2m100.rs +++ b/tests/m2m100.rs @@ -1,8 +1,9 @@ use rust_bert::m2m_100::{ - M2M100Config, M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100Model, - M2M100ModelResources, M2M100VocabResources, + M2M100Config, M2M100ConfigResources, M2M100MergesResources, M2M100Model, M2M100ModelResources, + M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources, }; -use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::Config; use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy}; @@ -75,43 +76,48 @@ fn m2m100_lm_model() -> anyhow::Result<()> { #[test] fn m2m100_translation() -> anyhow::Result<()> { - // Resources paths - let generate_config = GenerateConfig { - max_length: 56, - model_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100ModelResources::M2M100_418M, - )), - config_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100ConfigResources::M2M100_418M, - )), - vocab_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100VocabResources::M2M100_418M, - )), - merges_resource: Resource::Remote(RemoteResource::from_pretrained( - M2M100MergesResources::M2M100_418M, - )), - do_sample: false, - num_beams: 3, - ..Default::default() - }; - let model = M2M100Generator::new(generate_config)?; + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )); + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )); + let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )); - let input_context = ">>en.<< The dog did not wake up."; - let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0]; + let source_languages = M2M100SourceLanguages::M2M100_418M; + let target_languages = M2M100TargetLanguages::M2M100_418M; - let output = model.generate( - Some(&[input_context]), - None, - None, - None, - None, - target_language, - None, - false, + let translation_config = TranslationConfig::new( + ModelType::M2M100, + model_resource, + config_resource, + vocab_resource, + merges_resource, + source_languages, + target_languages, + Device::cuda_if_available(), ); + let model = TranslationModel::new(translation_config)?; - assert_eq!(output.len(), 1); - assert_eq!(output[0].text, ">>es.<< El perro no se despertó."); + let source_sentence = "This sentence will be translated in multiple languages."; + + let mut outputs = Vec::new(); + outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + + assert_eq!(outputs.len(), 3); + assert_eq!( + outputs[0], + " Cette phrase sera traduite en plusieurs langues." + ); + assert_eq!(outputs[1], " Esta frase se traducirá en varios idiomas."); + assert_eq!(outputs[2], " यह वाक्यांश कई भाषाओं में अनुवादित किया जाएगा।"); Ok(()) } diff --git a/tests/marian.rs b/tests/marian.rs index 063d9e0..699dce0 100644 --- a/tests/marian.rs +++ b/tests/marian.rs @@ -1,24 +1,82 @@ -use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; +use rust_bert::marian::{ + MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, + MarianTargetLanguages, MarianVocabResources, +}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::translation::{ + Language, TranslationConfig, TranslationModel, TranslationModelBuilder, +}; +use rust_bert::resources::{RemoteResource, Resource}; use tch::Device; #[test] // #[cfg_attr(not(feature = "all-tests"), ignore)] fn test_translation() -> anyhow::Result<()> { // Set-up translation model - let translation_config = TranslationConfig::new(OldLanguage::EnglishToFrench, Device::Cpu); + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianModelResources::ENGLISH2ROMANCE, + )); + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianConfigResources::ENGLISH2ROMANCE, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianVocabResources::ENGLISH2ROMANCE, + )); + let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + MarianSpmResources::ENGLISH2ROMANCE, + )); + + let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE; + let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE; + + let translation_config = TranslationConfig::new( + ModelType::Marian, + model_resource, + config_resource, + vocab_resource, + merges_resource, + source_languages, + target_languages, + Device::cuda_if_available(), + ); let model = TranslationModel::new(translation_config)?; let input_context_1 = "The quick brown fox jumps over the lazy dog"; let input_context_2 = "The dog did not wake up"; - let output = model.translate(&[input_context_1, input_context_2]); + let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?; - assert_eq!(output.len(), 2); + assert_eq!(outputs.len(), 2); assert_eq!( - output[0], + outputs[0], " Le rapide renard brun saute sur le chien paresseux" ); - assert_eq!(output[1], " Le chien ne s'est pas réveillé"); + assert_eq!(outputs[1], " Le chien ne s'est pas réveillé"); + + Ok(()) +} + +#[test] +// #[cfg_attr(not(feature = "all-tests"), ignore)] +fn test_translation_builder() -> anyhow::Result<()> { + let model = TranslationModelBuilder::new() + .with_device(Device::cuda_if_available()) + .with_model_type(ModelType::Marian) + .with_source_languages(vec![Language::English]) + .with_target_languages(vec![Language::French]) + .create_model()?; + + let input_context_1 = "The quick brown fox jumps over the lazy dog"; + let input_context_2 = "The dog did not wake up"; + + let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?; + + assert_eq!(outputs.len(), 2); + assert_eq!( + outputs[0], + " Le rapide renard brun saute sur le chien paresseux" + ); + assert_eq!(outputs[1], " Le chien ne s'est pas réveillé"); Ok(()) } diff --git a/tests/mbart.rs b/tests/mbart.rs index 552a525..f851f7e 100644 --- a/tests/mbart.rs +++ b/tests/mbart.rs @@ -1,8 +1,8 @@ use rust_bert::mbart::{ - MBartConfig, MBartConfigResources, MBartGenerator, MBartModel, MBartModelResources, - MBartVocabResources, + MBartConfig, MBartConfigResources, MBartModel, MBartModelResources, MBartVocabResources, }; -use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::translation::{Language, TranslationModelBuilder}; use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::Config; use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy}; @@ -65,46 +65,28 @@ fn mbart_lm_model() -> anyhow::Result<()> { #[test] fn mbart_translation() -> anyhow::Result<()> { - // Resources paths - let generate_config = GenerateConfig { - max_length: 56, - model_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartModelResources::MBART50_MANY_TO_MANY, - )), - config_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartConfigResources::MBART50_MANY_TO_MANY, - )), - vocab_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartVocabResources::MBART50_MANY_TO_MANY, - )), - merges_resource: Resource::Remote(RemoteResource::from_pretrained( - MBartVocabResources::MBART50_MANY_TO_MANY, - )), - do_sample: false, - num_beams: 3, - ..Default::default() - }; - let model = MBartGenerator::new(generate_config)?; + let model = TranslationModelBuilder::new() + .with_device(Device::cuda_if_available()) + .with_model_type(ModelType::MBart) + .create_model()?; - let input_context = "en_XX The quick brown fox jumps over the lazy dog."; - let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0]; + let source_sentence = "This sentence will be translated in multiple languages."; - let output = model.generate( - Some(&[input_context]), - None, - None, - None, - None, - target_language, - None, - false, - ); + let mut outputs = Vec::new(); + outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); - assert_eq!(output.len(), 1); + assert_eq!(outputs.len(), 3); assert_eq!( - output[0].text, - "de_DE Der schnelle braune Fuchs springt über den faulen Hund." + outputs[0], + " Cette phrase sera traduite en plusieurs langues." ); + assert_eq!( + outputs[1], + " Esta frase será traducida en múltiples idiomas." + ); + assert_eq!(outputs[2], " यह वाक्य कई भाषाओं में अनुवाद किया जाएगा."); Ok(()) } diff --git a/tests/t5.rs b/tests/t5.rs index 587572e..c370543 100644 --- a/tests/t5.rs +++ b/tests/t5.rs @@ -1,31 +1,65 @@ use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; -use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel}; +use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; use tch::Device; #[test] fn test_translation_t5() -> anyhow::Result<()> { - // Set-up translation model - let translation_config = TranslationConfig::new_from_resources( - Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)), - Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)), - Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), - Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), - Some("translate English to French:".to_string()), - Device::cuda_if_available(), + let model_resource = + Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)); + let config_resource = + Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)); + let vocab_resource = + Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)); + let merges_resource = + Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)); + + let source_languages = [ + Language::English, + Language::French, + Language::German, + Language::Romanian, + ]; + let target_languages = [ + Language::English, + Language::French, + Language::German, + Language::Romanian, + ]; + + let translation_config = TranslationConfig::new( ModelType::T5, + model_resource, + config_resource, + vocab_resource, + merges_resource, + source_languages, + target_languages, + Device::cuda_if_available(), ); let model = TranslationModel::new(translation_config)?; - let input_context = "The quick brown fox jumps over the lazy dog."; + let source_sentence = "This sentence will be translated in multiple languages."; - let output = model.translate(&[input_context]); + let mut outputs = Vec::new(); + outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::German)?); + outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?); + assert_eq!(outputs.len(), 3); assert_eq!( - output[0], - " Le renard brun rapide saute au-dessus du chien paresseux." + outputs[0], + " Cette phrase sera traduite dans plusieurs langues." + ); + assert_eq!( + outputs[1], + " Dieser Satz wird in mehreren Sprachen übersetzt." + ); + assert_eq!( + outputs[2], + " Această frază va fi tradusă în mai multe limbi." ); Ok(())