Model resources update, documentation

This commit is contained in:
Guillaume B 2020-05-25 22:01:53 +02:00
parent ff312a7b9e
commit 829965d68b
10 changed files with 272 additions and 57 deletions

View File

@ -30,7 +30,7 @@ all-tests = []
features = [ "doc-only" ]
[dependencies]
rust_tokenizers = {version = "~3.1.2", path = "E:/Coding/backup-rust/rust-tokenizers/main"}
rust_tokenizers = "~3.1.2"
tch = "~0.1.7"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}

View File

@ -10,16 +10,17 @@ This repository exposes the model base architecture, task-specific heads (see be
The following models are currently implemented:
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
Masked LM|✅ |✅ |✅ | | | |✅|
Sequence classification|✅ |✅ |✅| | | | |
Token classification|✅ |✅ | ✅| | | |✅|
Question answering|✅ |✅ |✅| | | | |
Multiple choices| |✅ |✅| | | | |
Next token prediction| | | |✅|✅|✅| |
Natural Language Generation| | | |✅|✅|✅| |
Summarization | | | | | |✅| |
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:
Masked LM|✅ |✅ |✅ | | | |✅| |
Sequence classification|✅ |✅ |✅| | | | | |
Token classification|✅ |✅ | ✅| | | |✅| |
Question answering|✅ |✅ |✅| | | | | |
Multiple choices| |✅ |✅| | | | | |
Next token prediction| | | |✅|✅|✅| | |
Natural Language Generation| | | |✅|✅|✅| | |
Summarization | | | | | |✅| | |
Translation | | | | | |✅| |✅ |
## Ready-to-use pipelines
@ -41,7 +42,31 @@ Output:
[Answer { score: 0.9976814985275269, start: 13, end: 21, answer: "Amsterdam" }]
```
#### 2. Summarization
#### 2. Translation
Translation using the MarianMT architecture and pre-trained models from the Opus-MT team from Language Technology at the University of Helsinki.
Currently supported languages are :
- English <-> French
- English <-> Spanish
- English <-> Portuguese
- English <-> Italian
- English <-> Catalan
- English <-> German
- French <-> German
```rust
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
let mut model = TranslationModel::new(translation_config)?;
let input = ["This is a sentence to be translated"];
let output = model.translate(&input);
```
Output:
```
Il s'agit d'une phrase à traduire
```
#### 3. Summarization
Abstractive summarization using a pretrained BART model.
```rust
@ -80,7 +105,7 @@ This is the first such discovery in a planet in its star's habitable zone.
The planet is not too hot and not too cold for liquid water to exist."
```
#### 3. Natural Language Generation
#### 4. Natural Language Generation
Generate language based on a prompt. GPT2 and GPT available as base models.
Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
@ -107,7 +132,7 @@ Example output:
]
```
#### 4. Sentiment analysis
#### 5. Sentiment analysis
Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
```rust
let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -131,7 +156,7 @@ Output:
]
```
#### 5. Named Entity Recognition
#### 6. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
```rust
let ner_model = NERModel::new(default::default())?;

View File

@ -13,24 +13,18 @@
extern crate failure;
use rust_bert::resources::{Resource, LocalResource};
use std::path::PathBuf;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel, Language};
use tch::Device;
fn main() -> failure::Fallible<()> {
let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/config.json") });
let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/model.ot") });
let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/vocab.json") });
let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/spiece.model") });
let translation_config = TranslationConfig::new_from_resources(model_resource,
config_resource, vocab_resource, merges_resource, Device::cuda_if_available());
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let mut model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let output = model.translate(&[input_context_1, input_context_2]);
for sentence in output {

View File

@ -37,7 +37,7 @@
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let config = BartConfig::from_file(config_path);
//! let gpt2_model = BartModel::new(&vs.root(), &config, false);
//! let bart_model = BartModel::new(&vs.root(), &config, false);
//! vs.load(weights_path)?;
//!
//!# Ok(())

View File

@ -7,6 +7,7 @@
//!
//! This crate can be used in two different ways:
//! - Ready-to-use NLP pipelines for:
//! - Translation
//! - Summarization
//! - Sentiment Analysis
//! - Named Entity Recognition
@ -28,16 +29,17 @@
//! ```
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
//!
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
//! Masked LM|✅ |✅ |✅ | | | |✅|
//! Sequence classification|✅ |✅ |✅| | | | |
//! Token classification|✅ |✅ | ✅| | | |✅|
//! Question answering|✅ |✅ |✅| | | | |
//! Multiple choices| |✅ |✅| | | | |
//! Next token prediction| | | |✅|✅| | |
//! Natural Language Generation| | | |✅|✅| | |
//! Summarization| | | |✅|✅|✅| |
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
//! Masked LM|✅ |✅ |✅ | | | |✅| |
//! Sequence classification|✅ |✅ |✅| | | | | |
//! Token classification|✅ |✅ | ✅| | | |✅| |
//! Question answering|✅ |✅ |✅| | | | | |
//! Multiple choices| |✅ |✅| | | | | |
//! Next token prediction| | | |✅|✅| | | |
//! Natural Language Generation| | | |✅|✅| | | |
//! Summarization| | | | | |✅| | |
//! Translation| | | | | | | |✅|
//!
//! # Loading pre-trained models
//!

View File

@ -34,35 +34,81 @@ pub struct MarianPrefix;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2FRENCH: (&'static str, &'static str) = ("marian-mt-en-fr/model.ot", "https://cdn.huggingface.co/facebook/bart-large/rust_model.ot");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const FRENCH2ENGLISH: (&'static str, &'static str) = ("marian-mt-fr-en/model.ot", "https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot");
}
impl MarianConfigResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2FRENCH: (&'static str, &'static str) = ("marian-mt-en-fr/config.json", "https://cdn.huggingface.co/facebook/bart-large/config.json");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2ENGLISH: (&'static str, &'static str) = ("marian-mt-fr-en/config.json", "https://cdn.huggingface.co/facebook/bart-large-cnn/config.json");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json");
}
impl MarianVocabResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2FRENCH: (&'static str, &'static str) = ("marian-mt-en-fr/vocab.json", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2ENGLISH: (&'static str, &'static str) = ("marian-mt-fr-en/vocab.json", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json");
}
impl MarianSpmResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2FRENCH: (&'static str, &'static str) = ("marian-mt-en-fr/spiece.model", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2ENGLISH: (&'static str, &'static str) = ("marian-mt-fr-en/spiece.model", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm");
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm");
}
impl MarianPrefix {
pub const ENGLISH2FRENCH: Option<&'static str> = Some(">>fr<<");
pub const ENGLISH2CATALAN: Option<&'static str> = Some(">>ca<<");
pub const ENGLISH2SPANISH: Option<&'static str> = Some(">>es<<");
pub const ENGLISH2PORTUGUESE: Option<&'static str> = Some(">>pt<<");
pub const ENGLISH2ITALIAN: Option<&'static str> = Some(">>it<<");
pub const ENGLISH2ROMANIAN: Option<&'static str> = Some(">>ro<<");
pub const ENGLISH2GERMAN: Option<&'static str> = None;
pub const FRENCH2ENGLISH: Option<&'static str> = None;
pub const CATALAN2ENGLISH: Option<&'static str> = None;
pub const SPANISH2ENGLISH: Option<&'static str> = None;
pub const PORTUGUESE2ENGLISH: Option<&'static str> = None;
pub const ITALIAN2ENGLISH: Option<&'static str> = None;
pub const ROMANIAN2ENGLISH: Option<&'static str> = None;
pub const GERMAN2ENGLISH: Option<&'static str> = None;
pub const FRENCH2GERMAN: Option<&'static str> = None;
pub const GERMAN2FRENCH: Option<&'static str> = None;
}
/// # Marian Model for conditional generation
@ -279,7 +325,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
train);
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias;
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
}
}

View File

@ -1,3 +1,50 @@
//! # Marian
//!
//! Implementation of the Marian language model ([Marian: Fast Neural Machine Translation in {C++}](http://www.aclweb.org/anthology/P18-4020) Junczys-Dowmunt, Grundkiewicz, Dwojak, Hoang, Heafield, Neckermann, Seide, Germann, Fikri Aji, Bogoychev, Martins, Birch, 2018).
//! The base model is implemented in the `bart::BartModel` struct. This model includes a language model head: `marian::MarianForConditionalGeneration`
//! implementing the common `generation::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/translation.rs`, run with `cargo run --example translation`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `MarianTokenizer` using a `vocab.json` vocabulary and `spiece.model` sentence piece model
//!
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//! use rust_bert::marian::MarianForConditionalGeneration;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json")});
//! let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let spiece_path = download_resource(&sentence_piece_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer = MarianTokenizer::from_files(vocab_path.to_str().unwrap(), spiece_path.to_str().unwrap(), true);
//! let config = BartConfig::from_file(config_path);
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! ```
mod marian;
pub use marian::{MarianForConditionalGeneration, MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};

View File

@ -34,7 +34,39 @@
//!# ;
//! ```
//!
//! #### 2. Summarization
//! #### 2. Translation
//! Translation using the MarianMT architecture and pre-trained models from the Opus-MT team from Language Technology at the University of Helsinki.
//! Currently supported languages are :
//! - English <-> French
//! - English <-> Spanish
//! - English <-> Portuguese
//! - English <-> Italian
//! - English <-> Catalan
//! - English <-> German
//! - French <-> German
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
//! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];
//!
//! let output = model.translate(&input);
//!# Ok(())
//!# }
//! ```
//!
//! Output: \
//! ```no_run
//!# let output =
//! "Il s'agit d'une phrase à traduire"
//!# ;
//!```
//!
//! #### 3. Summarization
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//!
@ -83,7 +115,7 @@
//!```
//!
//!
//! #### 3. Natural Language Generation
//! #### 4. Natural Language Generation
//! Generate language based on a prompt. GPT2 and GPT available as base models.
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
@ -114,7 +146,7 @@
//!# ;
//!```
//!
//! #### 4. Sentiment analysis
//! #### 5. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
@ -144,7 +176,7 @@
//!# ;
//! ```
//!
//! #### 5. Named Entity Recognition
//! #### 6. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;

View File

@ -16,8 +16,17 @@
//! Translation based on the Marian encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! Pre-trained and ready-to-use models are available by creating a configuration from the `Language` enum.
//! These models have been trained by the Opus-MT team from Language Technology at the University of Helsinki (Jörg Tiedemann, jorg.tiedemann@helsinki.fi)
//! The Rust model files are hosted by Hugging Face Inc (https://huggingface.co).
//! These models have been trained by the [Opus-MT team from Language Technology at the University of Helsinki](https://github.com/Helsinki-NLP/Opus-MT).
//! The Rust model files are hosted by [Hugging Face Inc](https://huggingface.co).
//! Currently supported languages are :
//! - English <-> French
//! - English <-> Spanish
//! - English <-> Portuguese
//! - English <-> Italian
//! - English <-> Catalan
//! - English <-> German
//! - French <-> German
//!
//! Customized Translation models can be loaded by creating a configuration from local files.
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/{translation-model-name}
//!
@ -27,7 +36,7 @@
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
//! use tch::Device;
//! let mut translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];
@ -51,18 +60,61 @@ use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabReso
/// Pretrained languages available for direct use
pub enum Language {
EnglishToFrench,
FrenchToEnglish,
CatalanToEnglish,
SpanishToEnglish,
PortugueseToEnglish,
ItalianToEnglish,
RomanianToEnglish,
GermanToEnglish,
EnglishToFrench,
EnglishToCatalan,
EnglishToSpanish,
EnglishToPortuguese,
EnglishToItalian,
EnglishToRomanian,
EnglishToGerman,
FrenchToGerman,
GermanToFrench,
}
struct RemoteTranslationResources;
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2FRENCH, MarianConfigResources::ENGLISH2FRENCH, MarianVocabResources::ENGLISH2FRENCH, MarianSpmResources::ENGLISH2FRENCH, MarianPrefix::ENGLISH2FRENCH);
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH);
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN);
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH);
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE);
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN);
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN);
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN);
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::FRENCH2ENGLISH, MarianConfigResources::FRENCH2ENGLISH, MarianVocabResources::FRENCH2ENGLISH, MarianSpmResources::FRENCH2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH);
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH);
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH);
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH);
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH);
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH);
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN);
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH);
}
@ -131,7 +183,23 @@ impl TranslationConfig {
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
@ -191,6 +259,7 @@ impl TranslationConfig {
/// config_resource,
/// vocab_resource,
/// sentence_piece_resource,
/// Some(">>fr<<".to_string()),
/// Device::cuda_if_available());
///# Ok(())
///# }

View File

@ -15,8 +15,8 @@ fn test_translation() -> failure::Fallible<()> {
let output = model.translate(&[input_context_1, input_context_2]);
assert_eq!(output.len(), 2);
assert_eq!(output[0], " Le renard brun rapide saute sur le chien paresseux");
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux");
assert_eq!(output[1], " Le chien ne s'est pas réveillé");
Ok(())
}