Use of new language enum in TranslationModel

This commit is contained in:
Guillaume B 2021-07-04 12:56:34 +02:00
parent edb080f3a0
commit 1c375a817e
10 changed files with 527 additions and 655 deletions

View File

@ -4,13 +4,14 @@ extern crate criterion;
use criterion::{black_box, Criterion};
// use rust_bert::pipelines::common::ModelType;
// use rust_bert::pipelines::translation::TranslationOption::{Marian, T5};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
// use rust_bert::resources::{LocalResource, Resource};
use std::time::{Duration, Instant};
use tch::Device;
fn create_translation_model() -> TranslationModel {
let config = TranslationConfig::new(Language::EnglishToFrenchV2, Device::cuda_if_available());
let config =
TranslationConfig::new(OldLanguage::EnglishToFrenchV2, Device::cuda_if_available());
// let config = TranslationConfig::new_from_resources(
// Resource::Local(LocalResource {
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(),
@ -46,7 +47,7 @@ fn translation_load_model(iters: u64) -> Duration {
for _i in 0..iters {
let start = Instant::now();
let config =
TranslationConfig::new(Language::EnglishToFrenchV2, Device::cuda_if_available());
TranslationConfig::new(OldLanguage::EnglishToFrenchV2, Device::cuda_if_available());
// let config = TranslationConfig::new_from_resources(
// Resource::Local(LocalResource {
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(),

View File

@ -13,12 +13,12 @@
extern crate anyhow;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
use tch::Device;
fn main() -> anyhow::Result<()> {
let translation_config =
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
TranslationConfig::new(OldLanguage::EnglishToGerman, Device::cuda_if_available());
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";

View File

@ -167,10 +167,10 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
//! use tch::Device;
//! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];

View File

@ -19,6 +19,7 @@ use crate::pipelines::generation_utils::private_generation_utils::{
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::translation::Language;
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{MarianTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::MarianVocab;
@ -41,6 +42,12 @@ pub struct MarianSpmResources;
/// # Marian optional prefixes
pub struct MarianPrefix;
/// # Marian source languages pre-sets
pub struct MarianSourceLanguages;
/// # Marian target languages pre-sets
pub struct MarianTargetLanguages;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
@ -487,6 +494,68 @@ impl MarianPrefix {
pub const HEBREW2ENGLISH: Option<&'static str> = None;
}
impl MarianSourceLanguages {
pub const ENGLISH2ROMANCE: [Language; 1] = [Language::English];
pub const ENGLISH2GERMAN: [Language; 1] = [Language::English];
pub const ENGLISH2RUSSIAN: [Language; 1] = [Language::English];
pub const ENGLISH2DUTCH: [Language; 1] = [Language::English];
pub const ENGLISH2CHINESE: [Language; 1] = [Language::English];
pub const ENGLISH2SWEDISH: [Language; 1] = [Language::English];
pub const ENGLISH2ARABIC: [Language; 1] = [Language::English];
pub const ENGLISH2HINDI: [Language; 1] = [Language::English];
pub const ENGLISH2HEBREW: [Language; 1] = [Language::English];
pub const ROMANCE2ENGLISH: [Language; 7] = [
Language::French,
Language::Spanish,
Language::Italian,
Language::Catalan,
Language::Romanian,
Language::Portuguese,
Language::Occitan,
];
pub const GERMAN2ENGLISH: [Language; 1] = [Language::German];
pub const RUSSIAN2ENGLISH: [Language; 1] = [Language::Russian];
pub const DUTCH2ENGLISH: [Language; 1] = [Language::Dutch];
pub const CHINESE2ENGLISH: [Language; 1] = [Language::ChineseMandarin];
pub const SWEDISH2ENGLISH: [Language; 1] = [Language::Swedish];
pub const ARABIC2ENGLISH: [Language; 1] = [Language::Arabic];
pub const HINDI2ENGLISH: [Language; 1] = [Language::Hindi];
pub const HEBREW2ENGLISH: [Language; 1] = [Language::Hebrew];
pub const FRENCH2GERMAN: [Language; 1] = [Language::French];
pub const GERMAN2FRENCH: [Language; 1] = [Language::German];
}
impl MarianTargetLanguages {
pub const ENGLISH2ROMANCE: [Language; 7] = [
Language::French,
Language::Spanish,
Language::Italian,
Language::Catalan,
Language::Romanian,
Language::Portuguese,
Language::Occitan,
];
pub const ENGLISH2GERMAN: [Language; 1] = [Language::German];
pub const ENGLISH2RUSSIAN: [Language; 1] = [Language::Russian];
pub const ENGLISH2DUTCH: [Language; 1] = [Language::Dutch];
pub const ENGLISH2CHINESE: [Language; 1] = [Language::ChineseMandarin];
pub const ENGLISH2SWEDISH: [Language; 1] = [Language::Swedish];
pub const ENGLISH2ARABIC: [Language; 1] = [Language::Arabic];
pub const ENGLISH2HINDI: [Language; 1] = [Language::Hindi];
pub const ENGLISH2HEBREW: [Language; 1] = [Language::Hebrew];
pub const ROMANCE2ENGLISH: [Language; 1] = [Language::English];
pub const GERMAN2ENGLISH: [Language; 1] = [Language::English];
pub const RUSSIAN2ENGLISH: [Language; 1] = [Language::English];
pub const DUTCH2ENGLISH: [Language; 1] = [Language::English];
pub const CHINESE2ENGLISH: [Language; 1] = [Language::English];
pub const SWEDISH2ENGLISH: [Language; 1] = [Language::English];
pub const ARABIC2ENGLISH: [Language; 1] = [Language::English];
pub const HINDI2ENGLISH: [Language; 1] = [Language::English];
pub const HEBREW2ENGLISH: [Language; 1] = [Language::English];
pub const FRENCH2GERMAN: [Language; 1] = [Language::German];
pub const GERMAN2FRENCH: [Language; 1] = [Language::French];
}
/// # Marian Model for conditional generation
/// Marian model with a vocabulary decoding head
/// It is made of the following blocks:

View File

@ -61,5 +61,6 @@ mod marian_model;
pub use marian_model::{
MarianConfigResources, MarianForConditionalGeneration, MarianGenerator, MarianModelResources,
MarianPrefix, MarianSpmResources, MarianVocabResources,
MarianPrefix, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages,
MarianVocabResources,
};

View File

@ -56,10 +56,10 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
//! use tch::Device;
//! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];

File diff suppressed because it is too large Load Diff

View File

@ -56,5 +56,5 @@ mod t5_model;
pub use attention::LayerState;
pub use t5_model::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Generator, T5Model, T5ModelOutput,
T5ModelResources, T5Prefix, T5VocabResources,
T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, T5VocabResources,
};

View File

@ -27,6 +27,7 @@ use crate::pipelines::generation_utils::private_generation_utils::{
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::translation::Language;
use crate::t5::attention::LayerState;
use crate::t5::encoder::T5Stack;
@ -42,6 +43,12 @@ pub struct T5VocabResources;
/// # T5 optional prefixes
pub struct T5Prefix;
/// # T5 source languages pre-sets
pub struct T5SourceLanguages;
/// # T5 target languages pre-sets
pub type T5TargetLanguages = T5SourceLanguages;
impl T5ModelResources {
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
pub const T5_SMALL: (&'static str, &'static str) = (
@ -81,6 +88,13 @@ impl T5VocabResources {
);
}
const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
impl T5SourceLanguages {
pub const T5_SMALL: [Language; 3] = T5LANGUAGES;
pub const T5_BASE: [Language; 3] = T5LANGUAGES;
}
impl T5Prefix {
pub const ENGLISH2FRENCH: Option<&'static str> = Some("translate English to French:");
pub const ENGLISH2GERMAN: Option<&'static str> = Some("translate English to German:");

View File

@ -1,11 +1,11 @@
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
use tch::Device;
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> anyhow::Result<()> {
// Set-up translation model
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let translation_config = TranslationConfig::new(OldLanguage::EnglishToFrench, Device::Cpu);
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";