Added integration tests for T5 and dependencies download

This commit is contained in:
Guillaume B 2020-07-07 19:29:45 +02:00
parent 1e865ef6eb
commit 154fe2a4d0
8 changed files with 267 additions and 30 deletions

View File

@ -20,6 +20,7 @@ use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
@ -297,6 +298,20 @@ fn _download_dialogpt() -> failure::Fallible<()> {
Ok(())
}
fn download_t5_small() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
@ -312,6 +327,7 @@ fn main() -> failure::Fallible<()> {
let _ = download_electra_generator();
let _ = download_electra_discriminator();
let _ = download_albert_base_v2();
let _ = download_t5_small();
Ok(())
}

View File

@ -17,11 +17,15 @@
//! pre-processing, forward pass and postprocessing differs between pipelines while basic config and
//! tokenization objects don't.
//!
use crate::bart::BartConfig;
use crate::bert::BertConfig;
use crate::distilbert::DistilBertConfig;
use crate::electra::ElectraConfig;
use crate::t5::T5Config;
use crate::Config;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@ -34,6 +38,8 @@ pub enum ModelType {
DistilBert,
Roberta,
Electra,
Marian,
T5,
}
/// # Abstraction that holds a model configuration, can be of any of the supported models
@ -44,6 +50,10 @@ pub enum ConfigOption {
DistilBert(DistilBertConfig),
/// Electra configuration
Electra(ElectraConfig),
/// Marian configuration
Marian(BartConfig),
/// T5 configuration
T5(T5Config),
}
/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
@ -52,6 +62,10 @@ pub enum TokenizerOption {
Bert(BertTokenizer),
/// Roberta Tokenizer
Roberta(RobertaTokenizer),
/// Marian Tokenizer
Marian(MarianTokenizer),
/// T5 Tokenizer
T5(T5Tokenizer),
}
impl ConfigOption {
@ -61,6 +75,8 @@ impl ConfigOption {
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
ModelType::Marian => ConfigOption::Marian(BartConfig::from_file(path)),
ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)),
}
}
@ -75,6 +91,10 @@ impl ConfigOption {
Self::Electra(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Marian(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
}
}
}
@ -96,6 +116,12 @@ impl TokenizerOption {
merges_path.expect("No merges specified!"),
lower_case,
)),
ModelType::Marian => TokenizerOption::Marian(MarianTokenizer::from_files(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)),
ModelType::T5 => TokenizerOption::T5(T5Tokenizer::from_file(vocab_path, lower_case)),
}
}
@ -104,6 +130,8 @@ impl TokenizerOption {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
}
}
@ -122,6 +150,12 @@ impl TokenizerOption {
Self::Roberta(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::Marian(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::T5(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
}
}
}

View File

@ -212,6 +212,12 @@ impl SequenceClassificationOption {
ModelType::Electra => {
panic!("SequenceClassification not implemented for Electra!");
}
ModelType::Marian => {
panic!("SequenceClassification not implemented for Marian!");
}
ModelType::T5 => {
panic!("SequenceClassification not implemented for T5!");
}
}
}

View File

@ -330,6 +330,12 @@ impl TokenClassificationOption {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Marian => {
panic!("TokenClassification not implemented for Marian!");
}
ModelType::T5 => {
panic!("TokenClassification not implemented for T5!");
}
}
}
@ -573,6 +579,12 @@ impl TokenClassificationModel {
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
TokenizerOption::Marian(_) => {
panic!("TokenClassification not implemented for Marian!");
}
TokenizerOption::T5(_) => {
panic!("TokenClassification not implemented for T5!");
}
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);

View File

@ -59,8 +59,12 @@ use crate::marian::{
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
MarianVocabResources,
};
use crate::pipelines::generation::{GenerateConfig, LanguageGenerator, MarianGenerator};
use tch::Device;
use crate::pipelines::common::ModelType;
use crate::pipelines::generation::{
GenerateConfig, LanguageGenerator, MarianGenerator, T5Generator,
};
use crate::t5::{T5ConfigResources, T5ModelResources, T5Prefix, T5VocabResources};
use tch::{Device, Tensor};
/// Pretrained languages available for direct use
pub enum Language {
@ -80,6 +84,8 @@ pub enum Language {
EnglishToRomanian,
EnglishToGerman,
EnglishToRussian,
EnglishToFrenchV2,
EnglishToGermanV2,
FrenchToGerman,
GermanToFrench,
}
@ -93,12 +99,44 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2FRENCH,
ModelType::Marian,
);
pub const ENGLISH2FRENCH_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2FRENCH,
ModelType::T5,
);
pub const ENGLISH2GERMAN_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2GERMAN,
ModelType::T5,
);
pub const ENGLISH2CATALAN: (
(&'static str, &'static str),
@ -106,12 +144,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
ModelType::Marian,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
@ -119,12 +159,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
ModelType::Marian,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
@ -132,12 +174,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
ModelType::Marian,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
@ -145,12 +189,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
ModelType::Marian,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
@ -158,12 +204,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
ModelType::Marian,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
@ -171,12 +219,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
ModelType::Marian,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
@ -184,12 +234,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
ModelType::Marian,
);
pub const FRENCH2ENGLISH: (
@ -198,12 +250,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::FRENCH2ENGLISH,
ModelType::Marian,
);
pub const CATALAN2ENGLISH: (
(&'static str, &'static str),
@ -211,12 +265,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
ModelType::Marian,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
@ -224,12 +280,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
ModelType::Marian,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
@ -237,12 +295,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
ModelType::Marian,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
@ -250,12 +310,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
ModelType::Marian,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
@ -263,12 +325,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
ModelType::Marian,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
@ -276,12 +340,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
ModelType::Marian,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
@ -289,12 +355,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
ModelType::Marian,
);
pub const FRENCH2GERMAN: (
@ -303,12 +371,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
ModelType::Marian,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
@ -316,12 +386,14 @@ impl RemoteTranslationResources {
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
ModelType::Marian,
);
}
@ -365,6 +437,8 @@ pub struct TranslationConfig {
pub device: Device,
/// Prefix to append translation inputs with
pub prefix: Option<String>,
/// Model type used for translation
pub model_type: ModelType,
}
impl TranslationConfig {
@ -388,7 +462,7 @@ impl TranslationConfig {
/// # }
/// ```
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) =
let (model_resource, config_resource, vocab_resource, merges_resource, prefix, model_type) =
match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
@ -408,6 +482,9 @@ impl TranslationConfig {
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
@ -438,6 +515,7 @@ impl TranslationConfig {
num_return_sequences: 1,
device,
prefix,
model_type,
}
}
@ -459,6 +537,7 @@ impl TranslationConfig {
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// use tch::Device;
/// use rust_bert::pipelines::common::ModelType;
///
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
@ -480,6 +559,7 @@ impl TranslationConfig {
/// sentence_piece_resource,
/// Some(">>fr<<".to_string()),
/// Device::cuda_if_available(),
/// ModelType::Marian,
/// );
/// # Ok(())
/// # }
@ -491,6 +571,7 @@ impl TranslationConfig {
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device,
model_type: ModelType,
) -> TranslationConfig {
TranslationConfig {
model_resource,
@ -511,13 +592,84 @@ impl TranslationConfig {
num_return_sequences: 1,
device,
prefix,
model_type,
}
}
}
/// # Abstraction that holds one particular translation model, for any of the supported models
pub enum TranslationOption {
/// Translator based on Marian model
Marian(MarianGenerator),
/// Translator based on T5 model
T5(T5Generator),
}
impl TranslationOption {
pub fn new(config: TranslationConfig) -> Self {
let generate_config = GenerateConfig {
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
device: config.device,
};
match config.model_type {
ModelType::Marian => {
TranslationOption::Marian(MarianGenerator::new(generate_config).unwrap())
}
ModelType::T5 => TranslationOption::T5(T5Generator::new(generate_config).unwrap()),
ModelType::Bert => {
panic!("Translation not implemented for Electra!");
}
ModelType::DistilBert => {
panic!("Translation not implemented for DistilBert!");
}
ModelType::Roberta => {
panic!("Translation not implemented for Roberta!");
}
ModelType::Electra => {
panic!("Translation not implemented for Electra!");
}
}
}
/// Returns the `ModelType` for this TranslationOption
pub fn model_type(&self) -> ModelType {
match *self {
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
}
}
/// Interface method to generate() of the particular models.
pub fn generate(
&self,
prompt_texts: Option<Vec<&str>>,
attention_mask: Option<Tensor>,
) -> Vec<String> {
match *self {
Self::Marian(ref model) => model.generate(prompt_texts, attention_mask),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask),
}
}
}
/// # TranslationModel to perform translation
pub struct TranslationModel {
model: MarianGenerator,
model: TranslationOption,
prefix: Option<String>,
}
@ -542,32 +694,10 @@ impl TranslationModel {
/// # }
/// ```
pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
let generate_config = GenerateConfig {
model_resource: translation_config.model_resource,
config_resource: translation_config.config_resource,
merges_resource: translation_config.merges_resource,
vocab_resource: translation_config.vocab_resource,
min_length: translation_config.min_length,
max_length: translation_config.max_length,
do_sample: translation_config.do_sample,
early_stopping: translation_config.early_stopping,
num_beams: translation_config.num_beams,
temperature: translation_config.temperature,
top_k: translation_config.top_k,
top_p: translation_config.top_p,
repetition_penalty: translation_config.repetition_penalty,
length_penalty: translation_config.length_penalty,
no_repeat_ngram_size: translation_config.no_repeat_ngram_size,
num_return_sequences: translation_config.num_return_sequences,
device: translation_config.device,
};
let prefix = translation_config.prefix.clone();
let model = TranslationOption::new(translation_config);
let model = MarianGenerator::new(generate_config)?;
Ok(TranslationModel {
model,
prefix: translation_config.prefix,
})
Ok(TranslationModel { model, prefix })
}
/// Translates texts provided

View File

@ -5,6 +5,6 @@ mod t5;
pub use attention::LayerState;
pub use t5::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelResources,
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelResources, T5Prefix,
T5VocabResources,
};

View File

@ -27,6 +27,9 @@ pub struct T5ConfigResources;
/// # T5 Pretrained model vocab files
pub struct T5VocabResources;
/// # T5 optional prefixes
pub struct T5Prefix;
impl T5ModelResources {
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
pub const T5_SMALL: (&'static str, &'static str) = (
@ -66,6 +69,11 @@ impl T5VocabResources {
);
}
impl T5Prefix {
pub const ENGLISH2FRENCH: Option<&'static str> = Some("translate English to French: ");
pub const ENGLISH2GERMAN: Option<&'static str> = Some("translate English to German: ");
}
#[derive(Debug, Serialize, Deserialize)]
/// # T5 model configuration
/// Defines the T5 model architecture (e.g. number of layers, hidden layer size, label mapping...)

31
tests/t5.rs Normal file
View File

@ -0,0 +1,31 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device;
#[test]
fn test_translation_t5() -> failure::Fallible<()> {
// Set-up translation model
let translation_config = TranslationConfig::new_from_resources(
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
Some("translate English to French: ".to_string()),
Device::cuda_if_available(),
ModelType::T5,
);
let model = TranslationModel::new(translation_config)?;
let input_context = "The quick brown fox jumps over the lazy dog";
let output = model.translate(&[input_context]);
assert_eq!(
output[0],
" Le renard brun rapide saute au-dessus du chien paresseux."
);
Ok(())
}