Updated examples and integration tests

This commit is contained in:
Guillaume B 2021-07-11 11:13:00 +02:00
parent 89b3a327fa
commit ce90d8901d
10 changed files with 1311 additions and 240 deletions

View File

@ -23,17 +23,13 @@ fn main() -> anyhow::Result<()> {
.with_model_type(ModelType::Marian)
// .with_large_model()
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::Hebrew])
.with_target_languages(vec![Language::Spanish])
.create_model()?;
let input_context_1 = "The quick brown fox jumps over the lazy dog.";
let input_context_2 = "The dog did not wake up.";
let output = model.translate(
&[input_context_1, input_context_2],
Language::English,
Language::Hebrew,
)?;
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
for sentence in output {
println!("{}", sentence);

View File

@ -13,54 +13,52 @@
extern crate anyhow;
use rust_bert::m2m_100::{
M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources,
M2M100VocabResources,
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: 512,
min_length: 0,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
do_sample: false,
early_stopping: true,
num_beams: 3,
no_repeat_ngram_size: 0,
..Default::default()
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let model = M2M100Generator::new(generate_config)?;
let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M;
let input_context_1 = ">>en.<< The dog did not wake up.";
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
println!("{:?} - {:?}", input_context_1, target_language);
let output = model.generate(
Some(&[input_context_1]),
None,
None,
None,
None,
target_language,
None,
false,
let translation_config = TranslationConfig::new(
ModelType::M2M100,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
for sentence in output {
println!("{:?}", sentence);
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);
}
Ok(())
}

View File

@ -13,48 +13,52 @@
extern crate anyhow;
use rust_bert::mbart::{
MBartConfigResources, MBartGenerator, MBartModelResources, MBartVocabResources,
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: 56,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
do_sample: false,
num_beams: 1,
..Default::default()
};
let model = MBartGenerator::new(generate_config)?;
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let input_context_1 = "en_XX The quick brown fox jumps over the lazy dog.";
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
let output = model.generate(
Some(&[input_context_1]),
None,
None,
None,
None,
target_language,
None,
false,
let translation_config = TranslationConfig::new(
ModelType::MBart,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
for sentence in output {
println!("{:?}", sentence.text);
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);
}
Ok(())
}

View File

@ -12,46 +12,56 @@
extern crate anyhow;
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5Generator, T5ModelResources, T5VocabResources};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device;
fn main() -> anyhow::Result<()> {
// Resources paths
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let generate_config = GenerateConfig {
model_resource: weights_resource,
vocab_resource,
let source_languages = [
Language::English,
Language::French,
Language::German,
Language::Romanian,
];
let target_languages = [
Language::English,
Language::French,
Language::German,
Language::Romanian,
];
let translation_config = TranslationConfig::new(
ModelType::T5,
model_resource,
config_resource,
max_length: 40,
do_sample: false,
num_beams: 4,
..Default::default()
};
// Set-up model
let t5_model = T5Generator::new(generate_config)?;
// Define input
let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(
Some(input.to_vec()),
None,
None,
None,
None,
None,
None,
false,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
println!("{:?}", output);
let model = TranslationModel::new(translation_config)?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
for sentence in outputs {
println!("{}", sentence);
}
Ok(())
}

View File

@ -34,14 +34,10 @@ enum ModelSize {
XLarge,
}
pub struct TranslationModelBuilder<S, T>
where
S: AsRef<[Language]> + Debug,
T: AsRef<[Language]> + Debug,
{
pub struct TranslationModelBuilder {
model_type: Option<ModelType>,
source_languages: Option<S>,
target_languages: Option<T>,
source_languages: Option<Vec<Language>>,
target_languages: Option<Vec<Language>>,
device: Option<Device>,
model_size: Option<ModelSize>,
}
@ -61,12 +57,8 @@ macro_rules! get_marian_resources {
};
}
impl<S, T> TranslationModelBuilder<S, T>
where
S: AsRef<[Language]> + Debug,
T: AsRef<[Language]> + Debug,
{
pub fn new() -> TranslationModelBuilder<S, T> {
impl TranslationModelBuilder {
pub fn new() -> TranslationModelBuilder {
TranslationModelBuilder {
model_type: None,
source_languages: None,
@ -131,20 +123,26 @@ where
self
}
pub fn with_source_languages(&mut self, source_languages: S) -> &mut Self {
self.source_languages = Some(source_languages);
pub fn with_source_languages<S>(&mut self, source_languages: S) -> &mut Self
where
S: AsRef<[Language]> + Debug,
{
self.source_languages = Some(source_languages.as_ref().to_vec());
self
}
pub fn with_target_languages(&mut self, target_languages: T) -> &mut Self {
self.target_languages = Some(target_languages);
pub fn with_target_languages<T>(&mut self, target_languages: T) -> &mut Self
where
T: AsRef<[Language]> + Debug,
{
self.target_languages = Some(target_languages.as_ref().to_vec());
self
}
fn get_default_model(
&self,
source_languages: Option<&S>,
target_languages: Option<&T>,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
Ok(
match self.get_marian_model(source_languages, target_languages) {
@ -161,14 +159,14 @@ where
fn get_marian_model(
&self,
source_languages: Option<&S>,
target_languages: Option<&T>,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
let (resources, source_languages, target_languages) =
if let (Some(source_languages), Some(target_languages)) =
(source_languages, target_languages)
{
match (source_languages.as_ref(), target_languages.as_ref()) {
match (source_languages.as_slice(), target_languages.as_slice()) {
([Language::English], [Language::German]) => {
get_marian_resources!(ENGLISH2RUSSIAN)
}
@ -257,18 +255,17 @@ where
fn get_mbart50_resources(
&self,
source_languages: Option<&S>,
target_languages: Option<&T>,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
if let Some(source_languages) = source_languages {
if !source_languages
.as_ref()
.iter()
.all(|lang| MBartSourceLanguages::MBART50_MANY_TO_MANY.contains(lang))
{
return Err(RustBertError::ValueError(format!(
"{:?} not in list of supported languages: {:?}",
source_languages.as_ref(),
source_languages,
MBartSourceLanguages::MBART50_MANY_TO_MANY
)));
}
@ -276,7 +273,6 @@ where
if let Some(target_languages) = target_languages {
if !target_languages
.as_ref()
.iter()
.all(|lang| MBartTargetLanguages::MBART50_MANY_TO_MANY.contains(lang))
{
@ -315,18 +311,17 @@ where
fn get_m2m100_large_resources(
&self,
source_languages: Option<&S>,
target_languages: Option<&T>,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
if let Some(source_languages) = source_languages {
if !source_languages
.as_ref()
.iter()
.all(|lang| M2M100SourceLanguages::M2M100_418M.contains(lang))
{
return Err(RustBertError::ValueError(format!(
"{:?} not in list of supported languages: {:?}",
source_languages.as_ref(),
source_languages,
M2M100SourceLanguages::M2M100_418M
)));
}
@ -334,7 +329,6 @@ where
if let Some(target_languages) = target_languages {
if !target_languages
.as_ref()
.iter()
.all(|lang| M2M100TargetLanguages::M2M100_418M.contains(lang))
{
@ -367,18 +361,17 @@ where
fn get_m2m100_xlarge_resources(
&self,
source_languages: Option<&S>,
target_languages: Option<&T>,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
if let Some(source_languages) = source_languages {
if !source_languages
.as_ref()
.iter()
.all(|lang| M2M100SourceLanguages::M2M100_1_2B.contains(lang))
{
return Err(RustBertError::ValueError(format!(
"{:?} not in list of supported languages: {:?}",
source_languages.as_ref(),
source_languages,
M2M100SourceLanguages::M2M100_1_2B
)));
}
@ -386,7 +379,6 @@ where
if let Some(target_languages) = target_languages {
if !target_languages
.as_ref()
.iter()
.all(|lang| M2M100TargetLanguages::M2M100_1_2B.contains(lang))
{

View File

@ -0,0 +1,991 @@
// Copyright 2018-2020 The HuggingFace Inc. team.
// Copyright 2020 Marian Team Authors
// Copyright 2019-2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Device, Tensor};
use crate::common::error::RustBertError;
use crate::common::resources::Resource;
use crate::m2m_100::M2M100Generator;
use crate::marian::MarianGenerator;
use crate::mbart::MBartGenerator;
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::t5::T5Generator;
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Debug, Display};
/// Language
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum Language {
Afrikaans,
Danish,
Dutch,
German,
English,
Icelandic,
Luxembourgish,
Norwegian,
Swedish,
WesternFrisian,
Yiddish,
Asturian,
Catalan,
French,
Galician,
Italian,
Occitan,
Portuguese,
Romanian,
Spanish,
Belarusian,
Bosnian,
Bulgarian,
Croatian,
Czech,
Macedonian,
Polish,
Russian,
Serbian,
Slovak,
Slovenian,
Ukrainian,
Estonian,
Finnish,
Hungarian,
Latvian,
Lithuanian,
Albanian,
Armenian,
Georgian,
Greek,
Breton,
Irish,
ScottishGaelic,
Welsh,
Azerbaijani,
Bashkir,
Kazakh,
Turkish,
Uzbek,
Japanese,
Korean,
Vietnamese,
ChineseMandarin,
Bengali,
Gujarati,
Hindi,
Kannada,
Marathi,
Nepali,
Oriya,
Panjabi,
Sindhi,
Sinhala,
Urdu,
Tamil,
Cebuano,
Iloko,
Indonesian,
Javanese,
Malagasy,
Malay,
Malayalam,
Sundanese,
Tagalog,
Burmese,
CentralKhmer,
Lao,
Thai,
Mongolian,
Arabic,
Hebrew,
Pashto,
Farsi,
Amharic,
Fulah,
Hausa,
Igbo,
Lingala,
Luganda,
NorthernSotho,
Somali,
Swahili,
Swati,
Tswana,
Wolof,
Xhosa,
Yoruba,
Zulu,
HaitianCreole,
}
impl Display for Language {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", {
let input_string = format!("{:?}", self);
let mut output: Vec<&str> = Vec::new();
let mut start: usize = 0;
for (c_pos, c) in input_string.char_indices() {
if c.is_uppercase() {
if start < c_pos {
output.push(&input_string[start..c_pos]);
}
start = c_pos;
}
}
if start < input_string.len() {
output.push(&input_string[start..]);
}
output.join(" ")
})
}
}
impl Language {
pub fn get_iso_639_1_code(&self) -> &'static str {
match self {
Language::Afrikaans => "af",
Language::Danish => "da",
Language::Dutch => "nl",
Language::German => "de",
Language::English => "en",
Language::Icelandic => "is",
Language::Luxembourgish => "lb",
Language::Norwegian => "no",
Language::Swedish => "sv",
Language::WesternFrisian => "fy",
Language::Yiddish => "yi",
Language::Asturian => "ast",
Language::Catalan => "ca",
Language::French => "fr",
Language::Galician => "gl",
Language::Italian => "it",
Language::Occitan => "oc",
Language::Portuguese => "pt",
Language::Romanian => "ro",
Language::Spanish => "es",
Language::Belarusian => "be",
Language::Bosnian => "bs",
Language::Bulgarian => "bg",
Language::Croatian => "hr",
Language::Czech => "cs",
Language::Macedonian => "mk",
Language::Polish => "pl",
Language::Russian => "ru",
Language::Serbian => "sr",
Language::Slovak => "sk",
Language::Slovenian => "sl",
Language::Ukrainian => "uk",
Language::Estonian => "et",
Language::Finnish => "fi",
Language::Hungarian => "hu",
Language::Latvian => "lv",
Language::Lithuanian => "lt",
Language::Albanian => "sq",
Language::Armenian => "hy",
Language::Georgian => "ka",
Language::Greek => "el",
Language::Breton => "br",
Language::Irish => "ga",
Language::ScottishGaelic => "gd",
Language::Welsh => "cy",
Language::Azerbaijani => "az",
Language::Bashkir => "ba",
Language::Kazakh => "kk",
Language::Turkish => "tr",
Language::Uzbek => "uz",
Language::Japanese => "ja",
Language::Korean => "ko",
Language::Vietnamese => "vi",
Language::ChineseMandarin => "zh",
Language::Bengali => "bn",
Language::Gujarati => "gu",
Language::Hindi => "hi",
Language::Kannada => "kn",
Language::Marathi => "mr",
Language::Nepali => "ne",
Language::Oriya => "or",
Language::Panjabi => "pa",
Language::Sindhi => "sd",
Language::Sinhala => "si",
Language::Urdu => "ur",
Language::Tamil => "ta",
Language::Cebuano => "ceb",
Language::Iloko => "ilo",
Language::Indonesian => "id",
Language::Javanese => "jv",
Language::Malagasy => "mg",
Language::Malay => "ms",
Language::Malayalam => "ml",
Language::Sundanese => "su",
Language::Tagalog => "tl",
Language::Burmese => "my",
Language::CentralKhmer => "km",
Language::Lao => "lo",
Language::Thai => "th",
Language::Mongolian => "mn",
Language::Arabic => "ar",
Language::Hebrew => "he",
Language::Pashto => "ps",
Language::Farsi => "fa",
Language::Amharic => "am",
Language::Fulah => "ff",
Language::Hausa => "ha",
Language::Igbo => "ig",
Language::Lingala => "ln",
Language::Luganda => "lg",
Language::NorthernSotho => "nso",
Language::Somali => "so",
Language::Swahili => "sw",
Language::Swati => "ss",
Language::Tswana => "tn",
Language::Wolof => "wo",
Language::Xhosa => "xh",
Language::Yoruba => "yo",
Language::Zulu => "zu",
Language::HaitianCreole => "ht",
}
}
pub fn get_iso_639_3_code(&self) -> &'static str {
match self {
Language::Afrikaans => "afr",
Language::Danish => "dan",
Language::Dutch => "nld",
Language::German => "deu",
Language::English => "eng",
Language::Icelandic => "isl",
Language::Luxembourgish => "ltz",
Language::Norwegian => "nor",
Language::Swedish => "swe",
Language::WesternFrisian => "fry",
Language::Yiddish => "yid",
Language::Asturian => "ast",
Language::Catalan => "cat",
Language::French => "fra",
Language::Galician => "glg",
Language::Italian => "ita",
Language::Occitan => "oci",
Language::Portuguese => "por",
Language::Romanian => "ron",
Language::Spanish => "spa",
Language::Belarusian => "bel",
Language::Bosnian => "bos",
Language::Bulgarian => "bul",
Language::Croatian => "hrv",
Language::Czech => "ces",
Language::Macedonian => "mkd",
Language::Polish => "pol",
Language::Russian => "rus",
Language::Serbian => "srp",
Language::Slovak => "slk",
Language::Slovenian => "slv",
Language::Ukrainian => "ukr",
Language::Estonian => "est",
Language::Finnish => "fin",
Language::Hungarian => "hun",
Language::Latvian => "lav",
Language::Lithuanian => "lit",
Language::Albanian => "sqi",
Language::Armenian => "hye",
Language::Georgian => "kat",
Language::Greek => "ell",
Language::Breton => "bre",
Language::Irish => "gle",
Language::ScottishGaelic => "gla",
Language::Welsh => "cym",
Language::Azerbaijani => "aze",
Language::Bashkir => "bak",
Language::Kazakh => "kaz",
Language::Turkish => "tur",
Language::Uzbek => "uzb",
Language::Japanese => "jpn",
Language::Korean => "kor",
Language::Vietnamese => "vie",
Language::ChineseMandarin => "cmn",
Language::Bengali => "ben",
Language::Gujarati => "guj",
Language::Hindi => "hin",
Language::Kannada => "kan",
Language::Marathi => "mar",
Language::Nepali => "nep",
Language::Oriya => "ori",
Language::Panjabi => "pan",
Language::Sindhi => "snd",
Language::Sinhala => "sin",
Language::Urdu => "urd",
Language::Tamil => "tam",
Language::Cebuano => "ceb",
Language::Iloko => "ilo",
Language::Indonesian => "ind",
Language::Javanese => "jav",
Language::Malagasy => "mlg",
Language::Malay => "msa",
Language::Malayalam => "mal",
Language::Sundanese => "sun",
Language::Tagalog => "tgl",
Language::Burmese => "mya",
Language::CentralKhmer => "khm",
Language::Lao => "lao",
Language::Thai => "tha",
Language::Mongolian => "mon",
Language::Arabic => "ara",
Language::Hebrew => "heb",
Language::Pashto => "pus",
Language::Farsi => "fas",
Language::Amharic => "amh",
Language::Fulah => "ful",
Language::Hausa => "hau",
Language::Igbo => "ibo",
Language::Lingala => "lin",
Language::Luganda => "lug",
Language::NorthernSotho => "nso",
Language::Somali => "som",
Language::Swahili => "swa",
Language::Swati => "ssw",
Language::Tswana => "tsn",
Language::Wolof => "wol",
Language::Xhosa => "xho",
Language::Yoruba => "yor",
Language::Zulu => "zul",
Language::HaitianCreole => "hat",
}
}
}
/// # Configuration for text translation
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
pub struct TranslationConfig {
/// Model type used for translation
pub model_type: ModelType,
/// Model weights resource
pub model_resource: Resource,
/// Config resource
pub config_resource: Resource,
/// Vocab resource
pub vocab_resource: Resource,
/// Merges resource
pub merges_resource: Resource,
/// Supported source languages
pub source_languages: HashSet<Language>,
/// Supported target languages
pub target_languages: HashSet<Language>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
pub repetition_penalty: f64,
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
pub num_beam_groups: Option<i64>,
/// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
pub diversity_penalty: Option<f64>,
}
impl TranslationConfig {
/// Create a new `TranslationConfiguration` from an available language.
///
/// # Arguments
///
/// * `language` - `Language` enum value (e.g. `Language::EnglishToFrench`)
/// * `device` - `Device` to place the model on (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::marian::{
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
/// MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use tch::Device;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// ));
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianConfigResources::ROMANCE2ENGLISH,
/// ));
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianVocabResources::ROMANCE2ENGLISH,
/// ));
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect();
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect();
///
/// let translation_config = TranslationConfig::new(
/// ModelType::Marian,
/// model_resource,
/// config_resource,
/// vocab_resource.clone(),
/// vocab_resource,
/// source_languages,
/// target_languages,
/// device: Device::cuda_if_available(),
/// );
/// # Ok(())
/// # }
/// ```
pub fn new<S, T>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>,
) -> TranslationConfig
where
S: AsRef<[Language]>,
T: AsRef<[Language]>,
{
let device = device.into().unwrap_or_else(|| Device::cuda_if_available());
TranslationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages: source_languages.as_ref().iter().cloned().collect(),
target_languages: target_languages.as_ref().iter().cloned().collect(),
device,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: true,
num_beams: 3,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
}
}
}
impl From<TranslationConfig> for GenerateConfig {
fn from(config: TranslationConfig) -> GenerateConfig {
GenerateConfig {
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
}
}
}
/// # Abstraction that holds one particular translation model, for any of the supported models
pub enum TranslationOption {
/// Translator based on Marian model
Marian(MarianGenerator),
/// Translator based on T5 model
T5(T5Generator),
/// Translator based on MBart50 model
MBart(MBartGenerator),
/// Translator based on M2M100 model
M2M100(M2M100Generator),
}
impl TranslationOption {
pub fn new(config: TranslationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new(
config.into(),
)?)),
ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new(
config.into(),
)?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Translation not implemented for {:?}!",
config.model_type
))),
}
}
/// Returns the `ModelType` for this TranslationOption
pub fn model_type(&self) -> ModelType {
match *self {
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::MBart(_) => ModelType::MBart,
Self::M2M100(_) => ModelType::M2M100,
}
}
fn validate_and_get_prefix_and_forced_bos_id(
&self,
source_language: Option<&Language>,
target_language: Option<&Language>,
supported_source_languages: &HashSet<Language>,
supported_target_languages: &HashSet<Language>,
) -> Result<(Option<String>, Option<i64>), RustBertError> {
if let Some(source_language) = source_language {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
source_language.to_string(),
supported_source_languages
)));
}
}
if let Some(target_language) = target_language {
if !supported_target_languages.contains(target_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
target_language.to_string(),
supported_target_languages
)));
}
}
Ok(match *self {
Self::Marian(_) => {
if supported_target_languages.len() > 1 {
(
Some(format!(
">>{}<< ",
match target_language {
Some(value) => value.get_iso_639_1_code(),
None => {
return Err(RustBertError::ValueError(format!(
"Missing target language for Marian \
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_target_languages
)));
}
}
)),
None,
)
} else {
(None, None)
}
}
Self::T5(_) => (
Some(format!(
"translate {} to {}:",
match source_language {
Some(value) => value.to_string(),
None => {
return Err(RustBertError::ValueError(
"Missing source language for T5".to_string(),
));
}
},
match target_language {
Some(value) => value.to_string(),
None => {
return Err(RustBertError::ValueError(
"Missing target language for T5".to_string(),
));
}
}
)),
None,
),
Self::MBart(ref model) => (
Some(format!(
">>{}<< ",
match source_language {
Some(value) => value.get_iso_639_1_code(),
None => {
return Err(RustBertError::ValueError(format!(
"Missing source language for MBart\
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_source_languages
)));
}
}
)),
if let Some(target_language) = target_language {
Some(
model._get_tokenizer().convert_tokens_to_ids([format!(
">>{}<<",
target_language.get_iso_639_1_code()
)])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for MBart\
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_target_languages
)));
},
),
Self::M2M100(ref model) => (
Some(match source_language {
Some(value) => {
let language_code = value.get_iso_639_1_code();
match language_code.len() {
2 => format!(">>{}.<< ", language_code),
3 => format!(">>{}<< ", language_code),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
));
}
}
}
None => {
return Err(RustBertError::ValueError(format!(
"Missing source language for M2M100 \
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_source_languages
)));
}
}),
if let Some(target_language) = target_language {
let language_code = target_language.get_iso_639_1_code();
Some(
model
._get_tokenizer()
.convert_tokens_to_ids([match language_code.len() {
2 => format!(">>{}.<<", language_code),
3 => format!(">>{}<<", language_code),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
));
}
}])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for M2M100 \
(multiple languages supported by model: {:?}, \
need to specify target language)",
supported_target_languages
)));
},
),
})
}
/// Interface method to generate() of the particular models.
pub fn generate<'a, S>(
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
forced_bos_token_id: Option<i64>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
match *self {
Self::Marian(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::MBart(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
forced_bos_token_id,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::M2M100(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
forced_bos_token_id,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
}
}
}
/// # TranslationModel to perform translation
pub struct TranslationModel {
model: TranslationOption,
supported_source_languages: HashSet<Language>,
supported_target_languages: HashSet<Language>,
}
impl TranslationModel {
/// Build a new `TranslationModel`
///
/// # Arguments
///
/// * `translation_config` - `TranslationConfig` object containing the resource references (model, vocabulary, configuration), translation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
/// use tch::Device;
/// use rust_bert::resources::{Resource, RemoteResource};
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages};
/// use rust_bert::pipelines::common::ModelType;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// ));
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianConfigResources::ROMANCE2ENGLISH,
/// ));
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianVocabResources::ROMANCE2ENGLISH,
/// ));
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect();
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect();
///
/// let translation_config = TranslationConfig::new(
/// ModelType::Marian,
/// model_resource,
/// config_resource,
/// vocab_resource.clone(),
/// vocab_resource,
/// source_languages,
/// target_languages,
/// device: Device::cuda_if_available(),
/// );
/// let mut summarization_model = TranslationModel::new(translation_config)?;
/// # Ok(())
/// # }
/// ```
pub fn new(translation_config: TranslationConfig) -> Result<TranslationModel, RustBertError> {
let supported_source_languages = translation_config.source_languages.clone();
let supported_target_languages = translation_config.target_languages.clone();
let model = TranslationOption::new(translation_config)?;
Ok(TranslationModel {
model,
supported_source_languages,
supported_target_languages,
})
}
/// Translates texts provided
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of texts to summarize.
///
/// # Returns
/// * `Vec<String>` Translated texts
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel, Language};
/// use tch::Device;
/// use rust_bert::resources::{Resource, RemoteResource};
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages, MarianSpmResources};
/// use rust_bert::pipelines::common::ModelType;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianModelResources::ENGLISH2ROMANCE,
/// ));
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianConfigResources::ENGLISH2ROMANCE,
/// ));
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianVocabResources::ENGLISH2ROMANCE,
/// ));
/// let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianSpmResources::ENGLISH2ROMANCE,
/// ));
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE.iter().collect();
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE.iter().collect();
///
/// let translation_config = TranslationConfig::new(
/// ModelType::Marian,
/// model_resource,
/// config_resource,
/// vocab_resource,
/// merges_resource,
/// source_languages,
/// target_languages,
/// device: Device::cuda_if_available(),
/// );
/// let model = TranslationModel::new(translation_config)?;
///
/// let input = ["This is a sentence to be translated"];
/// let source_language = None;
/// let target_language = Language::French;
///
/// let output = model.translate(&input, source_language, target_language);
/// # Ok(())
/// # }
/// ```
pub fn translate<'a, S>(
&self,
texts: S,
source_language: impl Into<Option<Language>>,
target_language: impl Into<Option<Language>>,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<[&'a str]>,
{
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
source_language.into().as_ref(),
target_language.into().as_ref(),
&self.supported_source_languages,
&self.supported_target_languages,
)?;
Ok(match prefix {
Some(value) => {
let texts = texts
.as_ref()
.iter()
.map(|&v| format!("{}{}", value, v))
.collect::<Vec<String>>();
self.model.generate(
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
None,
forced_bos_token_id,
)
}
None => self.model.generate(Some(texts), None, forced_bos_token_id),
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
MarianVocabResources,
};
use crate::resources::RemoteResource;
#[test]
#[ignore] // no need to run, compilation is enough to verify it is Send
fn test() {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianModelResources::ROMANCE2ENGLISH,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianConfigResources::ROMANCE2ENGLISH,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ROMANCE2ENGLISH,
));
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
config_resource,
vocab_resource.clone(),
vocab_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let _: Box<dyn Send> = Box::new(TranslationModel::new(translation_config));
}
}

View File

@ -1,8 +1,9 @@
use rust_bert::m2m_100::{
M2M100Config, M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100Model,
M2M100ModelResources, M2M100VocabResources,
M2M100Config, M2M100ConfigResources, M2M100MergesResources, M2M100Model, M2M100ModelResources,
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
@ -75,43 +76,48 @@ fn m2m100_lm_model() -> anyhow::Result<()> {
#[test]
fn m2m100_translation() -> anyhow::Result<()> {
// Resources paths
let generate_config = GenerateConfig {
max_length: 56,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
do_sample: false,
num_beams: 3,
..Default::default()
};
let model = M2M100Generator::new(generate_config)?;
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let input_context = ">>en.<< The dog did not wake up.";
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M;
let output = model.generate(
Some(&[input_context]),
None,
None,
None,
None,
target_language,
None,
false,
let translation_config = TranslationConfig::new(
ModelType::M2M100,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
assert_eq!(output.len(), 1);
assert_eq!(output[0].text, ">>es.<< El perro no se despertó.");
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
assert_eq!(outputs.len(), 3);
assert_eq!(
outputs[0],
" Cette phrase sera traduite en plusieurs langues."
);
assert_eq!(outputs[1], " Esta frase se traducirá en varios idiomas.");
assert_eq!(outputs[2], " यह वाक्यांश कई भाषाओं में अनुवादित किया जाएगा।");
Ok(())
}

View File

@ -1,24 +1,82 @@
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> anyhow::Result<()> {
// Set-up translation model
let translation_config = TranslationConfig::new(OldLanguage::EnglishToFrench, Device::Cpu);
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianModelResources::ENGLISH2ROMANCE,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianConfigResources::ENGLISH2ROMANCE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ENGLISH2ROMANCE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianSpmResources::ENGLISH2ROMANCE,
));
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let output = model.translate(&[input_context_1, input_context_2]);
let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
assert_eq!(output.len(), 2);
assert_eq!(outputs.len(), 2);
assert_eq!(
output[0],
outputs[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(output[1], " Le chien ne s'est pas réveillé");
assert_eq!(outputs[1], " Le chien ne s'est pas réveillé");
Ok(())
}
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation_builder() -> anyhow::Result<()> {
let model = TranslationModelBuilder::new()
.with_device(Device::cuda_if_available())
.with_model_type(ModelType::Marian)
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::French])
.create_model()?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
assert_eq!(outputs.len(), 2);
assert_eq!(
outputs[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(outputs[1], " Le chien ne s'est pas réveillé");
Ok(())
}

View File

@ -1,8 +1,8 @@
use rust_bert::mbart::{
MBartConfig, MBartConfigResources, MBartGenerator, MBartModel, MBartModelResources,
MBartVocabResources,
MBartConfig, MBartConfigResources, MBartModel, MBartModelResources, MBartVocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
@ -65,46 +65,28 @@ fn mbart_lm_model() -> anyhow::Result<()> {
#[test]
fn mbart_translation() -> anyhow::Result<()> {
// Resources paths
let generate_config = GenerateConfig {
max_length: 56,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
do_sample: false,
num_beams: 3,
..Default::default()
};
let model = MBartGenerator::new(generate_config)?;
let model = TranslationModelBuilder::new()
.with_device(Device::cuda_if_available())
.with_model_type(ModelType::MBart)
.create_model()?;
let input_context = "en_XX The quick brown fox jumps over the lazy dog.";
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
let source_sentence = "This sentence will be translated in multiple languages.";
let output = model.generate(
Some(&[input_context]),
None,
None,
None,
None,
target_language,
None,
false,
);
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
assert_eq!(output.len(), 1);
assert_eq!(outputs.len(), 3);
assert_eq!(
output[0].text,
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
outputs[0],
" Cette phrase sera traduite en plusieurs langues."
);
assert_eq!(
outputs[1],
" Esta frase será traducida en múltiples idiomas."
);
assert_eq!(outputs[2], " यह वाक्य कई भाषाओं में अनुवाद किया जाएगा.");
Ok(())
}

View File

@ -1,31 +1,65 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device;
#[test]
fn test_translation_t5() -> anyhow::Result<()> {
// Set-up translation model
let translation_config = TranslationConfig::new_from_resources(
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
Some("translate English to French:".to_string()),
Device::cuda_if_available(),
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let source_languages = [
Language::English,
Language::French,
Language::German,
Language::Romanian,
];
let target_languages = [
Language::English,
Language::French,
Language::German,
Language::Romanian,
];
let translation_config = TranslationConfig::new(
ModelType::T5,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let input_context = "The quick brown fox jumps over the lazy dog.";
let source_sentence = "This sentence will be translated in multiple languages.";
let output = model.translate(&[input_context]);
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
assert_eq!(outputs.len(), 3);
assert_eq!(
output[0],
" Le renard brun rapide saute au-dessus du chien paresseux."
outputs[0],
" Cette phrase sera traduite dans plusieurs langues."
);
assert_eq!(
outputs[1],
" Dieser Satz wird in mehreren Sprachen übersetzt."
);
assert_eq!(
outputs[2],
" Această frază va fi tradusă în mai multe limbi."
);
Ok(())