mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Updated examples and integration tests
This commit is contained in:
parent
89b3a327fa
commit
ce90d8901d
@ -23,17 +23,13 @@ fn main() -> anyhow::Result<()> {
|
||||
.with_model_type(ModelType::Marian)
|
||||
// .with_large_model()
|
||||
.with_source_languages(vec![Language::English])
|
||||
.with_target_languages(vec![Language::Hebrew])
|
||||
.with_target_languages(vec![Language::Spanish])
|
||||
.create_model()?;
|
||||
|
||||
let input_context_1 = "The quick brown fox jumps over the lazy dog.";
|
||||
let input_context_2 = "The dog did not wake up.";
|
||||
|
||||
let output = model.translate(
|
||||
&[input_context_1, input_context_2],
|
||||
Language::English,
|
||||
Language::Hebrew,
|
||||
)?;
|
||||
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{}", sentence);
|
||||
|
@ -13,54 +13,52 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::m2m_100::{
|
||||
M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources,
|
||||
M2M100VocabResources,
|
||||
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
|
||||
M2M100TargetLanguages, M2M100VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use tch::Device;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 512,
|
||||
min_length: 0,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ModelResources::M2M100_418M,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ConfigResources::M2M100_418M,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100VocabResources::M2M100_418M,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100MergesResources::M2M100_418M,
|
||||
)),
|
||||
do_sample: false,
|
||||
early_stopping: true,
|
||||
num_beams: 3,
|
||||
no_repeat_ngram_size: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ModelResources::M2M100_418M,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ConfigResources::M2M100_418M,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100VocabResources::M2M100_418M,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100MergesResources::M2M100_418M,
|
||||
));
|
||||
|
||||
let model = M2M100Generator::new(generate_config)?;
|
||||
let source_languages = M2M100SourceLanguages::M2M100_418M;
|
||||
let target_languages = M2M100TargetLanguages::M2M100_418M;
|
||||
|
||||
let input_context_1 = ">>en.<< The dog did not wake up.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
|
||||
|
||||
println!("{:?} - {:?}", input_context_1, target_language);
|
||||
let output = model.generate(
|
||||
Some(&[input_context_1]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
false,
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::M2M100,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence);
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -13,48 +13,52 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::mbart::{
|
||||
MBartConfigResources, MBartGenerator, MBartModelResources, MBartVocabResources,
|
||||
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
|
||||
MBartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use tch::Device;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 56,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
do_sample: false,
|
||||
num_beams: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let model = MBartGenerator::new(generate_config)?;
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
|
||||
let input_context_1 = "en_XX The quick brown fox jumps over the lazy dog.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
|
||||
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
|
||||
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
|
||||
|
||||
let output = model.generate(
|
||||
Some(&[input_context_1]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
false,
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::MBart,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence.text);
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -12,46 +12,56 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::t5::{T5ConfigResources, T5Generator, T5ModelResources, T5VocabResources};
|
||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||
use tch::Device;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let model_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
model_resource: weights_resource,
|
||||
vocab_resource,
|
||||
let source_languages = [
|
||||
Language::English,
|
||||
Language::French,
|
||||
Language::German,
|
||||
Language::Romanian,
|
||||
];
|
||||
let target_languages = [
|
||||
Language::English,
|
||||
Language::French,
|
||||
Language::German,
|
||||
Language::Romanian,
|
||||
];
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::T5,
|
||||
model_resource,
|
||||
config_resource,
|
||||
max_length: 40,
|
||||
do_sample: false,
|
||||
num_beams: 4,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Set-up model
|
||||
let t5_model = T5Generator::new(generate_config)?;
|
||||
|
||||
// Define input
|
||||
let input = ["translate English to German: This sentence will get translated to German"];
|
||||
|
||||
let output = t5_model.generate(
|
||||
Some(input.to_vec()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
println!("{:?}", output);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -34,14 +34,10 @@ enum ModelSize {
|
||||
XLarge,
|
||||
}
|
||||
|
||||
pub struct TranslationModelBuilder<S, T>
|
||||
where
|
||||
S: AsRef<[Language]> + Debug,
|
||||
T: AsRef<[Language]> + Debug,
|
||||
{
|
||||
pub struct TranslationModelBuilder {
|
||||
model_type: Option<ModelType>,
|
||||
source_languages: Option<S>,
|
||||
target_languages: Option<T>,
|
||||
source_languages: Option<Vec<Language>>,
|
||||
target_languages: Option<Vec<Language>>,
|
||||
device: Option<Device>,
|
||||
model_size: Option<ModelSize>,
|
||||
}
|
||||
@ -61,12 +57,8 @@ macro_rules! get_marian_resources {
|
||||
};
|
||||
}
|
||||
|
||||
impl<S, T> TranslationModelBuilder<S, T>
|
||||
where
|
||||
S: AsRef<[Language]> + Debug,
|
||||
T: AsRef<[Language]> + Debug,
|
||||
{
|
||||
pub fn new() -> TranslationModelBuilder<S, T> {
|
||||
impl TranslationModelBuilder {
|
||||
pub fn new() -> TranslationModelBuilder {
|
||||
TranslationModelBuilder {
|
||||
model_type: None,
|
||||
source_languages: None,
|
||||
@ -131,20 +123,26 @@ where
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_source_languages(&mut self, source_languages: S) -> &mut Self {
|
||||
self.source_languages = Some(source_languages);
|
||||
pub fn with_source_languages<S>(&mut self, source_languages: S) -> &mut Self
|
||||
where
|
||||
S: AsRef<[Language]> + Debug,
|
||||
{
|
||||
self.source_languages = Some(source_languages.as_ref().to_vec());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_target_languages(&mut self, target_languages: T) -> &mut Self {
|
||||
self.target_languages = Some(target_languages);
|
||||
pub fn with_target_languages<T>(&mut self, target_languages: T) -> &mut Self
|
||||
where
|
||||
T: AsRef<[Language]> + Debug,
|
||||
{
|
||||
self.target_languages = Some(target_languages.as_ref().to_vec());
|
||||
self
|
||||
}
|
||||
|
||||
fn get_default_model(
|
||||
&self,
|
||||
source_languages: Option<&S>,
|
||||
target_languages: Option<&T>,
|
||||
source_languages: Option<&Vec<Language>>,
|
||||
target_languages: Option<&Vec<Language>>,
|
||||
) -> Result<TranslationResources, RustBertError> {
|
||||
Ok(
|
||||
match self.get_marian_model(source_languages, target_languages) {
|
||||
@ -161,14 +159,14 @@ where
|
||||
|
||||
fn get_marian_model(
|
||||
&self,
|
||||
source_languages: Option<&S>,
|
||||
target_languages: Option<&T>,
|
||||
source_languages: Option<&Vec<Language>>,
|
||||
target_languages: Option<&Vec<Language>>,
|
||||
) -> Result<TranslationResources, RustBertError> {
|
||||
let (resources, source_languages, target_languages) =
|
||||
if let (Some(source_languages), Some(target_languages)) =
|
||||
(source_languages, target_languages)
|
||||
{
|
||||
match (source_languages.as_ref(), target_languages.as_ref()) {
|
||||
match (source_languages.as_slice(), target_languages.as_slice()) {
|
||||
([Language::English], [Language::German]) => {
|
||||
get_marian_resources!(ENGLISH2RUSSIAN)
|
||||
}
|
||||
@ -257,18 +255,17 @@ where
|
||||
|
||||
fn get_mbart50_resources(
|
||||
&self,
|
||||
source_languages: Option<&S>,
|
||||
target_languages: Option<&T>,
|
||||
source_languages: Option<&Vec<Language>>,
|
||||
target_languages: Option<&Vec<Language>>,
|
||||
) -> Result<TranslationResources, RustBertError> {
|
||||
if let Some(source_languages) = source_languages {
|
||||
if !source_languages
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|lang| MBartSourceLanguages::MBART50_MANY_TO_MANY.contains(lang))
|
||||
{
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{:?} not in list of supported languages: {:?}",
|
||||
source_languages.as_ref(),
|
||||
source_languages,
|
||||
MBartSourceLanguages::MBART50_MANY_TO_MANY
|
||||
)));
|
||||
}
|
||||
@ -276,7 +273,6 @@ where
|
||||
|
||||
if let Some(target_languages) = target_languages {
|
||||
if !target_languages
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|lang| MBartTargetLanguages::MBART50_MANY_TO_MANY.contains(lang))
|
||||
{
|
||||
@ -315,18 +311,17 @@ where
|
||||
|
||||
fn get_m2m100_large_resources(
|
||||
&self,
|
||||
source_languages: Option<&S>,
|
||||
target_languages: Option<&T>,
|
||||
source_languages: Option<&Vec<Language>>,
|
||||
target_languages: Option<&Vec<Language>>,
|
||||
) -> Result<TranslationResources, RustBertError> {
|
||||
if let Some(source_languages) = source_languages {
|
||||
if !source_languages
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|lang| M2M100SourceLanguages::M2M100_418M.contains(lang))
|
||||
{
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{:?} not in list of supported languages: {:?}",
|
||||
source_languages.as_ref(),
|
||||
source_languages,
|
||||
M2M100SourceLanguages::M2M100_418M
|
||||
)));
|
||||
}
|
||||
@ -334,7 +329,6 @@ where
|
||||
|
||||
if let Some(target_languages) = target_languages {
|
||||
if !target_languages
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|lang| M2M100TargetLanguages::M2M100_418M.contains(lang))
|
||||
{
|
||||
@ -367,18 +361,17 @@ where
|
||||
|
||||
fn get_m2m100_xlarge_resources(
|
||||
&self,
|
||||
source_languages: Option<&S>,
|
||||
target_languages: Option<&T>,
|
||||
source_languages: Option<&Vec<Language>>,
|
||||
target_languages: Option<&Vec<Language>>,
|
||||
) -> Result<TranslationResources, RustBertError> {
|
||||
if let Some(source_languages) = source_languages {
|
||||
if !source_languages
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|lang| M2M100SourceLanguages::M2M100_1_2B.contains(lang))
|
||||
{
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{:?} not in list of supported languages: {:?}",
|
||||
source_languages.as_ref(),
|
||||
source_languages,
|
||||
M2M100SourceLanguages::M2M100_1_2B
|
||||
)));
|
||||
}
|
||||
@ -386,7 +379,6 @@ where
|
||||
|
||||
if let Some(target_languages) = target_languages {
|
||||
if !target_languages
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|lang| M2M100TargetLanguages::M2M100_1_2B.contains(lang))
|
||||
{
|
||||
|
991
src/pipelines/translation/translation_pipeline.rs
Normal file
991
src/pipelines/translation/translation_pipeline.rs
Normal file
@ -0,0 +1,991 @@
|
||||
// Copyright 2018-2020 The HuggingFace Inc. team.
|
||||
// Copyright 2020 Marian Team Authors
|
||||
// Copyright 2019-2020 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{Device, Tensor};
|
||||
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::common::resources::Resource;
|
||||
use crate::m2m_100::M2M100Generator;
|
||||
use crate::marian::MarianGenerator;
|
||||
use crate::mbart::MBartGenerator;
|
||||
use crate::pipelines::common::ModelType;
|
||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use crate::t5::T5Generator;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
/// Language
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
|
||||
pub enum Language {
|
||||
Afrikaans,
|
||||
Danish,
|
||||
Dutch,
|
||||
German,
|
||||
English,
|
||||
Icelandic,
|
||||
Luxembourgish,
|
||||
Norwegian,
|
||||
Swedish,
|
||||
WesternFrisian,
|
||||
Yiddish,
|
||||
Asturian,
|
||||
Catalan,
|
||||
French,
|
||||
Galician,
|
||||
Italian,
|
||||
Occitan,
|
||||
Portuguese,
|
||||
Romanian,
|
||||
Spanish,
|
||||
Belarusian,
|
||||
Bosnian,
|
||||
Bulgarian,
|
||||
Croatian,
|
||||
Czech,
|
||||
Macedonian,
|
||||
Polish,
|
||||
Russian,
|
||||
Serbian,
|
||||
Slovak,
|
||||
Slovenian,
|
||||
Ukrainian,
|
||||
Estonian,
|
||||
Finnish,
|
||||
Hungarian,
|
||||
Latvian,
|
||||
Lithuanian,
|
||||
Albanian,
|
||||
Armenian,
|
||||
Georgian,
|
||||
Greek,
|
||||
Breton,
|
||||
Irish,
|
||||
ScottishGaelic,
|
||||
Welsh,
|
||||
Azerbaijani,
|
||||
Bashkir,
|
||||
Kazakh,
|
||||
Turkish,
|
||||
Uzbek,
|
||||
Japanese,
|
||||
Korean,
|
||||
Vietnamese,
|
||||
ChineseMandarin,
|
||||
Bengali,
|
||||
Gujarati,
|
||||
Hindi,
|
||||
Kannada,
|
||||
Marathi,
|
||||
Nepali,
|
||||
Oriya,
|
||||
Panjabi,
|
||||
Sindhi,
|
||||
Sinhala,
|
||||
Urdu,
|
||||
Tamil,
|
||||
Cebuano,
|
||||
Iloko,
|
||||
Indonesian,
|
||||
Javanese,
|
||||
Malagasy,
|
||||
Malay,
|
||||
Malayalam,
|
||||
Sundanese,
|
||||
Tagalog,
|
||||
Burmese,
|
||||
CentralKhmer,
|
||||
Lao,
|
||||
Thai,
|
||||
Mongolian,
|
||||
Arabic,
|
||||
Hebrew,
|
||||
Pashto,
|
||||
Farsi,
|
||||
Amharic,
|
||||
Fulah,
|
||||
Hausa,
|
||||
Igbo,
|
||||
Lingala,
|
||||
Luganda,
|
||||
NorthernSotho,
|
||||
Somali,
|
||||
Swahili,
|
||||
Swati,
|
||||
Tswana,
|
||||
Wolof,
|
||||
Xhosa,
|
||||
Yoruba,
|
||||
Zulu,
|
||||
HaitianCreole,
|
||||
}
|
||||
|
||||
impl Display for Language {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", {
|
||||
let input_string = format!("{:?}", self);
|
||||
let mut output: Vec<&str> = Vec::new();
|
||||
let mut start: usize = 0;
|
||||
|
||||
for (c_pos, c) in input_string.char_indices() {
|
||||
if c.is_uppercase() {
|
||||
if start < c_pos {
|
||||
output.push(&input_string[start..c_pos]);
|
||||
}
|
||||
start = c_pos;
|
||||
}
|
||||
}
|
||||
if start < input_string.len() {
|
||||
output.push(&input_string[start..]);
|
||||
}
|
||||
output.join(" ")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Language {
|
||||
pub fn get_iso_639_1_code(&self) -> &'static str {
|
||||
match self {
|
||||
Language::Afrikaans => "af",
|
||||
Language::Danish => "da",
|
||||
Language::Dutch => "nl",
|
||||
Language::German => "de",
|
||||
Language::English => "en",
|
||||
Language::Icelandic => "is",
|
||||
Language::Luxembourgish => "lb",
|
||||
Language::Norwegian => "no",
|
||||
Language::Swedish => "sv",
|
||||
Language::WesternFrisian => "fy",
|
||||
Language::Yiddish => "yi",
|
||||
Language::Asturian => "ast",
|
||||
Language::Catalan => "ca",
|
||||
Language::French => "fr",
|
||||
Language::Galician => "gl",
|
||||
Language::Italian => "it",
|
||||
Language::Occitan => "oc",
|
||||
Language::Portuguese => "pt",
|
||||
Language::Romanian => "ro",
|
||||
Language::Spanish => "es",
|
||||
Language::Belarusian => "be",
|
||||
Language::Bosnian => "bs",
|
||||
Language::Bulgarian => "bg",
|
||||
Language::Croatian => "hr",
|
||||
Language::Czech => "cs",
|
||||
Language::Macedonian => "mk",
|
||||
Language::Polish => "pl",
|
||||
Language::Russian => "ru",
|
||||
Language::Serbian => "sr",
|
||||
Language::Slovak => "sk",
|
||||
Language::Slovenian => "sl",
|
||||
Language::Ukrainian => "uk",
|
||||
Language::Estonian => "et",
|
||||
Language::Finnish => "fi",
|
||||
Language::Hungarian => "hu",
|
||||
Language::Latvian => "lv",
|
||||
Language::Lithuanian => "lt",
|
||||
Language::Albanian => "sq",
|
||||
Language::Armenian => "hy",
|
||||
Language::Georgian => "ka",
|
||||
Language::Greek => "el",
|
||||
Language::Breton => "br",
|
||||
Language::Irish => "ga",
|
||||
Language::ScottishGaelic => "gd",
|
||||
Language::Welsh => "cy",
|
||||
Language::Azerbaijani => "az",
|
||||
Language::Bashkir => "ba",
|
||||
Language::Kazakh => "kk",
|
||||
Language::Turkish => "tr",
|
||||
Language::Uzbek => "uz",
|
||||
Language::Japanese => "ja",
|
||||
Language::Korean => "ko",
|
||||
Language::Vietnamese => "vi",
|
||||
Language::ChineseMandarin => "zh",
|
||||
Language::Bengali => "bn",
|
||||
Language::Gujarati => "gu",
|
||||
Language::Hindi => "hi",
|
||||
Language::Kannada => "kn",
|
||||
Language::Marathi => "mr",
|
||||
Language::Nepali => "ne",
|
||||
Language::Oriya => "or",
|
||||
Language::Panjabi => "pa",
|
||||
Language::Sindhi => "sd",
|
||||
Language::Sinhala => "si",
|
||||
Language::Urdu => "ur",
|
||||
Language::Tamil => "ta",
|
||||
Language::Cebuano => "ceb",
|
||||
Language::Iloko => "ilo",
|
||||
Language::Indonesian => "id",
|
||||
Language::Javanese => "jv",
|
||||
Language::Malagasy => "mg",
|
||||
Language::Malay => "ms",
|
||||
Language::Malayalam => "ml",
|
||||
Language::Sundanese => "su",
|
||||
Language::Tagalog => "tl",
|
||||
Language::Burmese => "my",
|
||||
Language::CentralKhmer => "km",
|
||||
Language::Lao => "lo",
|
||||
Language::Thai => "th",
|
||||
Language::Mongolian => "mn",
|
||||
Language::Arabic => "ar",
|
||||
Language::Hebrew => "he",
|
||||
Language::Pashto => "ps",
|
||||
Language::Farsi => "fa",
|
||||
Language::Amharic => "am",
|
||||
Language::Fulah => "ff",
|
||||
Language::Hausa => "ha",
|
||||
Language::Igbo => "ig",
|
||||
Language::Lingala => "ln",
|
||||
Language::Luganda => "lg",
|
||||
Language::NorthernSotho => "nso",
|
||||
Language::Somali => "so",
|
||||
Language::Swahili => "sw",
|
||||
Language::Swati => "ss",
|
||||
Language::Tswana => "tn",
|
||||
Language::Wolof => "wo",
|
||||
Language::Xhosa => "xh",
|
||||
Language::Yoruba => "yo",
|
||||
Language::Zulu => "zu",
|
||||
Language::HaitianCreole => "ht",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_iso_639_3_code(&self) -> &'static str {
|
||||
match self {
|
||||
Language::Afrikaans => "afr",
|
||||
Language::Danish => "dan",
|
||||
Language::Dutch => "nld",
|
||||
Language::German => "deu",
|
||||
Language::English => "eng",
|
||||
Language::Icelandic => "isl",
|
||||
Language::Luxembourgish => "ltz",
|
||||
Language::Norwegian => "nor",
|
||||
Language::Swedish => "swe",
|
||||
Language::WesternFrisian => "fry",
|
||||
Language::Yiddish => "yid",
|
||||
Language::Asturian => "ast",
|
||||
Language::Catalan => "cat",
|
||||
Language::French => "fra",
|
||||
Language::Galician => "glg",
|
||||
Language::Italian => "ita",
|
||||
Language::Occitan => "oci",
|
||||
Language::Portuguese => "por",
|
||||
Language::Romanian => "ron",
|
||||
Language::Spanish => "spa",
|
||||
Language::Belarusian => "bel",
|
||||
Language::Bosnian => "bos",
|
||||
Language::Bulgarian => "bul",
|
||||
Language::Croatian => "hrv",
|
||||
Language::Czech => "ces",
|
||||
Language::Macedonian => "mkd",
|
||||
Language::Polish => "pol",
|
||||
Language::Russian => "rus",
|
||||
Language::Serbian => "srp",
|
||||
Language::Slovak => "slk",
|
||||
Language::Slovenian => "slv",
|
||||
Language::Ukrainian => "ukr",
|
||||
Language::Estonian => "est",
|
||||
Language::Finnish => "fin",
|
||||
Language::Hungarian => "hun",
|
||||
Language::Latvian => "lav",
|
||||
Language::Lithuanian => "lit",
|
||||
Language::Albanian => "sqi",
|
||||
Language::Armenian => "hye",
|
||||
Language::Georgian => "kat",
|
||||
Language::Greek => "ell",
|
||||
Language::Breton => "bre",
|
||||
Language::Irish => "gle",
|
||||
Language::ScottishGaelic => "gla",
|
||||
Language::Welsh => "cym",
|
||||
Language::Azerbaijani => "aze",
|
||||
Language::Bashkir => "bak",
|
||||
Language::Kazakh => "kaz",
|
||||
Language::Turkish => "tur",
|
||||
Language::Uzbek => "uzb",
|
||||
Language::Japanese => "jpn",
|
||||
Language::Korean => "kor",
|
||||
Language::Vietnamese => "vie",
|
||||
Language::ChineseMandarin => "cmn",
|
||||
Language::Bengali => "ben",
|
||||
Language::Gujarati => "guj",
|
||||
Language::Hindi => "hin",
|
||||
Language::Kannada => "kan",
|
||||
Language::Marathi => "mar",
|
||||
Language::Nepali => "nep",
|
||||
Language::Oriya => "ori",
|
||||
Language::Panjabi => "pan",
|
||||
Language::Sindhi => "snd",
|
||||
Language::Sinhala => "sin",
|
||||
Language::Urdu => "urd",
|
||||
Language::Tamil => "tam",
|
||||
Language::Cebuano => "ceb",
|
||||
Language::Iloko => "ilo",
|
||||
Language::Indonesian => "ind",
|
||||
Language::Javanese => "jav",
|
||||
Language::Malagasy => "mlg",
|
||||
Language::Malay => "msa",
|
||||
Language::Malayalam => "mal",
|
||||
Language::Sundanese => "sun",
|
||||
Language::Tagalog => "tgl",
|
||||
Language::Burmese => "mya",
|
||||
Language::CentralKhmer => "khm",
|
||||
Language::Lao => "lao",
|
||||
Language::Thai => "tha",
|
||||
Language::Mongolian => "mon",
|
||||
Language::Arabic => "ara",
|
||||
Language::Hebrew => "heb",
|
||||
Language::Pashto => "pus",
|
||||
Language::Farsi => "fas",
|
||||
Language::Amharic => "amh",
|
||||
Language::Fulah => "ful",
|
||||
Language::Hausa => "hau",
|
||||
Language::Igbo => "ibo",
|
||||
Language::Lingala => "lin",
|
||||
Language::Luganda => "lug",
|
||||
Language::NorthernSotho => "nso",
|
||||
Language::Somali => "som",
|
||||
Language::Swahili => "swa",
|
||||
Language::Swati => "ssw",
|
||||
Language::Tswana => "tsn",
|
||||
Language::Wolof => "wol",
|
||||
Language::Xhosa => "xho",
|
||||
Language::Yoruba => "yor",
|
||||
Language::Zulu => "zul",
|
||||
Language::HaitianCreole => "hat",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Configuration for text translation
|
||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||
/// different set of default parameters and sets the device to place the model on.
|
||||
pub struct TranslationConfig {
|
||||
/// Model type used for translation
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource
|
||||
pub model_resource: Resource,
|
||||
/// Config resource
|
||||
pub config_resource: Resource,
|
||||
/// Vocab resource
|
||||
pub vocab_resource: Resource,
|
||||
/// Merges resource
|
||||
pub merges_resource: Resource,
|
||||
/// Supported source languages
|
||||
pub source_languages: HashSet<Language>,
|
||||
/// Supported target languages
|
||||
pub target_languages: HashSet<Language>,
|
||||
/// Minimum sequence length (default: 0)
|
||||
pub min_length: i64,
|
||||
/// Maximum sequence length (default: 20)
|
||||
pub max_length: i64,
|
||||
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
|
||||
pub do_sample: bool,
|
||||
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
|
||||
pub early_stopping: bool,
|
||||
/// Number of beams for beam search (default: 5)
|
||||
pub num_beams: i64,
|
||||
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
|
||||
pub temperature: f64,
|
||||
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
|
||||
pub top_k: i64,
|
||||
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
|
||||
pub top_p: f64,
|
||||
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
|
||||
pub repetition_penalty: f64,
|
||||
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
|
||||
pub length_penalty: f64,
|
||||
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
|
||||
pub no_repeat_ngram_size: i64,
|
||||
/// Number of sequences to return for each prompt text (default: 1)
|
||||
pub num_return_sequences: i64,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
|
||||
pub num_beam_groups: Option<i64>,
|
||||
/// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
|
||||
pub diversity_penalty: Option<f64>,
|
||||
}
|
||||
|
||||
impl TranslationConfig {
|
||||
/// Create a new `TranslationConfiguration` from an available language.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `language` - `Language` enum value (e.g. `Language::EnglishToFrench`)
|
||||
/// * `device` - `Device` to place the model on (CPU/GPU)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> anyhow::Result<()> {
|
||||
/// use rust_bert::marian::{
|
||||
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
|
||||
/// MarianVocabResources,
|
||||
/// };
|
||||
/// use rust_bert::pipelines::common::ModelType;
|
||||
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig};
|
||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||
/// ));
|
||||
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianConfigResources::ROMANCE2ENGLISH,
|
||||
/// ));
|
||||
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianVocabResources::ROMANCE2ENGLISH,
|
||||
/// ));
|
||||
///
|
||||
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect();
|
||||
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect();
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new(
|
||||
/// ModelType::Marian,
|
||||
/// model_resource,
|
||||
/// config_resource,
|
||||
/// vocab_resource.clone(),
|
||||
/// vocab_resource,
|
||||
/// source_languages,
|
||||
/// target_languages,
|
||||
/// device: Device::cuda_if_available(),
|
||||
/// );
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new<S, T>(
|
||||
model_type: ModelType,
|
||||
model_resource: Resource,
|
||||
config_resource: Resource,
|
||||
vocab_resource: Resource,
|
||||
merges_resource: Resource,
|
||||
source_languages: S,
|
||||
target_languages: T,
|
||||
device: impl Into<Option<Device>>,
|
||||
) -> TranslationConfig
|
||||
where
|
||||
S: AsRef<[Language]>,
|
||||
T: AsRef<[Language]>,
|
||||
{
|
||||
let device = device.into().unwrap_or_else(|| Device::cuda_if_available());
|
||||
|
||||
TranslationConfig {
|
||||
model_type,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages: source_languages.as_ref().iter().cloned().collect(),
|
||||
target_languages: target_languages.as_ref().iter().cloned().collect(),
|
||||
device,
|
||||
min_length: 0,
|
||||
max_length: 512,
|
||||
do_sample: false,
|
||||
early_stopping: true,
|
||||
num_beams: 3,
|
||||
temperature: 1.0,
|
||||
top_k: 50,
|
||||
top_p: 1.0,
|
||||
repetition_penalty: 1.0,
|
||||
length_penalty: 1.0,
|
||||
no_repeat_ngram_size: 0,
|
||||
num_return_sequences: 1,
|
||||
num_beam_groups: None,
|
||||
diversity_penalty: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TranslationConfig> for GenerateConfig {
|
||||
fn from(config: TranslationConfig) -> GenerateConfig {
|
||||
GenerateConfig {
|
||||
model_resource: config.model_resource,
|
||||
config_resource: config.config_resource,
|
||||
merges_resource: config.merges_resource,
|
||||
vocab_resource: config.vocab_resource,
|
||||
min_length: config.min_length,
|
||||
max_length: config.max_length,
|
||||
do_sample: config.do_sample,
|
||||
early_stopping: config.early_stopping,
|
||||
num_beams: config.num_beams,
|
||||
temperature: config.temperature,
|
||||
top_k: config.top_k,
|
||||
top_p: config.top_p,
|
||||
repetition_penalty: config.repetition_penalty,
|
||||
length_penalty: config.length_penalty,
|
||||
no_repeat_ngram_size: config.no_repeat_ngram_size,
|
||||
num_return_sequences: config.num_return_sequences,
|
||||
num_beam_groups: config.num_beam_groups,
|
||||
diversity_penalty: config.diversity_penalty,
|
||||
device: config.device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Abstraction that holds one particular translation model, for any of the supported models
|
||||
pub enum TranslationOption {
|
||||
/// Translator based on Marian model
|
||||
Marian(MarianGenerator),
|
||||
/// Translator based on T5 model
|
||||
T5(T5Generator),
|
||||
/// Translator based on MBart50 model
|
||||
MBart(MBartGenerator),
|
||||
/// Translator based on M2M100 model
|
||||
M2M100(M2M100Generator),
|
||||
}
|
||||
|
||||
impl TranslationOption {
|
||||
pub fn new(config: TranslationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
|
||||
ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Translation not implemented for {:?}!",
|
||||
config.model_type
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this TranslationOption
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
match *self {
|
||||
Self::Marian(_) => ModelType::Marian,
|
||||
Self::T5(_) => ModelType::T5,
|
||||
Self::MBart(_) => ModelType::MBart,
|
||||
Self::M2M100(_) => ModelType::M2M100,
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_and_get_prefix_and_forced_bos_id(
|
||||
&self,
|
||||
source_language: Option<&Language>,
|
||||
target_language: Option<&Language>,
|
||||
supported_source_languages: &HashSet<Language>,
|
||||
supported_target_languages: &HashSet<Language>,
|
||||
) -> Result<(Option<String>, Option<i64>), RustBertError> {
|
||||
if let Some(source_language) = source_language {
|
||||
if !supported_source_languages.contains(source_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{} not in list of supported languages: {:?}",
|
||||
source_language.to_string(),
|
||||
supported_source_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(target_language) = target_language {
|
||||
if !supported_target_languages.contains(target_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{} not in list of supported languages: {:?}",
|
||||
target_language.to_string(),
|
||||
supported_target_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(match *self {
|
||||
Self::Marian(_) => {
|
||||
if supported_target_languages.len() > 1 {
|
||||
(
|
||||
Some(format!(
|
||||
">>{}<< ",
|
||||
match target_language {
|
||||
Some(value) => value.get_iso_639_1_code(),
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for Marian \
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_target_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
)),
|
||||
None,
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
}
|
||||
}
|
||||
Self::T5(_) => (
|
||||
Some(format!(
|
||||
"translate {} to {}:",
|
||||
match source_language {
|
||||
Some(value) => value.to_string(),
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Missing source language for T5".to_string(),
|
||||
));
|
||||
}
|
||||
},
|
||||
match target_language {
|
||||
Some(value) => value.to_string(),
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Missing target language for T5".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
)),
|
||||
None,
|
||||
),
|
||||
Self::MBart(ref model) => (
|
||||
Some(format!(
|
||||
">>{}<< ",
|
||||
match source_language {
|
||||
Some(value) => value.get_iso_639_1_code(),
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing source language for MBart\
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_source_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
)),
|
||||
if let Some(target_language) = target_language {
|
||||
Some(
|
||||
model._get_tokenizer().convert_tokens_to_ids([format!(
|
||||
">>{}<<",
|
||||
target_language.get_iso_639_1_code()
|
||||
)])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for MBart\
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_target_languages
|
||||
)));
|
||||
},
|
||||
),
|
||||
Self::M2M100(ref model) => (
|
||||
Some(match source_language {
|
||||
Some(value) => {
|
||||
let language_code = value.get_iso_639_1_code();
|
||||
match language_code.len() {
|
||||
2 => format!(">>{}.<< ", language_code),
|
||||
3 => format!(">>{}<< ", language_code),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing source language for M2M100 \
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_source_languages
|
||||
)));
|
||||
}
|
||||
}),
|
||||
if let Some(target_language) = target_language {
|
||||
let language_code = target_language.get_iso_639_1_code();
|
||||
Some(
|
||||
model
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids([match language_code.len() {
|
||||
2 => format!(">>{}.<<", language_code),
|
||||
3 => format!(">>{}<<", language_code),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
));
|
||||
}
|
||||
}])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for M2M100 \
|
||||
(multiple languages supported by model: {:?}, \
|
||||
need to specify target language)",
|
||||
supported_target_languages
|
||||
)));
|
||||
},
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
/// Interface method to generate() of the particular models.
|
||||
pub fn generate<'a, S>(
|
||||
&self,
|
||||
prompt_texts: Option<S>,
|
||||
attention_mask: Option<Tensor>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) -> Vec<String>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
match *self {
|
||||
Self::Marian(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::T5(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::MBart(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
forced_bos_token_id,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
Self::M2M100(ref model) => model
|
||||
.generate(
|
||||
prompt_texts,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
forced_bos_token_id,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # TranslationModel to perform translation
|
||||
pub struct TranslationModel {
|
||||
model: TranslationOption,
|
||||
supported_source_languages: HashSet<Language>,
|
||||
supported_target_languages: HashSet<Language>,
|
||||
}
|
||||
|
||||
impl TranslationModel {
|
||||
/// Build a new `TranslationModel`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `translation_config` - `TranslationConfig` object containing the resource references (model, vocabulary, configuration), translation options and device placement (CPU/GPU)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> anyhow::Result<()> {
|
||||
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
/// use tch::Device;
|
||||
/// use rust_bert::resources::{Resource, RemoteResource};
|
||||
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages};
|
||||
/// use rust_bert::pipelines::common::ModelType;
|
||||
///
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||
/// ));
|
||||
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianConfigResources::ROMANCE2ENGLISH,
|
||||
/// ));
|
||||
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianVocabResources::ROMANCE2ENGLISH,
|
||||
/// ));
|
||||
///
|
||||
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect();
|
||||
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect();
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new(
|
||||
/// ModelType::Marian,
|
||||
/// model_resource,
|
||||
/// config_resource,
|
||||
/// vocab_resource.clone(),
|
||||
/// vocab_resource,
|
||||
/// source_languages,
|
||||
/// target_languages,
|
||||
/// device: Device::cuda_if_available(),
|
||||
/// );
|
||||
/// let mut summarization_model = TranslationModel::new(translation_config)?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(translation_config: TranslationConfig) -> Result<TranslationModel, RustBertError> {
|
||||
let supported_source_languages = translation_config.source_languages.clone();
|
||||
let supported_target_languages = translation_config.target_languages.clone();
|
||||
|
||||
let model = TranslationOption::new(translation_config)?;
|
||||
|
||||
Ok(TranslationModel {
|
||||
model,
|
||||
supported_source_languages,
|
||||
supported_target_languages,
|
||||
})
|
||||
}
|
||||
|
||||
/// Translates texts provided
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - `&[&str]` Array of texts to summarize.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Vec<String>` Translated texts
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> anyhow::Result<()> {
|
||||
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel, Language};
|
||||
/// use tch::Device;
|
||||
/// use rust_bert::resources::{Resource, RemoteResource};
|
||||
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages, MarianSpmResources};
|
||||
/// use rust_bert::pipelines::common::ModelType;
|
||||
///
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ENGLISH2ROMANCE,
|
||||
/// ));
|
||||
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianConfigResources::ENGLISH2ROMANCE,
|
||||
/// ));
|
||||
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianVocabResources::ENGLISH2ROMANCE,
|
||||
/// ));
|
||||
/// let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// MarianSpmResources::ENGLISH2ROMANCE,
|
||||
/// ));
|
||||
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE.iter().collect();
|
||||
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE.iter().collect();
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new(
|
||||
/// ModelType::Marian,
|
||||
/// model_resource,
|
||||
/// config_resource,
|
||||
/// vocab_resource,
|
||||
/// merges_resource,
|
||||
/// source_languages,
|
||||
/// target_languages,
|
||||
/// device: Device::cuda_if_available(),
|
||||
/// );
|
||||
/// let model = TranslationModel::new(translation_config)?;
|
||||
///
|
||||
/// let input = ["This is a sentence to be translated"];
|
||||
/// let source_language = None;
|
||||
/// let target_language = Language::French;
|
||||
///
|
||||
/// let output = model.translate(&input, source_language, target_language);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn translate<'a, S>(
|
||||
&self,
|
||||
texts: S,
|
||||
source_language: impl Into<Option<Language>>,
|
||||
target_language: impl Into<Option<Language>>,
|
||||
) -> Result<Vec<String>, RustBertError>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
|
||||
source_language.into().as_ref(),
|
||||
target_language.into().as_ref(),
|
||||
&self.supported_source_languages,
|
||||
&self.supported_target_languages,
|
||||
)?;
|
||||
|
||||
Ok(match prefix {
|
||||
Some(value) => {
|
||||
let texts = texts
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|&v| format!("{}{}", value, v))
|
||||
.collect::<Vec<String>>();
|
||||
self.model.generate(
|
||||
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
|
||||
None,
|
||||
forced_bos_token_id,
|
||||
)
|
||||
}
|
||||
None => self.model.generate(Some(texts), None, forced_bos_token_id),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::marian::{
|
||||
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
|
||||
MarianVocabResources,
|
||||
};
|
||||
use crate::resources::RemoteResource;
|
||||
|
||||
#[test]
|
||||
#[ignore] // no need to run, compilation is enough to verify it is Send
|
||||
fn test() {
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
));
|
||||
|
||||
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
||||
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::Marian,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource.clone(),
|
||||
vocab_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
let _: Box<dyn Send> = Box::new(TranslationModel::new(translation_config));
|
||||
}
|
||||
}
|
@ -1,8 +1,9 @@
|
||||
use rust_bert::m2m_100::{
|
||||
M2M100Config, M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100Model,
|
||||
M2M100ModelResources, M2M100VocabResources,
|
||||
M2M100Config, M2M100ConfigResources, M2M100MergesResources, M2M100Model, M2M100ModelResources,
|
||||
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
|
||||
@ -75,43 +76,48 @@ fn m2m100_lm_model() -> anyhow::Result<()> {
|
||||
|
||||
#[test]
|
||||
fn m2m100_translation() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 56,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ModelResources::M2M100_418M,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ConfigResources::M2M100_418M,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100VocabResources::M2M100_418M,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100MergesResources::M2M100_418M,
|
||||
)),
|
||||
do_sample: false,
|
||||
num_beams: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let model = M2M100Generator::new(generate_config)?;
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ModelResources::M2M100_418M,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ConfigResources::M2M100_418M,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100VocabResources::M2M100_418M,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100MergesResources::M2M100_418M,
|
||||
));
|
||||
|
||||
let input_context = ">>en.<< The dog did not wake up.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
|
||||
let source_languages = M2M100SourceLanguages::M2M100_418M;
|
||||
let target_languages = M2M100TargetLanguages::M2M100_418M;
|
||||
|
||||
let output = model.generate(
|
||||
Some(&[input_context]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
false,
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::M2M100,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0].text, ">>es.<< El perro no se despertó.");
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
outputs[0],
|
||||
" Cette phrase sera traduite en plusieurs langues."
|
||||
);
|
||||
assert_eq!(outputs[1], " Esta frase se traducirá en varios idiomas.");
|
||||
assert_eq!(outputs[2], " यह वाक्यांश कई भाषाओं में अनुवादित किया जाएगा।");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,24 +1,82 @@
|
||||
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
use rust_bert::marian::{
|
||||
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
MarianTargetLanguages, MarianVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::translation::{
|
||||
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use tch::Device;
|
||||
|
||||
#[test]
|
||||
// #[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn test_translation() -> anyhow::Result<()> {
|
||||
// Set-up translation model
|
||||
let translation_config = TranslationConfig::new(OldLanguage::EnglishToFrench, Device::Cpu);
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
));
|
||||
|
||||
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
|
||||
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::Marian,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||
let input_context_2 = "The dog did not wake up";
|
||||
|
||||
let output = model.translate(&[input_context_1, input_context_2]);
|
||||
let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(outputs.len(), 2);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
outputs[0],
|
||||
" Le rapide renard brun saute sur le chien paresseux"
|
||||
);
|
||||
assert_eq!(output[1], " Le chien ne s'est pas réveillé");
|
||||
assert_eq!(outputs[1], " Le chien ne s'est pas réveillé");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// #[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn test_translation_builder() -> anyhow::Result<()> {
|
||||
let model = TranslationModelBuilder::new()
|
||||
.with_device(Device::cuda_if_available())
|
||||
.with_model_type(ModelType::Marian)
|
||||
.with_source_languages(vec![Language::English])
|
||||
.with_target_languages(vec![Language::French])
|
||||
.create_model()?;
|
||||
|
||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||
let input_context_2 = "The dog did not wake up";
|
||||
|
||||
let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
|
||||
|
||||
assert_eq!(outputs.len(), 2);
|
||||
assert_eq!(
|
||||
outputs[0],
|
||||
" Le rapide renard brun saute sur le chien paresseux"
|
||||
);
|
||||
assert_eq!(outputs[1], " Le chien ne s'est pas réveillé");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
use rust_bert::mbart::{
|
||||
MBartConfig, MBartConfigResources, MBartGenerator, MBartModel, MBartModelResources,
|
||||
MBartVocabResources,
|
||||
MBartConfig, MBartConfigResources, MBartModel, MBartModelResources, MBartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
|
||||
@ -65,46 +65,28 @@ fn mbart_lm_model() -> anyhow::Result<()> {
|
||||
|
||||
#[test]
|
||||
fn mbart_translation() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 56,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
do_sample: false,
|
||||
num_beams: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let model = MBartGenerator::new(generate_config)?;
|
||||
let model = TranslationModelBuilder::new()
|
||||
.with_device(Device::cuda_if_available())
|
||||
.with_model_type(ModelType::MBart)
|
||||
.create_model()?;
|
||||
|
||||
let input_context = "en_XX The quick brown fox jumps over the lazy dog.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let output = model.generate(
|
||||
Some(&[input_context]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
output[0].text,
|
||||
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
|
||||
outputs[0],
|
||||
" Cette phrase sera traduite en plusieurs langues."
|
||||
);
|
||||
assert_eq!(
|
||||
outputs[1],
|
||||
" Esta frase será traducida en múltiples idiomas."
|
||||
);
|
||||
assert_eq!(outputs[2], " यह वाक्य कई भाषाओं में अनुवाद किया जाएगा.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
60
tests/t5.rs
60
tests/t5.rs
@ -1,31 +1,65 @@
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||
use tch::Device;
|
||||
|
||||
#[test]
|
||||
fn test_translation_t5() -> anyhow::Result<()> {
|
||||
// Set-up translation model
|
||||
let translation_config = TranslationConfig::new_from_resources(
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
|
||||
Some("translate English to French:".to_string()),
|
||||
Device::cuda_if_available(),
|
||||
let model_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
|
||||
|
||||
let source_languages = [
|
||||
Language::English,
|
||||
Language::French,
|
||||
Language::German,
|
||||
Language::Romanian,
|
||||
];
|
||||
let target_languages = [
|
||||
Language::English,
|
||||
Language::French,
|
||||
Language::German,
|
||||
Language::Romanian,
|
||||
];
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::T5,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
source_languages,
|
||||
target_languages,
|
||||
Device::cuda_if_available(),
|
||||
);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
let input_context = "The quick brown fox jumps over the lazy dog.";
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let output = model.translate(&[input_context]);
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
|
||||
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
" Le renard brun rapide saute au-dessus du chien paresseux."
|
||||
outputs[0],
|
||||
" Cette phrase sera traduite dans plusieurs langues."
|
||||
);
|
||||
assert_eq!(
|
||||
outputs[1],
|
||||
" Dieser Satz wird in mehreren Sprachen übersetzt."
|
||||
);
|
||||
assert_eq!(
|
||||
outputs[2],
|
||||
" Această frază va fi tradusă în mai multe limbi."
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
Loading…
Reference in New Issue
Block a user