Merge branch 'm2m100_implementation' into translation_rework

This commit is contained in:
Guillaume B 2021-07-09 15:41:28 +02:00
commit 450fe0d533
37 changed files with 1954 additions and 288 deletions

View File

@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
- (BREAKING) Support for `output_scores` boolean argument for generation, allowing users to output the log-probability scores of generated sequences. Updated the return type of low-level generate API to `GeneratedTextOutput` and `GeneratedIndicesOutput` containing optional scores along with the generated output.
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
- Addition of the M2M100 Language model and support for text generation / direct translation between 100 language
## Changed
- Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint)

View File

@ -57,7 +57,7 @@ all-tests = []
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~6.2.3"
rust_tokenizers = { version = "~6.2.4", path = "E:/Coding/backup-rust/rust-tokenizers/main" }
tch = "~0.5.0"
serde_json = "1.0.64"
serde = { version = "1.0.126", features = ["derive"] }

View File

@ -49,6 +49,7 @@ GPT-Neo| | | |✅ | | | |
BART|✅| | |✅ |✅| | |
Marian| | | | | |✅| |
MBart|✅| | |✅ | | | |
M2M100| | | |✅ | | | |
Electra | |✅| | | | |✅|
ALBERT |✅|✅|✅| | | |✅|
T5 | | | |✅ |✅|✅| |
@ -62,7 +63,7 @@ Pegasus| | | | |✅| | |
## Getting started
This library relies on the [tch](https://github.com/LaurentMazare/tch-rs) crate for bindings to the C++ Libtorch API.
The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up yoru environment
The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up your environment
to use these bindings, please refer to the [tch](https://github.com/LaurentMazare/tch-rs) for detailed information or support.
Furthermore, this library relies on a cache folder for downloading pre-trained models.

View File

@ -0,0 +1,64 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::m2m_100::{
M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources,
M2M100VocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: 142,
min_length: 0,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
do_sample: false,
early_stopping: true,
num_beams: 3,
..Default::default()
};
let model = M2M100Generator::new(generate_config)?;
let input_context_1 = ">>en.<< The dog did not wake up.";
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
let output = model.generate(
Some(&[input_context_1]),
None,
None,
None,
None,
target_language,
None,
false,
);
for sentence in output {
println!("{:?}", sentence);
}
Ok(())
}

View File

@ -31,7 +31,7 @@ pub struct AlbertConfigResources;
pub struct AlbertVocabResources;
impl AlbertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/model",
"https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot",
@ -39,7 +39,7 @@ impl AlbertModelResources {
}
impl AlbertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/config",
"https://huggingface.co/albert-base-v2/resolve/main/config.json",
@ -47,7 +47,7 @@ impl AlbertConfigResources {
}
impl AlbertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/spiece",
"https://huggingface.co/albert-base-v2/resolve/main/spiece.model",
@ -88,7 +88,7 @@ pub struct AlbertConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<AlbertConfig> for AlbertConfig {}
impl Config for AlbertConfig {}
/// # ALBERT Base model
/// Base architecture for ALBERT models. Task-specific models will be built from this common base model

View File

@ -50,32 +50,32 @@ pub struct BartVocabResources;
pub struct BartMergesResources;
impl BartModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/model",
"https://huggingface.co/facebook/bart-large/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/model",
"https://huggingface.co/facebook/bart-large-cnn/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/model",
"https://huggingface.co/facebook/bart-large-xsum/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/model",
"https://huggingface.co/facebook/bart-large-mnli/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/model",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/model",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/rust_model.ot",
@ -83,32 +83,32 @@ impl BartModelResources {
}
impl BartConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/config",
"https://huggingface.co/facebook/bart-large/resolve/main/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/config",
"https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/config",
"https://huggingface.co/facebook/bart-large-xsum/resolve/main/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/config",
"https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/config",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/config.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/config",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/config.json",
@ -116,32 +116,32 @@ impl BartConfigResources {
}
impl BartVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/vocab",
"https://huggingface.co/roberta-large/resolve/main/vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/vocab",
"https://huggingface.co/roberta-large/resolve/main/vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/vocab",
"https://huggingface.co/roberta-large/resolve/main/vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/vocab",
"https://huggingface.co/roberta-large/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/vocab",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/vocab.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/vocab",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/vocab.json",
@ -149,32 +149,32 @@ impl BartVocabResources {
}
impl BartMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/merges",
"https://huggingface.co/roberta-large/resolve/main/merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/merges",
"https://huggingface.co/roberta-large/resolve/main/merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/merges",
"https://huggingface.co/roberta-large/resolve/main/merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/merges",
"https://huggingface.co/roberta-large/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/merges",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/merges.txt",
);
/// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/merges",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/merges.txt",
@ -222,7 +222,7 @@ pub struct BartConfig {
pub vocab_size: i64,
}
impl Config<BartConfig> for BartConfig {}
impl Config for BartConfig {}
pub(crate) fn _make_causal_mask(
input_ids_shape: &[i64],
@ -391,7 +391,7 @@ impl BartModel {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
@ -554,7 +554,7 @@ impl BartForConditionalGeneration {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -748,7 +748,7 @@ impl BartForSequenceClassification {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -847,7 +847,7 @@ impl LMHeadModel for BartForConditionalGeneration {
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for BART
/// * `token_type_ids` - Unused for BART

View File

@ -70,7 +70,7 @@ pub use bart_model::{
};
pub(crate) use attention::BartAttention;
pub(crate) use bart_model::{_expand_mask, _prepare_decoder_attention_mask};
pub(crate) use bart_model::{_expand_mask, _make_causal_mask, _prepare_decoder_attention_mask};
pub(crate) use decoder::BartDecoderOutput;
pub(crate) use embeddings::LearnedPositionalEmbedding;
pub(crate) use encoder::BartEncoderOutput;

View File

@ -37,17 +37,17 @@ pub struct BertConfigResources;
pub struct BertVocabResources;
impl BertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/model",
"https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
@ -55,17 +55,17 @@ impl BertModelResources {
}
impl BertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/config",
"https://huggingface.co/bert-base-uncased/resolve/main/config.json",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/config",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
@ -73,17 +73,17 @@ impl BertConfigResources {
}
impl BertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/vocab",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/vocab",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
@ -112,7 +112,7 @@ pub struct BertConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<BertConfig> for BertConfig {}
impl Config for BertConfig {}
/// # BERT Base model
/// Base architecture for BERT models. Task-specific models will be built from this common base model

View File

@ -15,9 +15,9 @@ use std::io::BufReader;
use std::path::Path;
/// # Utility to deserialize JSON config files
pub trait Config<T>
pub trait Config
where
for<'de> T: Deserialize<'de>,
for<'de> Self: Deserialize<'de>,
{
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
/// The parsing will fail if non-optional keys expected by the model are missing.
@ -36,10 +36,10 @@ where
/// let config_path = Path::new("path/to/config.json");
/// let config = Gpt2Config::from_file(config_path);
/// ```
fn from_file<P: AsRef<Path>>(path: P) -> T {
fn from_file<P: AsRef<Path>>(path: P) -> Self {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);
let config: T = serde_json::from_reader(br).expect("could not parse configuration");
let config: Self = serde_json::from_reader(br).expect("could not parse configuration");
config
}
}

View File

@ -115,7 +115,7 @@ impl RemoteResource {
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rusbert/model_name. Note that this does not download the resource (only declares
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments

View File

@ -31,17 +31,17 @@ pub struct DistilBertConfigResources;
pub struct DistilBertVocabResources;
impl DistilBertModelResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model",
"https://huggingface.co/distilbert-base-uncased/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
@ -49,17 +49,17 @@ impl DistilBertModelResources {
}
impl DistilBertConfigResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config",
"https://huggingface.co/distilbert-base-uncased/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
@ -67,17 +67,17 @@ impl DistilBertConfigResources {
}
impl DistilBertVocabResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab",
"https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
@ -112,7 +112,7 @@ pub struct DistilBertConfig {
pub vocab_size: i64,
}
impl Config<DistilBertConfig> for DistilBertConfig {}
impl Config for DistilBertConfig {}
/// # DistilBERT Base model
/// Base architecture for DistilBERT models. Task-specific models will be built from this common base model

View File

@ -32,12 +32,12 @@ pub struct ElectraConfigResources;
pub struct ElectraVocabResources;
impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model",
"https://huggingface.co/google/electra-base-generator/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/rust_model.ot",
@ -45,12 +45,12 @@ impl ElectraModelResources {
}
impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config",
"https://huggingface.co/google/electra-base-generator/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json",
@ -58,12 +58,12 @@ impl ElectraConfigResources {
}
impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab",
"https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
@ -95,7 +95,7 @@ pub struct ElectraConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<ElectraConfig> for ElectraConfig {}
impl Config for ElectraConfig {}
/// # Electra Base model
/// Base architecture for Electra models.

View File

@ -44,32 +44,32 @@ pub struct Gpt2VocabResources;
pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model",
"https://huggingface.co/gpt2/resolve/main/rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model",
"https://huggingface.co/gpt2-medium/resolve/main/rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model",
"https://huggingface.co/gpt2-large/resolve/main/rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model",
"https://huggingface.co/gpt2-xl/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model",
"https://huggingface.co/distilgpt2/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot",
@ -77,32 +77,32 @@ impl Gpt2ModelResources {
}
impl Gpt2ConfigResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config",
"https://huggingface.co/gpt2/resolve/main/config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config",
"https://huggingface.co/gpt2-medium/resolve/main/config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config",
"https://huggingface.co/gpt2-large/resolve/main/config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config",
"https://huggingface.co/gpt2-xl/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config",
"https://huggingface.co/distilgpt2/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json",
@ -110,32 +110,32 @@ impl Gpt2ConfigResources {
}
impl Gpt2VocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab",
"https://huggingface.co/gpt2/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab",
"https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab",
"https://huggingface.co/gpt2-large/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab",
"https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab",
"https://huggingface.co/distilgpt2/resolve/main/vocab.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json",
@ -143,32 +143,32 @@ impl Gpt2VocabResources {
}
impl Gpt2MergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges",
"https://huggingface.co/gpt2/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges",
"https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges",
"https://huggingface.co/gpt2-large/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges",
"https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges",
"https://huggingface.co/distilgpt2/resolve/main/merges.txt",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt",
@ -199,7 +199,7 @@ pub struct Gpt2Config {
pub vocab_size: i64,
}
impl Config<Gpt2Config> for Gpt2Config {}
impl Config for Gpt2Config {}
/// # GPT2 Base model
/// Base architecture for GPT2 model. Usually complemented with a task-specific head, such as a language model head.

View File

@ -41,17 +41,17 @@ pub struct GptNeoVocabResources;
pub struct GptNeoMergesResources;
impl GptNeoModelResources {
/// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/model",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/model",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/model",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/rust_model.ot",
@ -59,17 +59,17 @@ impl GptNeoModelResources {
}
impl GptNeoConfigResources {
/// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/config",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/config",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/config",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/config.json",
@ -77,17 +77,17 @@ impl GptNeoConfigResources {
}
impl GptNeoVocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/vocab.json",
@ -95,17 +95,17 @@ impl GptNeoVocabResources {
}
impl GptNeoMergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/merges",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/merges",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/merges",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/merges.txt",
@ -146,7 +146,7 @@ pub struct GptNeoConfig {
pub resid_dropout: f64,
}
impl Config<GptNeoConfig> for GptNeoConfig {}
impl Config for GptNeoConfig {}
/// # GPT-Neo Base model
/// Base architecture for GPT-Neo models. Task-specific models will be built from this common base model
@ -263,7 +263,7 @@ impl GptNeoModel {
/// - `hidden_states` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containign the attention weights for each layer
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containing the attention weights for each layer
///
/// # Example
///
@ -504,7 +504,7 @@ impl GptNeoForCausalLM {
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containign the attention weights for each layer
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containing the attention weights for each layer
///
/// # Example
///

View File

@ -59,6 +59,7 @@
//! BART|✅| | |✅ |✅| | |
//! Marian| | | | | |✅| |
//! MBart|✅| | |✅ | | | |
//! M2M100| | | |✅ | | | |
//! Electra | |✅| | | | |✅|
//! ALBERT |✅|✅|✅| | | |✅|
//! T5 | | | |✅ |✅|✅| |
@ -80,7 +81,7 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.8.1`: if this version is no longer available on the "get started" page,
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.8.1`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
@ -579,6 +580,7 @@ pub mod electra;
pub mod gpt2;
pub mod gpt_neo;
pub mod longformer;
pub mod m2m_100;
pub mod marian;
pub mod mbart;
pub mod mobilebert;

View File

@ -34,12 +34,12 @@ pub struct LongformerVocabResources;
pub struct LongformerMergesResources;
impl LongformerModelResources {
/// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
"longformer-base-4096/model",
"https://huggingface.co/allenai/longformer-base-4096/resolve/main/rust_model.ot",
);
/// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
"longformer-base-4096/model",
"https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/rust_model.ot",
@ -47,12 +47,12 @@ impl LongformerModelResources {
}
impl LongformerConfigResources {
/// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
"longformer-base-4096/config",
"https://huggingface.co/allenai/longformer-base-4096/resolve/main/config.json",
);
/// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
"longformer-base-4096/config",
"https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/config.json",
@ -60,12 +60,12 @@ impl LongformerConfigResources {
}
impl LongformerVocabResources {
/// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
"longformer-base-4096/vocab",
"https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
);
/// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
"longformer-base-4096/vocab",
"https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/vocab.json",
@ -73,12 +73,12 @@ impl LongformerVocabResources {
}
impl LongformerMergesResources {
/// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
"longformer-base-4096/merges",
"https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
);
/// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
"longformer-base-4096/merges",
"https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/merges.txt",
@ -120,7 +120,7 @@ pub struct LongformerConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<LongformerConfig> for LongformerConfig {}
impl Config for LongformerConfig {}
fn get_question_end_index(input_ids: &Tensor, sep_token_id: i64) -> Tensor {
input_ids

15
src/m2m_100/attention.rs Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::LayerState as BartLayerState;
pub type LayerState = BartLayerState;

195
src/m2m_100/decoder.rs Normal file
View File

@ -0,0 +1,195 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartDecoderOutput, _expand_mask, _make_causal_mask};
use crate::common::dropout::Dropout;
use crate::m2m_100::embeddings::SinusoidalPositionalEmbedding;
use crate::m2m_100::{LayerState, M2M100Config};
use crate::mbart::MBartDecoderLayer;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub type M2M100DecoderLayer = MBartDecoderLayer;
pub struct M2M100Decoder {
dropout: Dropout,
layer_norm: nn::LayerNorm,
layers: Vec<M2M100DecoderLayer>,
embed_positions: SinusoidalPositionalEmbedding,
output_attentions: bool,
output_hidden_states: bool,
output_past: bool,
scale_embedding: f64,
}
impl M2M100Decoder {
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Decoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_past = config.output_past.unwrap_or(true);
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let scale_embedding = if let Some(scale_embeddings) = config.scale_embedding {
if scale_embeddings {
(config.d_model as f64).sqrt()
} else {
1.0
}
} else {
1.0
};
let dropout = Dropout::new(config.dropout);
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
let embed_positions = SinusoidalPositionalEmbedding::new(
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
config.pad_token_id.unwrap_or(1),
);
let mut layers: Vec<M2M100DecoderLayer> = vec![];
let p_layers = p / "layers";
for layer_index in 0..config.decoder_layers {
layers.push(M2M100DecoderLayer::new(&p_layers / layer_index, config));
}
M2M100Decoder {
dropout,
layer_norm,
layers,
embed_positions,
output_attentions,
output_hidden_states,
output_past,
scale_embedding,
}
}
pub fn forward_t(
&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attention_mask: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> M2M100DecoderOutput {
let past_key_values_length = if let Some(old_layer_states_values) = &old_layer_states {
if let Some(old_value_state) = &old_layer_states_values[0].0 {
old_value_state.prev_key.size()[2]
} else {
0
}
} else {
0
};
let input_shape = input_ids.size();
let sequence_length = input_shape[1];
let positions = self
.embed_positions
.forward(input_ids, past_key_values_length);
let x: Tensor = input_ids.apply(embeddings) * self.scale_embedding + positions;
let causal_mask = if sequence_length > 1 {
Some(_make_causal_mask(
input_ids.size().as_slice(),
x.kind(),
x.device(),
past_key_values_length,
))
} else {
None
};
let decoder_attention_mask = decoder_attention_mask.map(|attention_mask| {
if let Some(causal_mask) = causal_mask {
causal_mask + _expand_mask(&attention_mask, Some(sequence_length))
} else {
_expand_mask(&attention_mask, Some(sequence_length))
}
});
let encoder_attention_mask = encoder_attention_mask
.map(|mask| _expand_mask(mask, Some(*input_ids.size().last().unwrap())));
let mut hidden_state = x.apply_t(&self.dropout, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.output_past {
if old_layer_states.is_some() {
old_layer_states
} else {
Some(vec![(None, None); self.layers.len()])
}
} else {
None
};
let mut attention_weights: Option<Tensor>;
for (layer_idx, layer) in self.layers.iter().enumerate() {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_attention_mask.as_ref(),
decoder_attention_mask.as_ref(),
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
M2M100DecoderOutput {
hidden_state: hidden_state.apply(&self.layer_norm),
encoder_attention_mask,
next_decoder_cache,
all_hidden_states,
all_attentions,
}
}
}
/// Container holding a M2M100 decoder output
pub type M2M100DecoderOutput = BartDecoderOutput;

122
src/m2m_100/embeddings.rs Normal file
View File

@ -0,0 +1,122 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Borrow;
use std::ops::Deref;
use std::sync::RwLock;
use tch::nn::embedding;
use tch::{nn, Device, Kind, Tensor};
#[derive(Debug)]
pub struct SinusoidalPositionalEmbedding {
embedding: RwLock<nn::Embedding>,
embedding_dim: i64,
padding_idx: i64,
offset: i64,
}
impl SinusoidalPositionalEmbedding {
pub fn new<'p, P>(
p: P,
num_embeddings: i64,
embedding_dim: i64,
padding_idx: i64,
) -> SinusoidalPositionalEmbedding
where
P: Borrow<nn::Path<'p>>,
{
let device = p.borrow().device();
let mut local_varstore = nn::VarStore::new(device);
let offset = 2;
let embedding = RwLock::new(embedding(
local_varstore.root(),
num_embeddings + offset,
embedding_dim,
Default::default(),
));
embedding.write().unwrap().ws = SinusoidalPositionalEmbedding::build_positional_embeddings(
num_embeddings + offset,
embedding_dim,
padding_idx,
device,
);
local_varstore.freeze();
SinusoidalPositionalEmbedding {
embedding,
embedding_dim,
padding_idx,
offset,
}
}
fn build_positional_embeddings(
num_embeddings: i64,
embedding_dim: i64,
padding_idx: i64,
device: Device,
) -> Tensor {
let half_dim = embedding_dim / 2;
let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64);
let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp();
let emb =
Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0);
let mut sinusoidal_embedding =
Tensor::cat(&[&emb.sin(), &emb.cos()], 1).view([num_embeddings, -1]);
if embedding_dim % 2 == 1 {
sinusoidal_embedding = Tensor::cat(
&[
sinusoidal_embedding,
Tensor::zeros(&[num_embeddings, 1], (Kind::Float, device)),
],
1,
);
}
let _ = sinusoidal_embedding.select(0, padding_idx).fill_(0);
let _ = sinusoidal_embedding.requires_grad_(false);
sinusoidal_embedding
}
fn create_position_ids_from_input_ids(
&self,
input_ids: &Tensor,
past_key_values_length: i64,
) -> Tensor {
let mask = input_ids.ne(self.padding_idx).to_kind(Kind::Int64);
let incremental_indices = (mask.cumsum(1, Kind::Int64) + past_key_values_length) * mask;
incremental_indices + self.padding_idx
}
pub fn forward(&self, input_ids: &Tensor, past_key_values_length: i64) -> Tensor {
let position_ids =
self.create_position_ids_from_input_ids(input_ids, past_key_values_length);
let input_size = input_ids.size();
let seq_length = input_size[1];
let max_pos = self.padding_idx + 1 + seq_length;
if max_pos > self.embedding.read().unwrap().ws.size()[0] {
self.embedding.write().unwrap().ws =
SinusoidalPositionalEmbedding::build_positional_embeddings(
max_pos + self.offset,
self.embedding_dim,
self.padding_idx,
input_ids.device(),
);
}
position_ids.apply(self.embedding.read().unwrap().deref())
}
}

134
src/m2m_100/encoder.rs Normal file
View File

@ -0,0 +1,134 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartEncoderOutput, _expand_mask};
use crate::common::dropout::Dropout;
use crate::m2m_100::embeddings::SinusoidalPositionalEmbedding;
use crate::m2m_100::M2M100Config;
use crate::mbart::MBartEncoderLayer;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub type M2M100EncoderLayer = MBartEncoderLayer;
pub struct M2M100Encoder {
dropout: Dropout,
layer_norm: nn::LayerNorm,
layers: Vec<M2M100EncoderLayer>,
embed_positions: SinusoidalPositionalEmbedding,
output_attentions: bool,
output_hidden_states: bool,
scale_embedding: f64,
}
impl M2M100Encoder {
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Encoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let scale_embedding = if let Some(scale_embeddings) = config.scale_embedding {
if scale_embeddings {
(config.d_model as f64).sqrt()
} else {
1.0
}
} else {
1.0
};
let dropout = Dropout::new(config.dropout);
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
let embed_positions = SinusoidalPositionalEmbedding::new(
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
config.pad_token_id.unwrap_or(1),
);
let mut layers: Vec<M2M100EncoderLayer> = vec![];
let p_layers = p / "layers";
for layer_index in 0..config.encoder_layers {
layers.push(M2M100EncoderLayer::new(&p_layers / layer_index, config));
}
M2M100Encoder {
dropout,
layer_norm,
layers,
embed_positions,
output_attentions,
output_hidden_states,
scale_embedding,
}
}
pub fn forward_t(
&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
train: bool,
) -> M2M100EncoderOutput {
let attention_mask = attention_mask.map(|mask| _expand_mask(mask, None));
let x = input_ids.apply(embeddings) * self.scale_embedding;
let x = x + &self.embed_positions.forward(input_ids, 0);
let mut hidden_state = x.apply_t(&self.dropout, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
hidden_state = hidden_state.apply(&self.layer_norm);
M2M100EncoderOutput {
hidden_state,
all_hidden_states,
all_attentions,
}
}
}
/// Container holding a M2M100 encoder output
pub type M2M100EncoderOutput = BartEncoderOutput;

View File

@ -0,0 +1,869 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::m2m_100::decoder::M2M100Decoder;
use crate::m2m_100::encoder::M2M100Encoder;
use crate::m2m_100::LayerState;
use crate::mbart::{MBartConfig, MBartModelOutput};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::resources::{RemoteResource, Resource};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # M2M100 Pretrained model weight files
pub struct M2M100ModelResources;
/// # M2M100 Pretrained model config files
pub struct M2M100ConfigResources;
/// # M2M100 Pretrained model vocab files
pub struct M2M100VocabResources;
/// # M2M100 Pretrained model merges files
pub struct M2M100MergesResources;
impl M2M100ModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/model",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/model",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/rust_model.ot",
);
}
impl M2M100ConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/config",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/config",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/config.json",
);
}
impl M2M100VocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/vocab",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/vocab",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/vocab.json",
);
}
impl M2M100MergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/merges",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/merges",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/sentencepiece.bpe.model",
);
}
pub type M2M100Config = MBartConfig;
fn _shift_tokens_right(
input_ids: &Tensor,
pad_token_id: i64,
decoder_start_token_id: i64,
) -> Tensor {
let shifted_input_ids = Tensor::zeros(
input_ids.size().as_slice(),
(Kind::Int64, input_ids.device()),
);
let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id);
let _ = shifted_input_ids
.slice(1, 1, *shifted_input_ids.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *input_ids.size().last().unwrap() - 1, 1));
shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), pad_token_id)
}
/// # M2M100 Base model
/// Base architecture for M2M100 model. Usually complemented with a task-specific head, such as a language model head.
/// It is made of the following blocks:
/// - `encoder`: `M2M100Encoder` (transformer) made of a vector of encoding layers
/// - `decoder`: `M2M100Decoder` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
/// - `pad_token_id`: padding token id
pub struct M2M100Model {
pub(crate) encoder: M2M100Encoder,
decoder: M2M100Decoder,
pub(crate) embeddings: nn::Embedding,
pad_token_id: i64,
decoder_start_token_id: i64,
}
impl M2M100Model {
/// Build a new `M2M100Model`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the M2M100 model
/// * `config` - `M2M100Config` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::m2m_100::{M2M100Config, M2M100Model};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = M2M100Config::from_file(config_path);
/// let m2m100: M2M100Model = M2M100Model::new(&p.root() / "m2m100", &config);
/// ```
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Model
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let pad_token_id = config.pad_token_id.unwrap_or(1);
let decoder_start_token_id = config.decoder_start_token_id.unwrap_or(2);
let embedding_config = EmbeddingConfig {
padding_idx: pad_token_id,
..Default::default()
};
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
embedding_config,
);
let encoder = M2M100Encoder::new(p / "encoder", config);
let decoder = M2M100Decoder::new(p / "decoder", config);
M2M100Model {
encoder,
decoder,
embeddings,
pad_token_id,
decoder_start_token_id,
}
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `M2M100ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::m2m_100::{M2M100Config, M2M100Model};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100Model = M2M100Model::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// m2m100_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_output: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> M2M100ModelOutput {
let calc_decoder_input_ids = if decoder_input_ids.is_none() {
Some(_shift_tokens_right(
input_ids.unwrap(),
self.pad_token_id,
self.decoder_start_token_id,
))
} else {
None
};
let decoder_input_ids =
decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
let calc_encoder_output = if encoder_output.is_none() {
Some(self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
))
} else {
None
};
let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
if let Some(calc_encoder_output) = calc_encoder_output {
(
Some(calc_encoder_output.hidden_state),
calc_encoder_output.all_hidden_states,
calc_encoder_output.all_attentions,
)
} else {
(None, None, None)
};
let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
let decoder_output = self.decoder.forward_t(
&decoder_input_ids,
&encoder_output,
attention_mask,
decoder_attention_mask,
&self.embeddings,
layer_states,
train,
);
M2M100ModelOutput {
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: calc_hidden_states,
cache: decoder_output.next_decoder_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
}
}
}
/// Container holding a M2M100 model output
pub type M2M100ModelOutput = MBartModelOutput;
/// # M2M100 Model for conditional generation
/// M2M100 model with a vocabulary decoding head
/// It is made of the following blocks:
/// - `base_model`: `M2M100Model` Base M2M100 model
/// - `linear`: Linear layer without bias tied to the weights of the token id embeddings
pub struct M2M100ForConditionalGeneration {
base_model: M2M100Model,
}
impl M2M100ForConditionalGeneration {
/// Build a new `M2M100ForConditionalGeneration`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the M2M100 model
/// * `config` - `M2M100Config` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = M2M100Config::from_file(config_path);
/// let m2m100: M2M100ForConditionalGeneration =
/// M2M100ForConditionalGeneration::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100ForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let base_model = M2M100Model::new(p.borrow() / "model", config);
M2M100ForConditionalGeneration { base_model }
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `M2M100ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// # use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// m2m100_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_output: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> M2M100ModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_output,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
M2M100ModelOutput {
decoder_output: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
self.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state
}
}
impl LMHeadModel for M2M100ForConditionalGeneration {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for M2M100
/// * `token_type_ids` - Unused for M2M100
/// * `position_ids` - Unused for M2M100
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BARTCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation_utils::LMHeadModel;
/// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// m2m100_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```
fn forward_t(
&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
encoder_outputs,
None,
cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
encoder_outputs,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with M2M100 Model".into(),
));
}
};
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
cache: Cache::BARTCache(base_model_output.cache),
})
}
}
/// # Language generation model based on the M2M100 architecture
pub struct M2M100Generator {
model: M2M100ForConditionalGeneration,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl M2M100Generator {
/// Build a new `M2M100Generator`
///
/// # Arguments
///
/// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
/// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
///
/// # Example
///
/// ```no_run
/// # use std::path::PathBuf;
/// # use tch::Device;
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::m2m_100::M2M100Generator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
/// # home.push("rustbert");
/// # home.push("openai-gpt");
/// # let config_path = &home.as_path().join("config.json");
/// # let vocab_path = &home.as_path().join("vocab.txt");
/// # let merges_path = &home.as_path().join("merges.txt");
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let m2m100_generator = M2M100Generator::new(generate_config)?;
/// # Ok(())
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::M2M100,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
None,
None,
)?;
let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(0);
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![2],
});
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let max_position_embeddings = config.max_position_embeddings;
Ok(M2M100Generator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M100Tokenizer>
for M2M100Generator
{
fn get_model(&self) -> &M2M100ForConditionalGeneration {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
match past {
Cache::BARTCache(past) => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
prepared_position_ids: None,
prepared_past: Cache::BARTCache(past),
},
Cache::None => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids),
prepared_position_ids: None,
prepared_past: Cache::BARTCache(None),
},
_ => panic!("Cache type incompatible with M2M100"),
}
}
fn encode_prompt_text<'a, S>(
&self,
prompt_text: S,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0],
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for M2M100 model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M100Tokenizer>
for M2M100Generator
{
}

70
src/m2m_100/mod.rs Normal file
View File

@ -0,0 +1,70 @@
//! # M2M-100 (Fan et al.)
//!
//! Implementation of the M2M-100 language model ([Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) Fan, Bhosale, Schwenk, Ma, El-Kishky, Goyal, Baines, Celebi, Wenzel, Chaudhary, Goyal, Birch, Liptchinsky, Edunov, Grave, Auli, Joulin, 2020).
//! The base model is implemented in the `m2m_100::M2M100Model` struct. The model also includes a language model head: `m2m_100::M2M100ForConditionalGeneration`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//! This model allows for direct translation between 100 languages.
//! The translation capabilities are illustrated in `examples/translation_m2m100`, run with `cargo run --example translation_m2m100`.
//!
//! # Model set-up and pre-trained weights loading
//!
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `M2M100Tokenizer` using a `config.json` vocabulary and a `spiece.model` SentencePiece BPE model
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::m2m_100::{M2M100Config, M2M100Model};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::M2M100Tokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: M2M100Tokenizer = M2M100Tokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! false,
//! )?;
//! let config = M2M100Config::from_file(config_path);
//! let m2m100_model = M2M100Model::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
//! # }
//! ```
mod attention;
mod decoder;
mod embeddings;
mod encoder;
mod m2m_100_model;
pub use m2m_100_model::{
M2M100Config, M2M100ConfigResources, M2M100ForConditionalGeneration, M2M100Generator,
M2M100MergesResources, M2M100Model, M2M100ModelResources, M2M100VocabResources,
};
pub use attention::LayerState;

View File

@ -49,102 +49,102 @@ pub struct MarianSourceLanguages;
pub struct MarianTargetLanguages;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2DUTCH: (&'static str, &'static str) = (
"marian-mt-en-nl/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const DUTCH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-nl-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2CHINESE: (&'static str, &'static str) = (
"marian-mt-en-zh/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const CHINESE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-zh-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2SWEDISH: (&'static str, &'static str) = (
"marian-mt-en-sv/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const SWEDISH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-sv-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ARABIC2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ar-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2ARABIC: (&'static str, &'static str) = (
"marian-mt-en-ar/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const HINDI2ENGLISH: (&'static str, &'static str) = (
"marian-mt-hi-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>. Modified with conversion to C-array format.
pub const ENGLISH2HINDI: (&'static str, &'static str) = (
"marian-mt-en-hi/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-he-en>. Modified with conversion to C-array format.
pub const HEBREW2ENGLISH: (&'static str, &'static str) = (
"marian-mt-he-en/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-en-he>. Modified with conversion to C-array format.
pub const ENGLISH2HEBREW: (&'static str, &'static str) = (
"marian-mt-en-he/model",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/rust_model.ot",
@ -152,102 +152,102 @@ impl MarianModelResources {
}
impl MarianConfigResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2DUTCH: (&'static str, &'static str) = (
"marian-mt-en-nl/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const DUTCH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-nl-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const CHINESE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-zh-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2CHINESE: (&'static str, &'static str) = (
"marian-mt-en-zh/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2SWEDISH: (&'static str, &'static str) = (
"marian-mt-en-sv/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const SWEDISH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-sv-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ARABIC2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ar-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2ARABIC: (&'static str, &'static str) = (
"marian-mt-en-ar/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const HINDI2ENGLISH: (&'static str, &'static str) = (
"marian-mt-hi-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2HINDI: (&'static str, &'static str) = (
"marian-mt-en-hi/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/config.json",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-he-en>. Modified with conversion to C-array format.
pub const HEBREW2ENGLISH: (&'static str, &'static str) = (
"marian-mt-he-en/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/config.json",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-en-he>. Modified with conversion to C-array format.
pub const ENGLISH2HEBREW: (&'static str, &'static str) = (
"marian-mt-en-he/config",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/config.json",
@ -255,102 +255,102 @@ impl MarianConfigResources {
}
impl MarianVocabResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2DUTCH: (&'static str, &'static str) = (
"marian-mt-en-nl/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const DUTCH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-nl-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const CHINESE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-zh-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2CHINESE: (&'static str, &'static str) = (
"marian-mt-en-zh/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2SWEDISH: (&'static str, &'static str) = (
"marian-mt-en-sv/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const SWEDISH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-sv-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ARABIC2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ar-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2ARABIC: (&'static str, &'static str) = (
"marian-mt-en-ar/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const HINDI2ENGLISH: (&'static str, &'static str) = (
"marian-mt-hi-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2HINDI: (&'static str, &'static str) = (
"marian-mt-en-hi/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-he-en>. Modified with conversion to C-array format.
pub const HEBREW2ENGLISH: (&'static str, &'static str) = (
"marian-mt-he-en/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-en-he>. Modified with conversion to C-array format.
pub const ENGLISH2HEBREW: (&'static str, &'static str) = (
"marian-mt-en-he/vocab",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/vocab.json",
@ -358,102 +358,102 @@ impl MarianVocabResources {
}
impl MarianSpmResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2DUTCH: (&'static str, &'static str) = (
"marian-mt-en-nl/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const DUTCH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-nl-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const CHINESE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-zh-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2CHINESE: (&'static str, &'static str) = (
"marian-mt-en-zh/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2SWEDISH: (&'static str, &'static str) = (
"marian-mt-en-sv/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const SWEDISH2ENGLISH: (&'static str, &'static str) = (
"marian-mt-sv-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ARABIC2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ar-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2ARABIC: (&'static str, &'static str) = (
"marian-mt-en-ar/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const HINDI2ENGLISH: (&'static str, &'static str) = (
"marian-mt-hi-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
pub const ENGLISH2HINDI: (&'static str, &'static str) = (
"marian-mt-en-hi/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/source.spm",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-he-en>. Modified with conversion to C-array format.
pub const HEBREW2ENGLISH: (&'static str, &'static str) = (
"marian-mt-he-en/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/source.spm",
);
/// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format.
/// Shared under Apache 2.0 License license at <https://huggingface.co/tiedeman/opus-mt-en-he>. Modified with conversion to C-array format.
pub const ENGLISH2HEBREW: (&'static str, &'static str) = (
"marian-mt-en-he/spiece",
"https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/source.spm",
@ -615,7 +615,7 @@ impl MarianForConditionalGeneration {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -717,7 +717,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// * `position_ids` - Unused for BART
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
///

View File

@ -12,7 +12,7 @@
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `MarianTokenizer` using a `vocab.json` vocabulary and `spiece.model` sentence piece model
//!
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at <https://github.com/Helsinki-NLP/Opus-MT>.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {

View File

@ -44,7 +44,7 @@ pub struct MBartConfigResources;
pub struct MBartVocabResources;
impl MBartModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const MBART50_MANY_TO_MANY: (&'static str, &'static str) = (
"mbart-50-many-to-many-mmt/model",
"https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/rust_model.ot",
@ -52,7 +52,7 @@ impl MBartModelResources {
}
impl MBartConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const MBART50_MANY_TO_MANY: (&'static str, &'static str) = (
"mbart-50-many-to-many-mmt/config",
"https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json",
@ -60,7 +60,7 @@ impl MBartConfigResources {
}
impl MBartVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const MBART50_MANY_TO_MANY: (&'static str, &'static str) = (
"mbart-50-many-to-many-mmt/vocab",
"https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model",
@ -107,7 +107,7 @@ pub struct MBartConfig {
pub static_position_embeddings: Option<bool>,
}
impl Config<MBartConfig> for MBartConfig {}
impl Config for MBartConfig {}
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id);
@ -247,7 +247,7 @@ impl MBartModel {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
@ -423,7 +423,7 @@ impl MBartForConditionalGeneration {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -571,7 +571,7 @@ impl MBartForSequenceClassification {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -670,7 +670,7 @@ impl LMHeadModel for MBartForConditionalGeneration {
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for BART
/// * `token_type_ids` - Unused for BART

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! The summarization capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`.
//! The translation capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
@ -41,7 +41,7 @@
//! let tokenizer: MBart50Tokenizer =
//! MBart50Tokenizer::from_file(vocab_path.to_str().unwrap(), false)?;
//! let config = MBartConfig::from_file(config_path);
//! let bart_model = MBartModel::new(&vs.root(), &config);
//! let mbart_model = MBartModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())

View File

@ -31,12 +31,12 @@ pub struct MobileBertConfigResources;
pub struct MobileBertVocabResources;
impl MobileBertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://huggingface.co/google/mobilebert-uncased. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
"mobilebert-uncased/model",
"https://huggingface.co/google/mobilebert-uncased/resolve/main/rust_model.ot",
);
/// Shared under MIT license at https://huggingface.co/mrm8488/mobilebert-finetuned-pos. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
"mobilebert-finetuned-pos/model",
"https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/rust_model.ot",
@ -44,12 +44,12 @@ impl MobileBertModelResources {
}
impl MobileBertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://huggingface.co/google/mobilebert-uncased. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
"mobilebert-uncased/config",
"https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json",
);
/// Shared under MIT license at https://huggingface.co/mrm8488/mobilebert-finetuned-pos. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
"mobilebert-finetuned-pos/config",
"https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/config.json",
@ -57,12 +57,12 @@ impl MobileBertConfigResources {
}
impl MobileBertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://huggingface.co/google/mobilebert-uncased. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
"mobilebert-uncased/vocab",
"https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt",
);
/// Shared under MIT license at https://huggingface.co/mrm8488/mobilebert-finetuned-pos. Modified with conversion to C-array format.
/// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
"mobilebert-finetuned-pos/vocab",
"https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/vocab.txt",
@ -194,7 +194,7 @@ pub struct MobileBertConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<MobileBertConfig> for MobileBertConfig {}
impl Config for MobileBertConfig {}
pub struct MobileBertPredictionHeadTransform {
dense: nn::Linear,
@ -319,7 +319,7 @@ impl MobileBertModel {
///
/// * `p` - Variable store path for the root of the MobileBERT model
/// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
/// * `add_poling_layer` - boolean flag indicating if a pooling layer shuld be added after the encoder
/// * `add_poling_layer` - boolean flag indicating if a pooling layer should be added after the encoder
///
/// # Example
///

View File

@ -45,7 +45,7 @@ pub struct OpenAiGptVocabResources;
pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
/// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model",
"https://huggingface.co/openai-gpt/resolve/main/rust_model.ot",
@ -53,7 +53,7 @@ impl OpenAiGptModelResources {
}
impl OpenAiGptConfigResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
/// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config",
"https://huggingface.co/openai-gpt/resolve/main/config.json",
@ -61,7 +61,7 @@ impl OpenAiGptConfigResources {
}
impl OpenAiGptVocabResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
/// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab",
"https://huggingface.co/openai-gpt/resolve/main/vocab.json",
@ -69,7 +69,7 @@ impl OpenAiGptVocabResources {
}
impl OpenAiGptMergesResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
/// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges",
"https://huggingface.co/openai-gpt/resolve/main/merges.txt",

View File

@ -41,7 +41,7 @@ pub struct PegasusConfigResources;
pub struct PegasusVocabResources;
impl PegasusModelResources {
/// Shared under Apache 2.0 license by the Pegasus team at https://huggingface.co/google/pegasus-cnn_dailymail. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>. Modified with conversion to C-array format.
pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
"pegasus-cnn_dailymail/model",
"https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/rust_model.ot",
@ -49,7 +49,7 @@ impl PegasusModelResources {
}
impl PegasusConfigResources {
/// Shared under Apache 2.0 license by the Pegasus team at https://huggingface.co/google/pegasus-cnn_dailymail.
/// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>.
pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
"pegasus-cnn_dailymail/config",
"https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/config.json",
@ -57,7 +57,7 @@ impl PegasusConfigResources {
}
impl PegasusVocabResources {
/// Shared under Apache 2.0 license by the Pegasus team at https://huggingface.co/google/pegasus-cnn_dailymail.
/// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>.
pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
"pegasus-cnn_dailymail/spiece",
"https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/spiece.model",
@ -156,7 +156,7 @@ impl PegasusModel {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
@ -329,7 +329,7 @@ impl PegasusForConditionalGeneration {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -437,7 +437,7 @@ impl LMHeadModel for PegasusForConditionalGeneration {
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for Pegasus
/// * `token_type_ids` - Unused for Pegasus

View File

@ -25,6 +25,7 @@ use crate::electra::ElectraConfig;
use crate::gpt2::Gpt2Config;
use crate::gpt_neo::GptNeoConfig;
use crate::longformer::LongformerConfig;
use crate::m2m_100::M2M100Config;
use crate::mbart::MBartConfig;
use crate::mobilebert::MobileBertConfig;
use crate::pegasus::PegasusConfig;
@ -34,14 +35,15 @@ use crate::t5::T5Config;
use crate::xlnet::XLNetConfig;
use crate::Config;
use rust_tokenizers::tokenizer::{
AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, MBart50Tokenizer, MarianTokenizer,
MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer,
ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer, TruncationStrategy,
XLMRobertaTokenizer, XLNetTokenizer,
AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer,
MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer,
ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer,
TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
};
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, Gpt2Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab, PegasusVocab,
ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab,
AlbertVocab, BertVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab,
PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab,
XLNetVocab,
};
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
@ -70,6 +72,7 @@ pub enum ModelType {
Pegasus,
GPTNeo,
MBart,
M2M100,
}
/// # Abstraction that holds a model configuration, can be of any of the supported models
@ -106,6 +109,8 @@ pub enum ConfigOption {
GPTNeo(GptNeoConfig),
/// MBart configuration
MBart(MBartConfig),
/// M2M100 configuration
M2M100(M2M100Config),
}
/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
@ -136,6 +141,8 @@ pub enum TokenizerOption {
Pegasus(PegasusTokenizer),
/// MBart50 Tokenizer
MBart50(MBart50Tokenizer),
/// M2M100 Tokenizer
M2M100(M2M100Tokenizer),
}
impl ConfigOption {
@ -160,6 +167,7 @@ impl ConfigOption {
ModelType::Longformer => ConfigOption::Longformer(LongformerConfig::from_file(path)),
ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)),
ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)),
ModelType::M2M100 => ConfigOption::M2M100(M2M100Config::from_file(path)),
}
}
@ -201,6 +209,9 @@ impl ConfigOption {
Self::MBart(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::M2M100(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"),
@ -403,6 +414,26 @@ impl TokenizerOption {
}
TokenizerOption::MBart50(MBart50Tokenizer::from_file(vocab_path, lower_case)?)
}
ModelType::M2M100 => {
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
TokenizerOption::M2M100(M2M100Tokenizer::from_files(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)?)
}
};
Ok(tokenizer)
}
@ -423,6 +454,7 @@ impl TokenizerOption {
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
Self::MBart50(_) => ModelType::MBart,
Self::M2M100(_) => ModelType::M2M100,
}
}
@ -526,6 +558,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
}
}
@ -629,6 +668,13 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
}
}
@ -681,6 +727,9 @@ impl TokenizerOption {
Self::MBart50(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::M2M100(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
}
}
@ -700,6 +749,7 @@ impl TokenizerOption {
Self::ProphetNet(ref tokenizer) => tokenizer.tokenize(text),
Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text),
Self::MBart50(ref tokenizer) => tokenizer.tokenize(text),
Self::M2M100(ref tokenizer) => tokenizer.tokenize(text),
}
}
@ -719,6 +769,7 @@ impl TokenizerOption {
Self::ProphetNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
}
}
@ -744,6 +795,7 @@ impl TokenizerOption {
}
Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
}
}
@ -794,6 +846,9 @@ impl TokenizerOption {
Self::MBart50(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::M2M100(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
}
}
@ -856,6 +911,10 @@ impl TokenizerOption {
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::M2M100(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
};
TokenizedInput {
token_ids: token_ids_with_special_tokens.token_ids,
@ -889,6 +948,7 @@ impl TokenizerOption {
Self::ProphetNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}
@ -947,6 +1007,10 @@ impl TokenizerOption {
.special_values
.get(MBart50Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::M2M100(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
}
}
@ -1013,6 +1077,12 @@ impl TokenizerOption {
.get(MBart50Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::M2M100(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Reformer(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
@ -1064,6 +1134,12 @@ impl TokenizerOption {
.get(MBart50Vocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::M2M100(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(M2M100Vocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,

View File

@ -622,7 +622,7 @@ impl ConversationManager {
///
/// # Returns
///
/// * `Option<Conversation>` deregistered conversation
/// * `Option<Conversation>` de-registered conversation
///
/// # Example
///
@ -643,7 +643,7 @@ impl ConversationManager {
///
/// # Returns
///
/// * `HashMap<Uuid, Conversation>` deregistered conversations
/// * `HashMap<Uuid, Conversation>` de-registered conversations
///
/// # Example
///

View File

@ -42,12 +42,12 @@ pub struct ProphetNetConfigResources;
pub struct ProphetNetVocabResources;
impl ProphetNetModelResources {
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
"prophetnet-large-uncased/model",
"https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
"prophetnet-large-uncased-cnndm/model",
"https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/rust_model.ot",
@ -55,12 +55,12 @@ impl ProphetNetModelResources {
}
impl ProphetNetConfigResources {
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
"prophetnet-large-uncased/config",
"https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
"prophetnet-large-uncased-cnndm/config",
"https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/config.json",
@ -68,12 +68,12 @@ impl ProphetNetConfigResources {
}
impl ProphetNetVocabResources {
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
"prophetnet-large-uncased/vocab",
"https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
);
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
"prophetnet-large-uncased-cnndm/vocab",
"https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/prophetnet.tokenizer",
@ -120,7 +120,7 @@ pub struct ProphetNetConfig {
pub add_cross_attention: Option<bool>,
}
impl Config<ProphetNetConfig> for ProphetNetConfig {}
impl Config for ProphetNetConfig {}
/// # ProphetNet Base model
/// Base architecture for ProphetNet models. Task-specific models will be built from this common base model
@ -190,7 +190,7 @@ impl ProphetNetModel {
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation)
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
@ -393,7 +393,7 @@ impl ProphetNetForConditionalGeneration {
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation)
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
@ -693,7 +693,7 @@ impl ProphetNetForCausalGeneration {
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token)
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.

View File

@ -47,7 +47,7 @@ pub struct ReformerConfigResources;
pub struct ReformerVocabResources;
impl ReformerModelResources {
/// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Trax Authors at <https://github.com/google/trax/tree/master/trax/models/reformer>. Modified with conversion to C-array format.
pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = (
"reformer-crime-punishment/model",
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/rust_model.ot",
@ -55,7 +55,7 @@ impl ReformerModelResources {
}
impl ReformerConfigResources {
/// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Trax Authors at <https://github.com/google/trax/tree/master/trax/models/reformer>. Modified with conversion to C-array format.
pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = (
"reformer-crime-punishment/config",
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json",
@ -63,7 +63,7 @@ impl ReformerConfigResources {
}
impl ReformerVocabResources {
/// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Trax Authors at <https://github.com/google/trax/tree/master/trax/models/reformer>. Modified with conversion to C-array format.
pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = (
"reformer-crime-punishment/spiece",
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model",
@ -115,7 +115,7 @@ pub struct ReformerConfig {
pub output_hidden_states: Option<bool>,
}
impl Config<ReformerConfig> for ReformerConfig {}
impl Config for ReformerConfig {}
pub struct ReformerLMHead {
decoder: nn::Linear,

View File

@ -33,37 +33,37 @@ pub struct RobertaVocabResources;
pub struct RobertaMergesResources;
impl RobertaModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/model",
"https://huggingface.co/roberta-base/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/model",
"https://cdn.huggingface.co/distilroberta-base-rust_model.ot",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/model",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
@ -71,37 +71,37 @@ impl RobertaModelResources {
}
impl RobertaConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/config",
"https://huggingface.co/roberta-base/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/config",
"https://cdn.huggingface.co/distilroberta-base-config.json",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/config",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
@ -109,37 +109,37 @@ impl RobertaConfigResources {
}
impl RobertaVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/vocab",
"https://huggingface.co/roberta-base/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/vocab",
"https://cdn.huggingface.co/distilroberta-base-vocab.json",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/vocab",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
@ -147,17 +147,17 @@ impl RobertaVocabResources {
}
impl RobertaMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/merges",
"https://huggingface.co/roberta-base/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/merges",
"https://cdn.huggingface.co/distilroberta-base-merges.txt",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/merges",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt",

View File

@ -50,12 +50,12 @@ pub struct T5SourceLanguages;
pub type T5TargetLanguages = T5SourceLanguages;
impl T5ModelResources {
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the T5 Authors at <https://github.com/google-research/text-to-text-transfer-transformer>. Modified with conversion to C-array format.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/model",
"https://huggingface.co/t5-small/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the T5 Authors at <https://github.com/google-research/text-to-text-transfer-transformer>. Modified with conversion to C-array format.
pub const T5_BASE: (&'static str, &'static str) = (
"t5-base/model",
"https://huggingface.co/t5-base/resolve/main/rust_model.ot",
@ -63,12 +63,12 @@ impl T5ModelResources {
}
impl T5ConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/config",
"https://huggingface.co/t5-small/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const T5_BASE: (&'static str, &'static str) = (
"t5-base/config",
"https://huggingface.co/t5-base/resolve/main/config.json",
@ -76,12 +76,12 @@ impl T5ConfigResources {
}
impl T5VocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/spiece",
"https://huggingface.co/t5-small/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
pub const T5_BASE: (&'static str, &'static str) = (
"t5-base/spiece",
"https://huggingface.co/t5-base/resolve/main/spiece.model",
@ -172,7 +172,7 @@ pub struct TranslationEnToRo {
prefix: String,
}
impl Config<T5Config> for T5Config {}
impl Config for T5Config {}
/// # T5 Base model
/// Base architecture for T5 model. Usually complemented with a task-specific head, such as a language model head.
@ -464,7 +464,7 @@ impl T5ForConditionalGeneration {
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
/// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
@ -572,7 +572,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for T5
/// * `token_type_ids` - Unused for T5

View File

@ -43,7 +43,7 @@ pub struct XLNetConfigResources;
pub struct XLNetVocabResources;
impl XLNetModelResources {
/// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the XLNet Authors at <https://github.com/zihangdai/xlnet>. Modified with conversion to C-array format.
pub const XLNET_BASE_CASED: (&'static str, &'static str) = (
"xlnet-base-cased/model",
"https://huggingface.co/xlnet-base-cased/resolve/main/rust_model.ot",
@ -51,7 +51,7 @@ impl XLNetModelResources {
}
impl XLNetConfigResources {
/// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the XLNet Authors at <https://github.com/zihangdai/xlnet>. Modified with conversion to C-array format.
pub const XLNET_BASE_CASED: (&'static str, &'static str) = (
"xlnet-base-cased/config",
"https://huggingface.co/xlnet-base-cased/resolve/main/config.json",
@ -59,7 +59,7 @@ impl XLNetConfigResources {
}
impl XLNetVocabResources {
/// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the XLNet Authors at <https://github.com/zihangdai/xlnet>. Modified with conversion to C-array format.
pub const XLNET_BASE_CASED: (&'static str, &'static str) = (
"xlnet-base-cased/spiece",
"https://huggingface.co/xlnet-base-cased/resolve/main/spiece.model",
@ -116,7 +116,7 @@ pub struct XLNetConfig {
pub chunk_size_feed_forward: Option<i64>,
}
impl Config<XLNetConfig> for XLNetConfig {}
impl Config for XLNetConfig {}
/// # XLNet Base model
/// Base architecture for XLNet models. Task-specific models will be built from this common base model

117
tests/m2m100.rs Normal file
View File

@ -0,0 +1,117 @@
use rust_bert::m2m_100::{
M2M100Config, M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100Model,
M2M100ModelResources, M2M100VocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn m2m100_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = M2M100Tokenizer::from_files(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
)?;
let config = M2M100Config::from_file(config_path);
let m2m100_model = M2M100Model::new(&vs.root() / "model", &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
m2m100_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
assert_eq!(model_output.decoder_output.size(), vec!(1, 5, 1024));
assert_eq!(
model_output.encoder_hidden_state.unwrap().size(),
vec!(1, 5, 1024)
);
assert!(
(model_output.decoder_output.double_value(&[0, 0, 0]) - -2.047429323196411).abs() < 1e-4
);
Ok(())
}
#[test]
fn m2m100_translation() -> anyhow::Result<()> {
// Resources paths
let generate_config = GenerateConfig {
max_length: 56,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
do_sample: false,
num_beams: 3,
..Default::default()
};
let model = M2M100Generator::new(generate_config)?;
let input_context = ">>en.<< The dog did not wake up.";
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0];
let output = model.generate(
Some(&[input_context]),
None,
None,
None,
None,
target_language,
None,
false,
);
assert_eq!(output.len(), 1);
assert_eq!(output[0].text, ">>es.<< El perro no se despertó.");
Ok(())
}