Implemented automatic model selection for builder

This commit is contained in:
Guillaume B 2021-07-09 19:12:05 +02:00
parent 169fe4cf2b
commit 22c10f4521
6 changed files with 456 additions and 28 deletions

View File

@ -64,7 +64,8 @@ mod m2m_100_model;
pub use m2m_100_model::{
M2M100Config, M2M100ConfigResources, M2M100ForConditionalGeneration, M2M100Generator,
M2M100MergesResources, M2M100Model, M2M100ModelResources, M2M100VocabResources,
M2M100MergesResources, M2M100Model, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
pub use attention::LayerState;

View File

@ -39,15 +39,15 @@ pub struct MarianVocabResources;
/// # Marian Pretrained sentence piece model files
pub struct MarianSpmResources;
/// # Marian optional prefixes
pub struct MarianPrefix;
/// # Marian source languages pre-sets
pub struct MarianSourceLanguages;
/// # Marian target languages pre-sets
pub struct MarianTargetLanguages;
/// # Marian translation model pre-sets
pub struct MarianModelPreset;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (

View File

@ -60,7 +60,7 @@
mod marian_model;
pub use marian_model::{
MarianConfigResources, MarianForConditionalGeneration, MarianGenerator, MarianModelResources,
MarianPrefix, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages,
MarianConfigResources, MarianForConditionalGeneration, MarianGenerator, MarianModelPreset,
MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages,
MarianVocabResources,
};

View File

@ -0,0 +1,34 @@
// use crate::marian::marian_model::MarianModelPreset;
// use crate::marian::{
// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
// MarianTargetLanguages, MarianVocabResources,
// };
// use crate::pipelines::common::ModelType;
// use crate::pipelines::translation::{Language, TranslationModelConfig};
// use crate::resources::{RemoteResource, Resource};
// use std::borrow::Cow;
//
// impl MarianModelPreset {
// pub const ENGLISH2GERMAN: TranslationModelConfig<[Language; 1], [Language; 1]> =
// TranslationModelConfig {
// model_type: ModelType::Marian,
// model_resource: Resource::Remote(RemoteResource {
// url: Cow::Borrowed(MarianModelResources::ENGLISH2GERMAN.0),
// cache_subdir: Cow::Borrowed(MarianModelResources::ENGLISH2GERMAN.1),
// }),
// config_resource: Resource::Remote(RemoteResource {
// url: Cow::Borrowed(MarianConfigResources::ENGLISH2GERMAN.0),
// cache_subdir: Cow::Borrowed(MarianConfigResources::ENGLISH2GERMAN.1),
// }),
// vocab_resource: Resource::Remote(RemoteResource {
// url: Cow::Borrowed(MarianVocabResources::ENGLISH2GERMAN.0),
// cache_subdir: Cow::Borrowed(MarianVocabResources::ENGLISH2GERMAN.1),
// }),
// merges_resource: Resource::Remote(RemoteResource {
// url: Cow::Borrowed(MarianSpmResources::ENGLISH2GERMAN.0),
// cache_subdir: Cow::Borrowed(MarianSpmResources::ENGLISH2GERMAN.1),
// }),
// source_languages: MarianSourceLanguages::ENGLISH2GERMAN,
// target_languages: MarianTargetLanguages::ENGLISH2GERMAN,
// };
// }

View File

@ -57,7 +57,7 @@ mod mbart_model;
pub use mbart_model::{
MBartConfig, MBartConfigResources, MBartForConditionalGeneration,
MBartForSequenceClassification, MBartGenerator, MBartModel, MBartModelOutput,
MBartModelResources, MBartVocabResources,
MBartModelResources, MBartSourceLanguages, MBartTargetLanguages, MBartVocabResources,
};
pub use attention::LayerState;

View File

@ -16,19 +16,24 @@ use tch::{Device, Tensor};
use crate::common::error::RustBertError;
use crate::common::resources::Resource;
use crate::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100VocabResources,
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use crate::marian::{
MarianConfigResources, MarianGenerator, MarianModelResources, MarianSourceLanguages,
MarianSpmResources, MarianTargetLanguages, MarianVocabResources,
};
use crate::mbart::{MBartConfigResources, MBartModelResources, MBartVocabResources};
use crate::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::RemoteResource;
use crate::t5::T5Generator;
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Debug, Display};
/// Language
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
@ -135,7 +140,7 @@ pub enum Language {
HaitianCreole,
}
impl fmt::Display for Language {
impl Display for Language {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", {
let input_string = format!("{:?}", self);
@ -370,6 +375,54 @@ impl Language {
}
}
// ToDo: remove
// pub struct TranslationModelConfig<S, T>
// where
// S: AsRef<[Language]>,
// T: AsRef<[Language]>,
// {
// /// Model type used for translation
// pub model_type: ModelType,
// /// Model weights resource
// pub model_resource: Resource,
// /// Config resource
// pub config_resource: Resource,
// /// Vocab resource
// pub vocab_resource: Resource,
// /// Merges resource
// pub merges_resource: Resource,
// /// Supported source languages
// pub source_languages: S,
// /// Supported target languages
// pub target_languages: T,
// }
//
// impl<S, T> TranslationModelConfig<S, T>
// where
// S: AsRef<[Language]>,
// T: AsRef<[Language]>,
// {
// fn supports_source_languages<L>(&self, source_languages: L) -> bool
// where
// L: AsRef<[Language]> + Display,
// {
// source_languages
// .as_ref()
// .iter()
// .all(|language| self.source_languages.as_ref().contains(language))
// }
//
// fn supports_target_languages<L>(&self, target_languages: L) -> bool
// where
// L: AsRef<[Language]> + Display,
// {
// target_languages
// .as_ref()
// .iter()
// .all(|language| self.target_languages.as_ref().contains(language))
// }
// }
/// # Configuration for text translation
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
@ -812,10 +865,13 @@ impl TranslationModel {
}
struct TranslationResources {
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
source_languages: HashSet<Language>,
target_languages: HashSet<Language>,
}
#[derive(Clone, Copy, PartialEq)]
@ -827,8 +883,8 @@ enum ModelSize {
pub struct TranslationModelBuilder<S, T>
where
S: AsRef<[Language]>,
T: AsRef<[Language]>,
S: AsRef<[Language]> + Display,
T: AsRef<[Language]> + Display,
{
model_type: Option<ModelType>,
model_resource: Option<Resource>,
@ -843,8 +899,8 @@ where
impl<S, T> TranslationModelBuilder<S, T>
where
S: AsRef<[Language]>,
T: AsRef<[Language]>,
S: AsRef<[Language]> + Display,
T: AsRef<[Language]> + Display,
{
pub fn new() -> TranslationModelBuilder<S, T> {
TranslationModelBuilder {
@ -933,7 +989,13 @@ where
source_languages: &S,
target_languages: &T,
) -> TranslationResources {
unimplemented!()
match self.get_marian_model(source_languages, target_languages) {
Ok(marian_resources) => marian_resources,
Err(_) => match self.model_size {
Some(value) if value == ModelSize::XLarge => self.get_m2m100_xlarge_resources(),
_ => self.get_m2m100_large_resources(),
},
}
}
fn get_marian_model(
@ -941,11 +1003,330 @@ where
source_languages: &S,
target_languages: &T,
) -> Result<TranslationResources, RustBertError> {
unimplemented!()
let (resources, source_languages, target_languages) =
match (source_languages.as_ref(), target_languages.as_ref()) {
([Language::English], [Language::German]) => (
(
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
),
MarianSourceLanguages::ENGLISH2GERMAN
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2GERMAN
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::Russian]) => (
(
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
),
MarianSourceLanguages::ENGLISH2RUSSIAN
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2RUSSIAN
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::Dutch]) => (
(
MarianModelResources::ENGLISH2DUTCH,
MarianConfigResources::ENGLISH2DUTCH,
MarianVocabResources::ENGLISH2DUTCH,
MarianSpmResources::ENGLISH2DUTCH,
),
MarianSourceLanguages::ENGLISH2DUTCH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2DUTCH
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::ChineseMandarin]) => (
(
MarianModelResources::ENGLISH2CHINESE,
MarianConfigResources::ENGLISH2CHINESE,
MarianVocabResources::ENGLISH2CHINESE,
MarianSpmResources::ENGLISH2CHINESE,
),
MarianSourceLanguages::ENGLISH2CHINESE
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2CHINESE
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::Swedish]) => (
(
MarianModelResources::ENGLISH2SWEDISH,
MarianConfigResources::ENGLISH2SWEDISH,
MarianVocabResources::ENGLISH2SWEDISH,
MarianSpmResources::ENGLISH2SWEDISH,
),
MarianSourceLanguages::ENGLISH2SWEDISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2SWEDISH
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::Arabic]) => (
(
MarianModelResources::ENGLISH2ARABIC,
MarianConfigResources::ENGLISH2ARABIC,
MarianVocabResources::ENGLISH2ARABIC,
MarianSpmResources::ENGLISH2ARABIC,
),
MarianSourceLanguages::ENGLISH2ARABIC
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2ARABIC
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::Hindi]) => (
(
MarianModelResources::ENGLISH2HINDI,
MarianConfigResources::ENGLISH2HINDI,
MarianVocabResources::ENGLISH2HINDI,
MarianSpmResources::ENGLISH2HINDI,
),
MarianSourceLanguages::ENGLISH2HINDI
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2HINDI
.iter()
.cloned()
.collect(),
),
([Language::English], [Language::Hebrew]) => (
(
MarianModelResources::ENGLISH2HEBREW,
MarianConfigResources::ENGLISH2HEBREW,
MarianVocabResources::ENGLISH2HEBREW,
MarianSpmResources::ENGLISH2HEBREW,
),
MarianSourceLanguages::ENGLISH2HEBREW
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2HEBREW
.iter()
.cloned()
.collect(),
),
([Language::German], [Language::English]) => (
(
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
),
MarianSourceLanguages::GERMAN2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::GERMAN2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::Russian], [Language::English]) => (
(
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
),
MarianSourceLanguages::RUSSIAN2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::RUSSIAN2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::Dutch], [Language::English]) => (
(
MarianModelResources::DUTCH2ENGLISH,
MarianConfigResources::DUTCH2ENGLISH,
MarianVocabResources::DUTCH2ENGLISH,
MarianSpmResources::DUTCH2ENGLISH,
),
MarianSourceLanguages::DUTCH2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::DUTCH2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::ChineseMandarin], [Language::English]) => (
(
MarianModelResources::CHINESE2ENGLISH,
MarianConfigResources::CHINESE2ENGLISH,
MarianVocabResources::CHINESE2ENGLISH,
MarianSpmResources::CHINESE2ENGLISH,
),
MarianSourceLanguages::CHINESE2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::CHINESE2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::Swedish], [Language::English]) => (
(
MarianModelResources::SWEDISH2ENGLISH,
MarianConfigResources::SWEDISH2ENGLISH,
MarianVocabResources::SWEDISH2ENGLISH,
MarianSpmResources::SWEDISH2ENGLISH,
),
MarianSourceLanguages::SWEDISH2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::SWEDISH2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::Arabic], [Language::English]) => (
(
MarianModelResources::ARABIC2ENGLISH,
MarianConfigResources::ARABIC2ENGLISH,
MarianVocabResources::ARABIC2ENGLISH,
MarianSpmResources::ARABIC2ENGLISH,
),
MarianSourceLanguages::ARABIC2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ARABIC2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::Hindi], [Language::English]) => (
(
MarianModelResources::HINDI2ENGLISH,
MarianConfigResources::HINDI2ENGLISH,
MarianVocabResources::HINDI2ENGLISH,
MarianSpmResources::HINDI2ENGLISH,
),
MarianSourceLanguages::HINDI2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::HINDI2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::Hebrew], [Language::English]) => (
(
MarianModelResources::HEBREW2ENGLISH,
MarianConfigResources::HEBREW2ENGLISH,
MarianVocabResources::HEBREW2ENGLISH,
MarianSpmResources::HEBREW2ENGLISH,
),
MarianSourceLanguages::HEBREW2ENGLISH
.iter()
.cloned()
.collect(),
MarianTargetLanguages::HEBREW2ENGLISH
.iter()
.cloned()
.collect(),
),
([Language::English], languages)
if languages
.iter()
.all(|lang| MarianTargetLanguages::ENGLISH2ROMANCE.contains(lang)) =>
{
(
(
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
),
MarianSourceLanguages::ENGLISH2ROMANCE
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2ROMANCE
.iter()
.cloned()
.collect(),
)
}
(languages, [Language::English])
if languages
.iter()
.all(|lang| MarianSourceLanguages::ROMANCE2ENGLISH.contains(lang)) =>
{
(
(
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
),
MarianSourceLanguages::ENGLISH2ROMANCE
.iter()
.cloned()
.collect(),
MarianTargetLanguages::ENGLISH2ROMANCE
.iter()
.cloned()
.collect(),
)
}
(_, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"No Pretrained Marian configuration found for {} to {} translation",
source_languages, target_languages
)));
}
};
Ok(TranslationResources {
model_type: ModelType::Marian,
model_resource: Resource::Remote(RemoteResource::from_pretrained(resources.0)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(resources.1)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(resources.2)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(resources.3)),
source_languages,
target_languages,
})
}
fn get_bart50_resources(&self) -> TranslationResources {
fn get_mbart50_resources(&self) -> TranslationResources {
TranslationResources {
model_type: ModelType::MBart,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
)),
@ -958,11 +1339,20 @@ where
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
source_languages: MBartSourceLanguages::MBART50_MANY_TO_MANY
.iter()
.cloned()
.collect(),
target_languages: MBartTargetLanguages::MBART50_MANY_TO_MANY
.iter()
.cloned()
.collect(),
}
}
fn get_m2m100_large_resources(&self) -> TranslationResources {
TranslationResources {
model_type: ModelType::M2M100,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
@ -975,11 +1365,14 @@ where
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
source_languages: M2M100SourceLanguages::M2M100_418M.iter().cloned().collect(),
target_languages: M2M100TargetLanguages::M2M100_418M.iter().cloned().collect(),
}
}
fn get_m2m100_xlarge_resources(&self) -> TranslationResources {
TranslationResources {
model_type: ModelType::M2M100,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_1_2B,
)),
@ -992,6 +1385,8 @@ where
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_1_2B,
)),
source_languages: M2M100SourceLanguages::M2M100_1_2B.iter().cloned().collect(),
target_languages: M2M100TargetLanguages::M2M100_1_2B.iter().cloned().collect(),
}
}
@ -1004,30 +1399,28 @@ where
&self.target_languages,
) {
(Some(ModelType::M2M100), None, None) | (None, None, None) => match self.model_size {
Some(value) if ((value == ModelSize::Large) | (value == ModelSize::Medium)) => {
self.get_m2m100_large_resources()
}
_ => self.get_m2m100_xlarge_resources(),
Some(value) if value == ModelSize::XLarge => self.get_m2m100_xlarge_resources(),
_ => self.get_m2m100_large_resources(),
},
(Some(ModelType::MBart), None, None) => self.get_bart50_resources(),
(Some(ModelType::MBart), None, None) => self.get_mbart50_resources(),
(Some(ModelType::Marian), Some(source_languages), Some(target_languages)) => {
self.get_marian_model(source_languages, target_languages)?
}
(None, Some(source_languages), Some(target_languages)) => {
self.get_default_model(source_languages, target_languages)
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
)));
}
(_, None, None) | (_, _, None) | (_, None, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Source and target languages must be specified for {:?}",
self.model_type.unwrap()
)));
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
)));
}
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(