diff --git a/examples/translation_builder.rs b/examples/translation_builder.rs index a869c70..34b7314 100644 --- a/examples/translation_builder.rs +++ b/examples/translation_builder.rs @@ -21,15 +21,16 @@ fn main() -> anyhow::Result<()> { let model = TranslationModelBuilder::new() .with_device(Device::cuda_if_available()) // .with_model_type(ModelType::Marian) + .with_model_type(ModelType::M2M100) .with_large_model() .with_source_languages(vec![Language::English]) .with_target_languages(vec![Language::French]) .create_model()?; - let input_context_1 = "The quick brown fox jumps over the lazy dog"; - let input_context_2 = "The dog did not wake up"; + // let input_context_1 = "The quick brown fox jumps over the lazy dog."; + let input_context_2 = "The dog did not wake up."; - let output = model.translate(&[input_context_1, input_context_2], None, Language::French)?; + let output = model.translate(&[input_context_2], Language::English, Language::French)?; for sentence in output { println!("{}", sentence); diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs index 5edeb23..05f5d0a 100644 --- a/examples/translation_m2m100.rs +++ b/examples/translation_m2m100.rs @@ -21,7 +21,7 @@ use rust_bert::resources::{RemoteResource, Resource}; fn main() -> anyhow::Result<()> { let generate_config = GenerateConfig { - max_length: 142, + max_length: 512, min_length: 0, model_resource: Resource::Remote(RemoteResource::from_pretrained( M2M100ModelResources::M2M100_418M, @@ -38,14 +38,16 @@ fn main() -> anyhow::Result<()> { do_sample: false, early_stopping: true, num_beams: 3, + no_repeat_ngram_size: 0, ..Default::default() }; let model = M2M100Generator::new(generate_config)?; let input_context_1 = ">>en.<< The dog did not wake up."; - let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0]; + let target_language = model.get_tokenizer().convert_tokens_to_ids([">>fr.<<"])[0]; + println!("{:?} - {:?}", input_context_1, target_language); let output = model.generate( Some(&[input_context_1]), None, diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index ddc6c20..f047abe 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -50,7 +50,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::Path; -#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq)] /// # Identifies the type of model pub enum ModelType { Bart, diff --git a/src/pipelines/translation.rs b/src/pipelines/translation.rs index da837a3..cc206a3 100644 --- a/src/pipelines/translation.rs +++ b/src/pipelines/translation.rs @@ -16,18 +16,19 @@ use tch::{Device, Tensor}; use crate::common::error::RustBertError; use crate::common::resources::Resource; use crate::m2m_100::{ - M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages, - M2M100TargetLanguages, M2M100VocabResources, + M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources, + M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources, }; use crate::marian::{ MarianConfigResources, MarianGenerator, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources, }; use crate::mbart::{ - MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages, - MBartVocabResources, + MBartConfigResources, MBartGenerator, MBartModelResources, MBartSourceLanguages, + MBartTargetLanguages, MBartVocabResources, }; use crate::pipelines::common::ModelType; +use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; use crate::resources::RemoteResource; use crate::t5::T5Generator; @@ -501,7 +502,7 @@ impl TranslationConfig { max_length: 512, do_sample: false, early_stopping: true, - num_beams: 4, + num_beams: 3, temperature: 1.0, top_k: 50, top_p: 1.0, @@ -547,6 +548,10 @@ pub enum TranslationOption { Marian(MarianGenerator), /// Translator based on T5 model T5(T5Generator), + /// Translator based on MBart50 model + MBart(MBartGenerator), + /// Translator based on M2M100 model + M2M100(M2M100Generator), } impl TranslationOption { @@ -556,6 +561,12 @@ impl TranslationOption { config.into(), )?)), ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)), + ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new( + config.into(), + )?)), + ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new( + config.into(), + )?)), _ => Err(RustBertError::InvalidConfigurationError(format!( "Translation not implemented for {:?}!", config.model_type @@ -568,16 +579,18 @@ impl TranslationOption { match *self { Self::Marian(_) => ModelType::Marian, Self::T5(_) => ModelType::T5, + Self::MBart(_) => ModelType::MBart, + Self::M2M100(_) => ModelType::M2M100, } } - fn validate_and_get_prefix( + fn validate_and_get_prefix_and_forced_bos_id( &self, source_language: Option<&Language>, target_language: Option<&Language>, supported_source_languages: &HashSet, supported_target_languages: &HashSet, - ) -> Result, RustBertError> { + ) -> Result<(Option, Option), RustBertError> { if let Some(source_language) = source_language { if !supported_source_languages.contains(source_language) { return Err(RustBertError::ValueError(format!( @@ -601,30 +614,112 @@ impl TranslationOption { Ok(match *self { Self::Marian(_) => { if supported_target_languages.len() > 1 { - Some(format!( - ">>{}<< ", - match target_language { - Some(value) => value.get_iso_639_1_code(), - None => { + ( + Some(format!( + ">>{}<< ", + match target_language { + Some(value) => value.get_iso_639_1_code(), + None => { + return Err(RustBertError::ValueError( + "Missing target language for Marian".to_string(), + )); + } + } + )), + None, + ) + } else { + (None, None) + } + } + Self::T5(_) => ( + Some(format!( + "translate {} to {}:", + match source_language { + Some(value) => value.to_string(), + None => { + return Err(RustBertError::ValueError( + "Missing source language for T5".to_string(), + )); + } + }, + match target_language { + Some(value) => value.to_string(), + None => { + return Err(RustBertError::ValueError( + "Missing target language for T5".to_string(), + )); + } + } + )), + None, + ), + Self::MBart(ref model) => ( + Some(format!( + ">>{}<< ", + match source_language { + Some(value) => value.get_iso_639_1_code(), + None => { + return Err(RustBertError::ValueError( + "Missing source language for MBart".to_string(), + )); + } + } + )), + if let Some(target_language) = target_language { + Some( + model._get_tokenizer().convert_tokens_to_ids([format!( + ">>{}<<", + target_language.get_iso_639_1_code() + )])[0], + ) + } else { + return Err(RustBertError::ValueError( + "Missing target language for MBart".to_string(), + )); + }, + ), + Self::M2M100(ref model) => ( + Some(match source_language { + Some(value) => { + let language_code = value.get_iso_639_1_code(); + match language_code.len() { + 2 => format!(">>{}.<< ", language_code), + 3 => format!(">>{}<< ", language_code), + _ => { return Err(RustBertError::ValueError( - "Missing target language for Marian".to_string(), + "Invalid ISO 639-3 code".to_string(), )); } } - )) + } + None => { + return Err(RustBertError::ValueError( + "Missing source language for M2M100".to_string(), + )); + } + }), + if let Some(target_language) = target_language { + let language_code = target_language.get_iso_639_1_code(); + Some( + model + ._get_tokenizer() + .convert_tokens_to_ids([match language_code.len() { + 2 => format!(">>{}.<<", language_code), + 3 => format!(">>{}<<", language_code), + _ => { + return Err(RustBertError::ValueError( + "Invalid ISO 639-3 code".to_string(), + )); + } + }])[0], + ) } else { - None - } - } - Self::T5(_) => Some(format!( - "translate {} to {}:", - source_language - .expect("Missing source language for T5") - .to_string(), - target_language - .expect("Missing target language for T5") - .to_string() - )), + return Err(RustBertError::ValueError( + "Missing target language for MBart".to_string(), + )); + }, + ), }) } @@ -633,6 +728,7 @@ impl TranslationOption { &self, prompt_texts: Option, attention_mask: Option, + forced_bos_token_id: Option, ) -> Vec where S: AsRef<[&'a str]>, @@ -666,6 +762,34 @@ impl TranslationOption { .into_iter() .map(|output| output.text) .collect(), + Self::MBart(ref model) => model + .generate( + prompt_texts, + attention_mask, + None, + None, + None, + forced_bos_token_id, + None, + false, + ) + .into_iter() + .map(|output| output.text) + .collect(), + Self::M2M100(ref model) => model + .generate( + prompt_texts, + attention_mask, + None, + None, + None, + forced_bos_token_id, + None, + false, + ) + .into_iter() + .map(|output| output.text) + .collect(), } } } @@ -797,7 +921,7 @@ impl TranslationModel { where S: AsRef<[&'a str]>, { - let prefix = self.model.validate_and_get_prefix( + let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id( source_language.into().as_ref(), target_language.into().as_ref(), &self.supported_source_languages, @@ -814,9 +938,10 @@ impl TranslationModel { self.model.generate( Some(texts.iter().map(AsRef::as_ref).collect::>()), None, + forced_bos_token_id, ) } - None => self.model.generate(Some(texts), None), + None => self.model.generate(Some(texts), None, forced_bos_token_id), }) } } @@ -876,12 +1001,14 @@ where } pub fn with_medium_model(&mut self) -> &mut Self { - if self.model_type.is_some() { - eprintln!( - "Model selection overwritten: was {:?}, replaced by {:?}", - self.model_type.unwrap(), - ModelType::Marian - ); + if let Some(model_type) = self.model_type { + if model_type != ModelType::Marian { + eprintln!( + "Model selection overwritten: was {:?}, replaced by {:?}", + self.model_type.unwrap(), + ModelType::Marian + ); + } } self.model_type = Some(ModelType::Marian); self.model_size = Some(ModelSize::Medium); @@ -889,12 +1016,14 @@ where } pub fn with_large_model(&mut self) -> &mut Self { - if self.model_type.is_some() { - eprintln!( - "Model selection overwritten: was {:?}, replaced by {:?}", - self.model_type.unwrap(), - ModelType::M2M100 - ); + if let Some(model_type) = self.model_type { + if model_type != ModelType::M2M100 { + eprintln!( + "Model selection overwritten: was {:?}, replaced by {:?}", + self.model_type.unwrap(), + ModelType::M2M100 + ); + } } self.model_type = Some(ModelType::M2M100); self.model_size = Some(ModelSize::Large); @@ -902,12 +1031,14 @@ where } pub fn with_xlarge_model(&mut self) -> &mut Self { - if self.model_type.is_some() { - eprintln!( - "Model selection overwritten: was {:?}, replaced by {:?}", - self.model_type.unwrap(), - ModelType::M2M100 - ); + if let Some(model_type) = self.model_type { + if model_type != ModelType::M2M100 { + eprintln!( + "Model selection overwritten: was {:?}, replaced by {:?}", + self.model_type.unwrap(), + ModelType::M2M100 + ); + } } self.model_type = Some(ModelType::M2M100); self.model_size = Some(ModelSize::XLarge);