mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Use of new language enum in TranslationModel
This commit is contained in:
parent
edb080f3a0
commit
1c375a817e
@ -4,13 +4,14 @@ extern crate criterion;
|
||||
use criterion::{black_box, Criterion};
|
||||
// use rust_bert::pipelines::common::ModelType;
|
||||
// use rust_bert::pipelines::translation::TranslationOption::{Marian, T5};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
// use rust_bert::resources::{LocalResource, Resource};
|
||||
use std::time::{Duration, Instant};
|
||||
use tch::Device;
|
||||
|
||||
fn create_translation_model() -> TranslationModel {
|
||||
let config = TranslationConfig::new(Language::EnglishToFrenchV2, Device::cuda_if_available());
|
||||
let config =
|
||||
TranslationConfig::new(OldLanguage::EnglishToFrenchV2, Device::cuda_if_available());
|
||||
// let config = TranslationConfig::new_from_resources(
|
||||
// Resource::Local(LocalResource {
|
||||
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(),
|
||||
@ -46,7 +47,7 @@ fn translation_load_model(iters: u64) -> Duration {
|
||||
for _i in 0..iters {
|
||||
let start = Instant::now();
|
||||
let config =
|
||||
TranslationConfig::new(Language::EnglishToFrenchV2, Device::cuda_if_available());
|
||||
TranslationConfig::new(OldLanguage::EnglishToFrenchV2, Device::cuda_if_available());
|
||||
// let config = TranslationConfig::new_from_resources(
|
||||
// Resource::Local(LocalResource {
|
||||
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(),
|
||||
|
@ -13,12 +13,12 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
use tch::Device;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let translation_config =
|
||||
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
|
||||
TranslationConfig::new(OldLanguage::EnglishToGerman, Device::cuda_if_available());
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||
|
@ -167,10 +167,10 @@
|
||||
//! ```no_run
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
|
||||
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
//! use tch::Device;
|
||||
//! let translation_config =
|
||||
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available());
|
||||
//! let mut model = TranslationModel::new(translation_config)?;
|
||||
//!
|
||||
//! let input = ["This is a sentence to be translated"];
|
||||
|
@ -19,6 +19,7 @@ use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
use crate::pipelines::generation_utils::{
|
||||
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
|
||||
};
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::{MarianTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::MarianVocab;
|
||||
@ -41,6 +42,12 @@ pub struct MarianSpmResources;
|
||||
/// # Marian optional prefixes
|
||||
pub struct MarianPrefix;
|
||||
|
||||
/// # Marian source languages pre-sets
|
||||
pub struct MarianSourceLanguages;
|
||||
|
||||
/// # Marian target languages pre-sets
|
||||
pub struct MarianTargetLanguages;
|
||||
|
||||
impl MarianModelResources {
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||
@ -487,6 +494,68 @@ impl MarianPrefix {
|
||||
pub const HEBREW2ENGLISH: Option<&'static str> = None;
|
||||
}
|
||||
|
||||
impl MarianSourceLanguages {
|
||||
pub const ENGLISH2ROMANCE: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2GERMAN: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2RUSSIAN: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2DUTCH: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2CHINESE: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2SWEDISH: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2ARABIC: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2HINDI: [Language; 1] = [Language::English];
|
||||
pub const ENGLISH2HEBREW: [Language; 1] = [Language::English];
|
||||
pub const ROMANCE2ENGLISH: [Language; 7] = [
|
||||
Language::French,
|
||||
Language::Spanish,
|
||||
Language::Italian,
|
||||
Language::Catalan,
|
||||
Language::Romanian,
|
||||
Language::Portuguese,
|
||||
Language::Occitan,
|
||||
];
|
||||
pub const GERMAN2ENGLISH: [Language; 1] = [Language::German];
|
||||
pub const RUSSIAN2ENGLISH: [Language; 1] = [Language::Russian];
|
||||
pub const DUTCH2ENGLISH: [Language; 1] = [Language::Dutch];
|
||||
pub const CHINESE2ENGLISH: [Language; 1] = [Language::ChineseMandarin];
|
||||
pub const SWEDISH2ENGLISH: [Language; 1] = [Language::Swedish];
|
||||
pub const ARABIC2ENGLISH: [Language; 1] = [Language::Arabic];
|
||||
pub const HINDI2ENGLISH: [Language; 1] = [Language::Hindi];
|
||||
pub const HEBREW2ENGLISH: [Language; 1] = [Language::Hebrew];
|
||||
pub const FRENCH2GERMAN: [Language; 1] = [Language::French];
|
||||
pub const GERMAN2FRENCH: [Language; 1] = [Language::German];
|
||||
}
|
||||
|
||||
impl MarianTargetLanguages {
|
||||
pub const ENGLISH2ROMANCE: [Language; 7] = [
|
||||
Language::French,
|
||||
Language::Spanish,
|
||||
Language::Italian,
|
||||
Language::Catalan,
|
||||
Language::Romanian,
|
||||
Language::Portuguese,
|
||||
Language::Occitan,
|
||||
];
|
||||
pub const ENGLISH2GERMAN: [Language; 1] = [Language::German];
|
||||
pub const ENGLISH2RUSSIAN: [Language; 1] = [Language::Russian];
|
||||
pub const ENGLISH2DUTCH: [Language; 1] = [Language::Dutch];
|
||||
pub const ENGLISH2CHINESE: [Language; 1] = [Language::ChineseMandarin];
|
||||
pub const ENGLISH2SWEDISH: [Language; 1] = [Language::Swedish];
|
||||
pub const ENGLISH2ARABIC: [Language; 1] = [Language::Arabic];
|
||||
pub const ENGLISH2HINDI: [Language; 1] = [Language::Hindi];
|
||||
pub const ENGLISH2HEBREW: [Language; 1] = [Language::Hebrew];
|
||||
pub const ROMANCE2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const GERMAN2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const RUSSIAN2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const DUTCH2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const CHINESE2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const SWEDISH2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const ARABIC2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const HINDI2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const HEBREW2ENGLISH: [Language; 1] = [Language::English];
|
||||
pub const FRENCH2GERMAN: [Language; 1] = [Language::German];
|
||||
pub const GERMAN2FRENCH: [Language; 1] = [Language::French];
|
||||
}
|
||||
|
||||
/// # Marian Model for conditional generation
|
||||
/// Marian model with a vocabulary decoding head
|
||||
/// It is made of the following blocks:
|
||||
|
@ -61,5 +61,6 @@ mod marian_model;
|
||||
|
||||
pub use marian_model::{
|
||||
MarianConfigResources, MarianForConditionalGeneration, MarianGenerator, MarianModelResources,
|
||||
MarianPrefix, MarianSpmResources, MarianVocabResources,
|
||||
MarianPrefix, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages,
|
||||
MarianVocabResources,
|
||||
};
|
||||
|
@ -56,10 +56,10 @@
|
||||
//! ```no_run
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
|
||||
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
//! use tch::Device;
|
||||
//! let translation_config =
|
||||
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available());
|
||||
//! let mut model = TranslationModel::new(translation_config)?;
|
||||
//!
|
||||
//! let input = ["This is a sentence to be translated"];
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -56,5 +56,5 @@ mod t5_model;
|
||||
pub use attention::LayerState;
|
||||
pub use t5_model::{
|
||||
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Generator, T5Model, T5ModelOutput,
|
||||
T5ModelResources, T5Prefix, T5VocabResources,
|
||||
T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, T5VocabResources,
|
||||
};
|
||||
|
@ -27,6 +27,7 @@ use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
use crate::pipelines::generation_utils::{
|
||||
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
|
||||
};
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::t5::attention::LayerState;
|
||||
use crate::t5::encoder::T5Stack;
|
||||
|
||||
@ -42,6 +43,12 @@ pub struct T5VocabResources;
|
||||
/// # T5 optional prefixes
|
||||
pub struct T5Prefix;
|
||||
|
||||
/// # T5 source languages pre-sets
|
||||
pub struct T5SourceLanguages;
|
||||
|
||||
/// # T5 target languages pre-sets
|
||||
pub type T5TargetLanguages = T5SourceLanguages;
|
||||
|
||||
impl T5ModelResources {
|
||||
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
|
||||
pub const T5_SMALL: (&'static str, &'static str) = (
|
||||
@ -81,6 +88,13 @@ impl T5VocabResources {
|
||||
);
|
||||
}
|
||||
|
||||
const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
|
||||
|
||||
impl T5SourceLanguages {
|
||||
pub const T5_SMALL: [Language; 3] = T5LANGUAGES;
|
||||
pub const T5_BASE: [Language; 3] = T5LANGUAGES;
|
||||
}
|
||||
|
||||
impl T5Prefix {
|
||||
pub const ENGLISH2FRENCH: Option<&'static str> = Some("translate English to French:");
|
||||
pub const ENGLISH2GERMAN: Option<&'static str> = Some("translate English to German:");
|
||||
|
@ -1,11 +1,11 @@
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
|
||||
use tch::Device;
|
||||
|
||||
#[test]
|
||||
// #[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn test_translation() -> anyhow::Result<()> {
|
||||
// Set-up translation model
|
||||
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
|
||||
let translation_config = TranslationConfig::new(OldLanguage::EnglishToFrench, Device::Cpu);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||
|
Loading…
Reference in New Issue
Block a user