Added MBart50 and M2M100 to supported translation models

This commit is contained in:
Guillaume B 2021-07-10 11:34:51 +02:00
parent 3b72b7cc9b
commit 5dc7f33c39
4 changed files with 186 additions and 52 deletions

View File

@ -21,15 +21,16 @@ fn main() -> anyhow::Result<()> {
let model = TranslationModelBuilder::new()
.with_device(Device::cuda_if_available())
// .with_model_type(ModelType::Marian)
.with_model_type(ModelType::M2M100)
.with_large_model()
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::French])
.create_model()?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
// let input_context_1 = "The quick brown fox jumps over the lazy dog.";
let input_context_2 = "The dog did not wake up.";
let output = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
let output = model.translate(&[input_context_2], Language::English, Language::French)?;
for sentence in output {
println!("{}", sentence);

View File

@ -21,7 +21,7 @@ use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: 142,
max_length: 512,
min_length: 0,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
@ -38,14 +38,16 @@ fn main() -> anyhow::Result<()> {
do_sample: false,
early_stopping: true,
num_beams: 3,
no_repeat_ngram_size: 0,
..Default::default()
};
let model = M2M100Generator::new(generate_config)?;
let input_context_1 = ">>en.<< The dog did not wake up.";
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>fr.<<"])[0];
println!("{:?} - {:?}", input_context_1, target_language);
let output = model.generate(
Some(&[input_context_1]),
None,

View File

@ -50,7 +50,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq)]
/// # Identifies the type of model
pub enum ModelType {
Bart,

View File

@ -16,18 +16,19 @@ use tch::{Device, Tensor};
use crate::common::error::RustBertError;
use crate::common::resources::Resource;
use crate::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources,
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
};
use crate::marian::{
MarianConfigResources, MarianGenerator, MarianModelResources, MarianSourceLanguages,
MarianSpmResources, MarianTargetLanguages, MarianVocabResources,
};
use crate::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
MBartConfigResources, MBartGenerator, MBartModelResources, MBartSourceLanguages,
MBartTargetLanguages, MBartVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::RemoteResource;
use crate::t5::T5Generator;
@ -501,7 +502,7 @@ impl TranslationConfig {
max_length: 512,
do_sample: false,
early_stopping: true,
num_beams: 4,
num_beams: 3,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
@ -547,6 +548,10 @@ pub enum TranslationOption {
Marian(MarianGenerator),
/// Translator based on T5 model
T5(T5Generator),
/// Translator based on MBart50 model
MBart(MBartGenerator),
/// Translator based on M2M100 model
M2M100(M2M100Generator),
}
impl TranslationOption {
@ -556,6 +561,12 @@ impl TranslationOption {
config.into(),
)?)),
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new(
config.into(),
)?)),
ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new(
config.into(),
)?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Translation not implemented for {:?}!",
config.model_type
@ -568,16 +579,18 @@ impl TranslationOption {
match *self {
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::MBart(_) => ModelType::MBart,
Self::M2M100(_) => ModelType::M2M100,
}
}
fn validate_and_get_prefix(
fn validate_and_get_prefix_and_forced_bos_id(
&self,
source_language: Option<&Language>,
target_language: Option<&Language>,
supported_source_languages: &HashSet<Language>,
supported_target_languages: &HashSet<Language>,
) -> Result<Option<String>, RustBertError> {
) -> Result<(Option<String>, Option<i64>), RustBertError> {
if let Some(source_language) = source_language {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
@ -601,30 +614,112 @@ impl TranslationOption {
Ok(match *self {
Self::Marian(_) => {
if supported_target_languages.len() > 1 {
Some(format!(
">>{}<< ",
match target_language {
Some(value) => value.get_iso_639_1_code(),
None => {
(
Some(format!(
">>{}<< ",
match target_language {
Some(value) => value.get_iso_639_1_code(),
None => {
return Err(RustBertError::ValueError(
"Missing target language for Marian".to_string(),
));
}
}
)),
None,
)
} else {
(None, None)
}
}
Self::T5(_) => (
Some(format!(
"translate {} to {}:",
match source_language {
Some(value) => value.to_string(),
None => {
return Err(RustBertError::ValueError(
"Missing source language for T5".to_string(),
));
}
},
match target_language {
Some(value) => value.to_string(),
None => {
return Err(RustBertError::ValueError(
"Missing target language for T5".to_string(),
));
}
}
)),
None,
),
Self::MBart(ref model) => (
Some(format!(
">>{}<< ",
match source_language {
Some(value) => value.get_iso_639_1_code(),
None => {
return Err(RustBertError::ValueError(
"Missing source language for MBart".to_string(),
));
}
}
)),
if let Some(target_language) = target_language {
Some(
model._get_tokenizer().convert_tokens_to_ids([format!(
">>{}<<",
target_language.get_iso_639_1_code()
)])[0],
)
} else {
return Err(RustBertError::ValueError(
"Missing target language for MBart".to_string(),
));
},
),
Self::M2M100(ref model) => (
Some(match source_language {
Some(value) => {
let language_code = value.get_iso_639_1_code();
match language_code.len() {
2 => format!(">>{}.<< ", language_code),
3 => format!(">>{}<< ", language_code),
_ => {
return Err(RustBertError::ValueError(
"Missing target language for Marian".to_string(),
"Invalid ISO 639-3 code".to_string(),
));
}
}
))
}
None => {
return Err(RustBertError::ValueError(
"Missing source language for M2M100".to_string(),
));
}
}),
if let Some(target_language) = target_language {
let language_code = target_language.get_iso_639_1_code();
Some(
model
._get_tokenizer()
.convert_tokens_to_ids([match language_code.len() {
2 => format!(">>{}.<<", language_code),
3 => format!(">>{}<<", language_code),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
));
}
}])[0],
)
} else {
None
}
}
Self::T5(_) => Some(format!(
"translate {} to {}:",
source_language
.expect("Missing source language for T5")
.to_string(),
target_language
.expect("Missing target language for T5")
.to_string()
)),
return Err(RustBertError::ValueError(
"Missing target language for MBart".to_string(),
));
},
),
})
}
@ -633,6 +728,7 @@ impl TranslationOption {
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
forced_bos_token_id: Option<i64>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
@ -666,6 +762,34 @@ impl TranslationOption {
.into_iter()
.map(|output| output.text)
.collect(),
Self::MBart(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
forced_bos_token_id,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::M2M100(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
forced_bos_token_id,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
}
}
}
@ -797,7 +921,7 @@ impl TranslationModel {
where
S: AsRef<[&'a str]>,
{
let prefix = self.model.validate_and_get_prefix(
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
source_language.into().as_ref(),
target_language.into().as_ref(),
&self.supported_source_languages,
@ -814,9 +938,10 @@ impl TranslationModel {
self.model.generate(
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
None,
forced_bos_token_id,
)
}
None => self.model.generate(Some(texts), None),
None => self.model.generate(Some(texts), None, forced_bos_token_id),
})
}
}
@ -876,12 +1001,14 @@ where
}
pub fn with_medium_model(&mut self) -> &mut Self {
if self.model_type.is_some() {
eprintln!(
"Model selection overwritten: was {:?}, replaced by {:?}",
self.model_type.unwrap(),
ModelType::Marian
);
if let Some(model_type) = self.model_type {
if model_type != ModelType::Marian {
eprintln!(
"Model selection overwritten: was {:?}, replaced by {:?}",
self.model_type.unwrap(),
ModelType::Marian
);
}
}
self.model_type = Some(ModelType::Marian);
self.model_size = Some(ModelSize::Medium);
@ -889,12 +1016,14 @@ where
}
pub fn with_large_model(&mut self) -> &mut Self {
if self.model_type.is_some() {
eprintln!(
"Model selection overwritten: was {:?}, replaced by {:?}",
self.model_type.unwrap(),
ModelType::M2M100
);
if let Some(model_type) = self.model_type {
if model_type != ModelType::M2M100 {
eprintln!(
"Model selection overwritten: was {:?}, replaced by {:?}",
self.model_type.unwrap(),
ModelType::M2M100
);
}
}
self.model_type = Some(ModelType::M2M100);
self.model_size = Some(ModelSize::Large);
@ -902,12 +1031,14 @@ where
}
pub fn with_xlarge_model(&mut self) -> &mut Self {
if self.model_type.is_some() {
eprintln!(
"Model selection overwritten: was {:?}, replaced by {:?}",
self.model_type.unwrap(),
ModelType::M2M100
);
if let Some(model_type) = self.model_type {
if model_type != ModelType::M2M100 {
eprintln!(
"Model selection overwritten: was {:?}, replaced by {:?}",
self.model_type.unwrap(),
ModelType::M2M100
);
}
}
self.model_type = Some(ModelType::M2M100);
self.model_size = Some(ModelSize::XLarge);