diff --git a/CHANGELOG.md b/CHANGELOG.md index 91877b3..b2c8477 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format - (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models) - (BREAKING) Support for `output_scores` boolean argument for generation, allowing users to output the log-probability scores of generated sequences. Updated the return type of low-level generate API to `GeneratedTextOutput` and `GeneratedIndicesOutput` containing optional scores along with the generated output. - Addition of the MBart Language model and support for text generation / direct translation between 50 language +- Addition of the M2M100 Language model and support for text generation / direct translation between 100 language ## Changed - Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint) diff --git a/Cargo.toml b/Cargo.toml index 510e2cd..243f8f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ all-tests = [] features = ["doc-only"] [dependencies] -rust_tokenizers = "~6.2.3" +rust_tokenizers = { version = "~6.2.4", path = "E:/Coding/backup-rust/rust-tokenizers/main" } tch = "~0.5.0" serde_json = "1.0.64" serde = { version = "1.0.126", features = ["derive"] } diff --git a/README.md b/README.md index 82512e2..e6bc200 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ GPT-Neo| | | |✅ | | | | BART|✅| | |✅ |✅| | | Marian| | | | | |✅| | MBart|✅| | |✅ | | | | +M2M100| | | |✅ | | | | Electra | |✅| | | | |✅| ALBERT |✅|✅|✅| | | |✅| T5 | | | |✅ |✅|✅| | @@ -62,7 +63,7 @@ Pegasus| | | | |✅| | | ## Getting started This library relies on the [tch](https://github.com/LaurentMazare/tch-rs) crate for bindings to the C++ Libtorch API. -The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up yoru environment +The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up your environment to use these bindings, please refer to the [tch](https://github.com/LaurentMazare/tch-rs) for detailed information or support. Furthermore, this library relies on a cache folder for downloading pre-trained models. diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs new file mode 100644 index 0000000..5edeb23 --- /dev/null +++ b/examples/translation_m2m100.rs @@ -0,0 +1,64 @@ +// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::m2m_100::{ + M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources, + M2M100VocabResources, +}; +use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::resources::{RemoteResource, Resource}; + +fn main() -> anyhow::Result<()> { + let generate_config = GenerateConfig { + max_length: 142, + min_length: 0, + model_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )), + config_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )), + merges_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )), + do_sample: false, + early_stopping: true, + num_beams: 3, + ..Default::default() + }; + + let model = M2M100Generator::new(generate_config)?; + + let input_context_1 = ">>en.<< The dog did not wake up."; + let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0]; + + let output = model.generate( + Some(&[input_context_1]), + None, + None, + None, + None, + target_language, + None, + false, + ); + + for sentence in output { + println!("{:?}", sentence); + } + Ok(()) +} diff --git a/src/albert/albert_model.rs b/src/albert/albert_model.rs index 606a29c..8317295 100644 --- a/src/albert/albert_model.rs +++ b/src/albert/albert_model.rs @@ -31,7 +31,7 @@ pub struct AlbertConfigResources; pub struct AlbertVocabResources; impl AlbertModelResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const ALBERT_BASE_V2: (&'static str, &'static str) = ( "albert-base-v2/model", "https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot", @@ -39,7 +39,7 @@ impl AlbertModelResources { } impl AlbertConfigResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const ALBERT_BASE_V2: (&'static str, &'static str) = ( "albert-base-v2/config", "https://huggingface.co/albert-base-v2/resolve/main/config.json", @@ -47,7 +47,7 @@ impl AlbertConfigResources { } impl AlbertVocabResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const ALBERT_BASE_V2: (&'static str, &'static str) = ( "albert-base-v2/spiece", "https://huggingface.co/albert-base-v2/resolve/main/spiece.model", @@ -88,7 +88,7 @@ pub struct AlbertConfig { pub label2id: Option>, } -impl Config for AlbertConfig {} +impl Config for AlbertConfig {} /// # ALBERT Base model /// Base architecture for ALBERT models. Task-specific models will be built from this common base model diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index db2984b..9ceb66f 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -50,32 +50,32 @@ pub struct BartVocabResources; pub struct BartMergesResources; impl BartModelResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART: (&'static str, &'static str) = ( "bart/model", "https://huggingface.co/facebook/bart-large/resolve/main/rust_model.ot", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_CNN: (&'static str, &'static str) = ( "bart-cnn/model", "https://huggingface.co/facebook/bart-large-cnn/resolve/main/rust_model.ot", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_XSUM: (&'static str, &'static str) = ( "bart-xsum/model", "https://huggingface.co/facebook/bart-large-xsum/resolve/main/rust_model.ot", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_MNLI: (&'static str, &'static str) = ( "bart-large-mnli/model", "https://huggingface.co/facebook/bart-large-mnli/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = ( "distilbart-cnn-6-6/model", "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = ( "distilbart-cnn-12-6/model", "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/rust_model.ot", @@ -83,32 +83,32 @@ impl BartModelResources { } impl BartConfigResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART: (&'static str, &'static str) = ( "bart/config", "https://huggingface.co/facebook/bart-large/resolve/main/config.json", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_CNN: (&'static str, &'static str) = ( "bart-cnn/config", "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_XSUM: (&'static str, &'static str) = ( "bart-xsum/config", "https://huggingface.co/facebook/bart-large-xsum/resolve/main/config.json", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_MNLI: (&'static str, &'static str) = ( "bart-large-mnli/config", "https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = ( "distilbart-cnn-6-6/config", "https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/config.json", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = ( "distilbart-cnn-12-6/config", "https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/config.json", @@ -116,32 +116,32 @@ impl BartConfigResources { } impl BartVocabResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART: (&'static str, &'static str) = ( "bart/vocab", "https://huggingface.co/roberta-large/resolve/main/vocab.json", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_CNN: (&'static str, &'static str) = ( "bart-cnn/vocab", "https://huggingface.co/roberta-large/resolve/main/vocab.json", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_XSUM: (&'static str, &'static str) = ( "bart-xsum/vocab", "https://huggingface.co/roberta-large/resolve/main/vocab.json", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_MNLI: (&'static str, &'static str) = ( "bart-large-mnli/vocab", "https://huggingface.co/roberta-large/resolve/main/vocab.json", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = ( "distilbart-cnn-6-6/vocab", "https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/vocab.json", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = ( "distilbart-cnn-12-6/vocab", "https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/vocab.json", @@ -149,32 +149,32 @@ impl BartVocabResources { } impl BartMergesResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART: (&'static str, &'static str) = ( "bart/merges", "https://huggingface.co/roberta-large/resolve/main/merges.txt", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_CNN: (&'static str, &'static str) = ( "bart-cnn/merges", "https://huggingface.co/roberta-large/resolve/main/merges.txt", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_XSUM: (&'static str, &'static str) = ( "bart-xsum/merges", "https://huggingface.co/roberta-large/resolve/main/merges.txt", ); - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const BART_MNLI: (&'static str, &'static str) = ( "bart-large-mnli/merges", "https://huggingface.co/roberta-large/resolve/main/merges.txt", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-6-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = ( "distilbart-cnn-6-6/merges", "https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/merges.txt", ); - /// Shared under Apache 2.0 license by the Hugging Face team at https://huggingface.co/sshleifer/distilbart-cnn-12-6. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face team at . Modified with conversion to C-array format. pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = ( "distilbart-cnn-12-6/merges", "https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/merges.txt", @@ -222,7 +222,7 @@ pub struct BartConfig { pub vocab_size: i64, } -impl Config for BartConfig {} +impl Config for BartConfig {} pub(crate) fn _make_causal_mask( input_ids_shape: &[i64], @@ -391,7 +391,7 @@ impl BartModel { /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. @@ -554,7 +554,7 @@ impl BartForConditionalGeneration { /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// @@ -748,7 +748,7 @@ impl BartForSequenceClassification { /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// @@ -847,7 +847,7 @@ impl LMHeadModel for BartForConditionalGeneration { /// # Arguments /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 /// * `input_embeds` - Unused for BART /// * `token_type_ids` - Unused for BART diff --git a/src/bart/mod.rs b/src/bart/mod.rs index 0d0397a..f779493 100644 --- a/src/bart/mod.rs +++ b/src/bart/mod.rs @@ -70,7 +70,7 @@ pub use bart_model::{ }; pub(crate) use attention::BartAttention; -pub(crate) use bart_model::{_expand_mask, _prepare_decoder_attention_mask}; +pub(crate) use bart_model::{_expand_mask, _make_causal_mask, _prepare_decoder_attention_mask}; pub(crate) use decoder::BartDecoderOutput; pub(crate) use embeddings::LearnedPositionalEmbedding; pub(crate) use encoder::BartEncoderOutput; diff --git a/src/bert/bert_model.rs b/src/bert/bert_model.rs index 0a52dd2..1faf7cd 100644 --- a/src/bert/bert_model.rs +++ b/src/bert/bert_model.rs @@ -37,17 +37,17 @@ pub struct BertConfigResources; pub struct BertVocabResources; impl BertModelResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BERT: (&'static str, &'static str) = ( "bert/model", "https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot", ); - /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format. + /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/model", "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by Hugging Face Inc at . Modified with conversion to C-array format. pub const BERT_QA: (&'static str, &'static str) = ( "bert-qa/model", "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot", @@ -55,17 +55,17 @@ impl BertModelResources { } impl BertConfigResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BERT: (&'static str, &'static str) = ( "bert/config", "https://huggingface.co/bert-base-uncased/resolve/main/config.json", ); - /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format. + /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/config", "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by Hugging Face Inc at . Modified with conversion to C-array format. pub const BERT_QA: (&'static str, &'static str) = ( "bert-qa/config", "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json", @@ -73,17 +73,17 @@ impl BertConfigResources { } impl BertVocabResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BERT: (&'static str, &'static str) = ( "bert/vocab", "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", ); - /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format. + /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at . Modified with conversion to C-array format. pub const BERT_NER: (&'static str, &'static str) = ( "bert-ner/vocab", "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt", ); - /// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by Hugging Face Inc at . Modified with conversion to C-array format. pub const BERT_QA: (&'static str, &'static str) = ( "bert-qa/vocab", "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", @@ -112,7 +112,7 @@ pub struct BertConfig { pub label2id: Option>, } -impl Config for BertConfig {} +impl Config for BertConfig {} /// # BERT Base model /// Base architecture for BERT models. Task-specific models will be built from this common base model diff --git a/src/common/config.rs b/src/common/config.rs index 3f2f36c..4051597 100644 --- a/src/common/config.rs +++ b/src/common/config.rs @@ -15,9 +15,9 @@ use std::io::BufReader; use std::path::Path; /// # Utility to deserialize JSON config files -pub trait Config +pub trait Config where - for<'de> T: Deserialize<'de>, + for<'de> Self: Deserialize<'de>, { /// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model. /// The parsing will fail if non-optional keys expected by the model are missing. @@ -36,10 +36,10 @@ where /// let config_path = Path::new("path/to/config.json"); /// let config = Gpt2Config::from_file(config_path); /// ``` - fn from_file>(path: P) -> T { + fn from_file>(path: P) -> Self { let f = File::open(path).expect("Could not open configuration file."); let br = BufReader::new(f); - let config: T = serde_json::from_reader(br).expect("could not parse configuration"); + let config: Self = serde_json::from_reader(br).expect("could not parse configuration"); config } } diff --git a/src/common/resources.rs b/src/common/resources.rs index 65b8228..aea64d7 100644 --- a/src/common/resources.rs +++ b/src/common/resources.rs @@ -115,7 +115,7 @@ impl RemoteResource { } /// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to - /// ~/.cache/.rusbert/model_name. Note that this does not download the resource (only declares + /// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares /// the remote and local locations) /// /// # Arguments diff --git a/src/distilbert/distilbert_model.rs b/src/distilbert/distilbert_model.rs index 8d0ebd9..2bb7833 100644 --- a/src/distilbert/distilbert_model.rs +++ b/src/distilbert/distilbert_model.rs @@ -31,17 +31,17 @@ pub struct DistilBertConfigResources; pub struct DistilBertVocabResources; impl DistilBertModelResources { - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ( "distilbert-sst2/model", "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT: (&'static str, &'static str) = ( "distilbert/model", "https://huggingface.co/distilbert-base-uncased/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ( "distilbert-qa/model", "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot", @@ -49,17 +49,17 @@ impl DistilBertModelResources { } impl DistilBertConfigResources { - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ( "distilbert-sst2/config", "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT: (&'static str, &'static str) = ( "distilbert/config", "https://huggingface.co/distilbert-base-uncased/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ( "distilbert-qa/config", "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json", @@ -67,17 +67,17 @@ impl DistilBertConfigResources { } impl DistilBertVocabResources { - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ( "distilbert-sst2/vocab", "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT: (&'static str, &'static str) = ( "distilbert/vocab", "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ( "distilbert-qa/vocab", "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", @@ -112,7 +112,7 @@ pub struct DistilBertConfig { pub vocab_size: i64, } -impl Config for DistilBertConfig {} +impl Config for DistilBertConfig {} /// # DistilBERT Base model /// Base architecture for DistilBERT models. Task-specific models will be built from this common base model diff --git a/src/electra/electra_model.rs b/src/electra/electra_model.rs index 633fe79..723acb3 100644 --- a/src/electra/electra_model.rs +++ b/src/electra/electra_model.rs @@ -32,12 +32,12 @@ pub struct ElectraConfigResources; pub struct ElectraVocabResources; impl ElectraModelResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BASE_GENERATOR: (&'static str, &'static str) = ( "electra-base-generator/model", "https://huggingface.co/google/electra-base-generator/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ( "electra-base-discriminator/model", "https://huggingface.co/google/electra-base-discriminator/resolve/main/rust_model.ot", @@ -45,12 +45,12 @@ impl ElectraModelResources { } impl ElectraConfigResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BASE_GENERATOR: (&'static str, &'static str) = ( "electra-base-generator/config", "https://huggingface.co/google/electra-base-generator/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ( "electra-base-discriminator/config", "https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json", @@ -58,12 +58,12 @@ impl ElectraConfigResources { } impl ElectraVocabResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BASE_GENERATOR: (&'static str, &'static str) = ( "electra-base-generator/vocab", "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt", ); - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ( "electra-base-discriminator/vocab", "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt", @@ -95,7 +95,7 @@ pub struct ElectraConfig { pub label2id: Option>, } -impl Config for ElectraConfig {} +impl Config for ElectraConfig {} /// # Electra Base model /// Base architecture for Electra models. diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index 7f884e8..77e5046 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -44,32 +44,32 @@ pub struct Gpt2VocabResources; pub struct Gpt2MergesResources; impl Gpt2ModelResources { - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2: (&'static str, &'static str) = ( "gpt2/model", "https://huggingface.co/gpt2/resolve/main/rust_model.ot", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_MEDIUM: (&'static str, &'static str) = ( "gpt2-medium/model", "https://huggingface.co/gpt2-medium/resolve/main/rust_model.ot", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_LARGE: (&'static str, &'static str) = ( "gpt2-large/model", "https://huggingface.co/gpt2-large/resolve/main/rust_model.ot", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_XL: (&'static str, &'static str) = ( "gpt2-xl/model", "https://huggingface.co/gpt2-xl/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_GPT2: (&'static str, &'static str) = ( "distilgpt2/model", "https://huggingface.co/distilgpt2/resolve/main/rust_model.ot", ); - /// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/model", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot", @@ -77,32 +77,32 @@ impl Gpt2ModelResources { } impl Gpt2ConfigResources { - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2: (&'static str, &'static str) = ( "gpt2/config", "https://huggingface.co/gpt2/resolve/main/config.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_MEDIUM: (&'static str, &'static str) = ( "gpt2-medium/config", "https://huggingface.co/gpt2-medium/resolve/main/config.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_LARGE: (&'static str, &'static str) = ( "gpt2-large/config", "https://huggingface.co/gpt2-large/resolve/main/config.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_XL: (&'static str, &'static str) = ( "gpt2-xl/config", "https://huggingface.co/gpt2-xl/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_GPT2: (&'static str, &'static str) = ( "distilgpt2/config", "https://huggingface.co/distilgpt2/resolve/main/config.json", ); - /// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/config", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json", @@ -110,32 +110,32 @@ impl Gpt2ConfigResources { } impl Gpt2VocabResources { - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2: (&'static str, &'static str) = ( "gpt2/vocab", "https://huggingface.co/gpt2/resolve/main/vocab.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_MEDIUM: (&'static str, &'static str) = ( "gpt2-medium/vocab", "https://huggingface.co/gpt2-medium/resolve/main/vocab.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_LARGE: (&'static str, &'static str) = ( "gpt2-large/vocab", "https://huggingface.co/gpt2-large/resolve/main/vocab.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_XL: (&'static str, &'static str) = ( "gpt2-xl/vocab", "https://huggingface.co/gpt2-xl/resolve/main/vocab.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_GPT2: (&'static str, &'static str) = ( "distilgpt2/vocab", "https://huggingface.co/distilgpt2/resolve/main/vocab.json", ); - /// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/vocab", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json", @@ -143,32 +143,32 @@ impl Gpt2VocabResources { } impl Gpt2MergesResources { - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2: (&'static str, &'static str) = ( "gpt2/merges", "https://huggingface.co/gpt2/resolve/main/merges.txt", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_MEDIUM: (&'static str, &'static str) = ( "gpt2-medium/merges", "https://huggingface.co/gpt2-medium/resolve/main/merges.txt", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_LARGE: (&'static str, &'static str) = ( "gpt2-large/merges", "https://huggingface.co/gpt2-large/resolve/main/merges.txt", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT2_XL: (&'static str, &'static str) = ( "gpt2-xl/merges", "https://huggingface.co/gpt2-xl/resolve/main/merges.txt", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const DISTIL_GPT2: (&'static str, &'static str) = ( "distilgpt2/merges", "https://huggingface.co/distilgpt2/resolve/main/merges.txt", ); - /// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/merges", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt", @@ -199,7 +199,7 @@ pub struct Gpt2Config { pub vocab_size: i64, } -impl Config for Gpt2Config {} +impl Config for Gpt2Config {} /// # GPT2 Base model /// Base architecture for GPT2 model. Usually complemented with a task-specific head, such as a language model head. diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index 56de616..e26880b 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -41,17 +41,17 @@ pub struct GptNeoVocabResources; pub struct GptNeoMergesResources; impl GptNeoModelResources { - /// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_125M: (&'static str, &'static str) = ( "gpt-neo-125M/model", "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_1_3B: (&'static str, &'static str) = ( "gpt-neo-1_3B/model", "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_2_7B: (&'static str, &'static str) = ( "gpt-neo-2_7B/model", "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/rust_model.ot", @@ -59,17 +59,17 @@ impl GptNeoModelResources { } impl GptNeoConfigResources { - /// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_125M: (&'static str, &'static str) = ( "gpt-neo-125M/config", "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_1_3B: (&'static str, &'static str) = ( "gpt-neo-1_3B/config", "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the EleutherAI contributors at https://www.eleuther.ai. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_2_7B: (&'static str, &'static str) = ( "gpt-neo-2_7B/config", "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/config.json", @@ -77,17 +77,17 @@ impl GptNeoConfigResources { } impl GptNeoVocabResources { - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT_NEO_125M: (&'static str, &'static str) = ( "gpt-neo-125M/vocab", "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT_NEO_1_3B: (&'static str, &'static str) = ( "gpt-neo-1_3B/vocab", "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/vocab.json", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Modified MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT_NEO_2_7B: (&'static str, &'static str) = ( "gpt-neo-2_7B/vocab", "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/vocab.json", @@ -95,17 +95,17 @@ impl GptNeoVocabResources { } impl GptNeoMergesResources { - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_125M: (&'static str, &'static str) = ( "gpt-neo-125M/merges", "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_1_3B: (&'static str, &'static str) = ( "gpt-neo-1_3B/merges", "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/merges.txt", ); - /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. pub const GPT_NEO_2_7B: (&'static str, &'static str) = ( "gpt-neo-2_7B/merges", "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/merges.txt", @@ -146,7 +146,7 @@ pub struct GptNeoConfig { pub resid_dropout: f64, } -impl Config for GptNeoConfig {} +impl Config for GptNeoConfig {} /// # GPT-Neo Base model /// Base architecture for GPT-Neo models. Task-specific models will be built from this common base model @@ -263,7 +263,7 @@ impl GptNeoModel { /// - `hidden_states` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state /// - `next_cache` - `Option>>` of length *n_layer* containing the past content for the the attention layers /// - `all_hidden_states` - `Option>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// - `all_attentions` - `Option>` of length *n_layer* containign the attention weights for each layer + /// - `all_attentions` - `Option>` of length *n_layer* containing the attention weights for each layer /// /// # Example /// @@ -504,7 +504,7 @@ impl GptNeoForCausalLM { /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position /// - `next_cache` - `Option>>` of length *n_layer* containing the past content for the the attention layers /// - `all_hidden_states` - `Option>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// - `all_attentions` - `Option>` of length *n_layer* containign the attention weights for each layer + /// - `all_attentions` - `Option>` of length *n_layer* containing the attention weights for each layer /// /// # Example /// diff --git a/src/lib.rs b/src/lib.rs index 0d4303d..4cfd12b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,6 +59,7 @@ //! BART|✅| | |✅ |✅| | | //! Marian| | | | | |✅| | //! MBart|✅| | |✅ | | | | +//! M2M100| | | |✅ | | | | //! Electra | |✅| | | | |✅| //! ALBERT |✅|✅|✅| | | |✅| //! T5 | | | |✅ |✅|✅| | @@ -80,7 +81,7 @@ //! //! ### Manual installation (recommended) //! -//! 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.8.1`: if this version is no longer available on the "get started" page, +//! 1. Download `libtorch` from . This package requires `v1.8.1`: if this version is no longer available on the "get started" page, //! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11. //! 2. Extract the library to a location of your choice //! 3. Set the following environment variables @@ -579,6 +580,7 @@ pub mod electra; pub mod gpt2; pub mod gpt_neo; pub mod longformer; +pub mod m2m_100; pub mod marian; pub mod mbart; pub mod mobilebert; diff --git a/src/longformer/longformer_model.rs b/src/longformer/longformer_model.rs index 0fe4b1f..9307a54 100644 --- a/src/longformer/longformer_model.rs +++ b/src/longformer/longformer_model.rs @@ -34,12 +34,12 @@ pub struct LongformerVocabResources; pub struct LongformerMergesResources; impl LongformerModelResources { - /// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the AllenAI team at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = ( "longformer-base-4096/model", "https://huggingface.co/allenai/longformer-base-4096/resolve/main/rust_model.ot", ); - /// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = ( "longformer-base-4096/model", "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/rust_model.ot", @@ -47,12 +47,12 @@ impl LongformerModelResources { } impl LongformerConfigResources { - /// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the AllenAI team at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = ( "longformer-base-4096/config", "https://huggingface.co/allenai/longformer-base-4096/resolve/main/config.json", ); - /// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = ( "longformer-base-4096/config", "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/config.json", @@ -60,12 +60,12 @@ impl LongformerConfigResources { } impl LongformerVocabResources { - /// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the AllenAI team at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = ( "longformer-base-4096/vocab", "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json", ); - /// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = ( "longformer-base-4096/vocab", "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/vocab.json", @@ -73,12 +73,12 @@ impl LongformerVocabResources { } impl LongformerMergesResources { - /// Shared under Apache 2.0 license by the AllenAI team at https://github.com/allenai/longformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the AllenAI team at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = ( "longformer-base-4096/merges", "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt", ); - /// Shared under MIT license at https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = ( "longformer-base-4096/merges", "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/merges.txt", @@ -120,7 +120,7 @@ pub struct LongformerConfig { pub label2id: Option>, } -impl Config for LongformerConfig {} +impl Config for LongformerConfig {} fn get_question_end_index(input_ids: &Tensor, sep_token_id: i64) -> Tensor { input_ids diff --git a/src/m2m_100/attention.rs b/src/m2m_100/attention.rs new file mode 100644 index 0000000..5449741 --- /dev/null +++ b/src/m2m_100/attention.rs @@ -0,0 +1,15 @@ +// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::bart::LayerState as BartLayerState; + +pub type LayerState = BartLayerState; diff --git a/src/m2m_100/decoder.rs b/src/m2m_100/decoder.rs new file mode 100644 index 0000000..49b14a4 --- /dev/null +++ b/src/m2m_100/decoder.rs @@ -0,0 +1,195 @@ +// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::bart::{BartDecoderOutput, _expand_mask, _make_causal_mask}; +use crate::common::dropout::Dropout; +use crate::m2m_100::embeddings::SinusoidalPositionalEmbedding; +use crate::m2m_100::{LayerState, M2M100Config}; +use crate::mbart::MBartDecoderLayer; +use std::borrow::{Borrow, BorrowMut}; +use tch::{nn, Tensor}; + +pub type M2M100DecoderLayer = MBartDecoderLayer; + +pub struct M2M100Decoder { + dropout: Dropout, + layer_norm: nn::LayerNorm, + layers: Vec, + embed_positions: SinusoidalPositionalEmbedding, + output_attentions: bool, + output_hidden_states: bool, + output_past: bool, + scale_embedding: f64, +} + +impl M2M100Decoder { + pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Decoder + where + P: Borrow>, + { + let p = p.borrow(); + let output_past = config.output_past.unwrap_or(true); + let output_attentions = config.output_attentions.unwrap_or(false); + let output_hidden_states = config.output_hidden_states.unwrap_or(false); + + let scale_embedding = if let Some(scale_embeddings) = config.scale_embedding { + if scale_embeddings { + (config.d_model as f64).sqrt() + } else { + 1.0 + } + } else { + 1.0 + }; + + let dropout = Dropout::new(config.dropout); + + let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default()); + + let embed_positions = SinusoidalPositionalEmbedding::new( + p / "embed_positions", + config.max_position_embeddings, + config.d_model, + config.pad_token_id.unwrap_or(1), + ); + + let mut layers: Vec = vec![]; + let p_layers = p / "layers"; + for layer_index in 0..config.decoder_layers { + layers.push(M2M100DecoderLayer::new(&p_layers / layer_index, config)); + } + + M2M100Decoder { + dropout, + layer_norm, + layers, + embed_positions, + output_attentions, + output_hidden_states, + output_past, + scale_embedding, + } + } + + pub fn forward_t( + &self, + input_ids: &Tensor, + encoder_hidden_states: &Tensor, + encoder_attention_mask: Option<&Tensor>, + decoder_attention_mask: Option<&Tensor>, + embeddings: &nn::Embedding, + old_layer_states: Option, Option)>>, + train: bool, + ) -> M2M100DecoderOutput { + let past_key_values_length = if let Some(old_layer_states_values) = &old_layer_states { + if let Some(old_value_state) = &old_layer_states_values[0].0 { + old_value_state.prev_key.size()[2] + } else { + 0 + } + } else { + 0 + }; + let input_shape = input_ids.size(); + let sequence_length = input_shape[1]; + + let positions = self + .embed_positions + .forward(input_ids, past_key_values_length); + + let x: Tensor = input_ids.apply(embeddings) * self.scale_embedding + positions; + + let causal_mask = if sequence_length > 1 { + Some(_make_causal_mask( + input_ids.size().as_slice(), + x.kind(), + x.device(), + past_key_values_length, + )) + } else { + None + }; + + let decoder_attention_mask = decoder_attention_mask.map(|attention_mask| { + if let Some(causal_mask) = causal_mask { + causal_mask + _expand_mask(&attention_mask, Some(sequence_length)) + } else { + _expand_mask(&attention_mask, Some(sequence_length)) + } + }); + + let encoder_attention_mask = encoder_attention_mask + .map(|mask| _expand_mask(mask, Some(*input_ids.size().last().unwrap()))); + + let mut hidden_state = x.apply_t(&self.dropout, train); + + let mut all_hidden_states: Option> = if self.output_hidden_states { + Some(Vec::with_capacity(self.layers.len())) + } else { + None + }; + let mut all_attentions: Option> = if self.output_attentions { + Some(Vec::with_capacity(self.layers.len())) + } else { + None + }; + let mut next_decoder_cache: Option, Option)>> = + if self.output_past { + if old_layer_states.is_some() { + old_layer_states + } else { + Some(vec![(None, None); self.layers.len()]) + } + } else { + None + }; + + let mut attention_weights: Option; + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let layer_state = match &next_decoder_cache { + Some(values) => values[layer_idx].to_owned(), + None => (None, None), + }; + let temp = layer.forward_t( + &hidden_state, + &encoder_hidden_states, + encoder_attention_mask.as_ref(), + decoder_attention_mask.as_ref(), + layer_state, + train, + ); + hidden_state = temp.0; + attention_weights = temp.1; + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy()); + }; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(attention_weights.as_ref().unwrap().copy()); + }; + if let Some(value) = &mut next_decoder_cache { + value[layer_idx] = temp.2 + }; + } + + M2M100DecoderOutput { + hidden_state: hidden_state.apply(&self.layer_norm), + encoder_attention_mask, + next_decoder_cache, + all_hidden_states, + all_attentions, + } + } +} + +/// Container holding a M2M100 decoder output +pub type M2M100DecoderOutput = BartDecoderOutput; diff --git a/src/m2m_100/embeddings.rs b/src/m2m_100/embeddings.rs new file mode 100644 index 0000000..24b3988 --- /dev/null +++ b/src/m2m_100/embeddings.rs @@ -0,0 +1,122 @@ +// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Borrow; +use std::ops::Deref; +use std::sync::RwLock; +use tch::nn::embedding; +use tch::{nn, Device, Kind, Tensor}; + +#[derive(Debug)] +pub struct SinusoidalPositionalEmbedding { + embedding: RwLock, + embedding_dim: i64, + padding_idx: i64, + offset: i64, +} + +impl SinusoidalPositionalEmbedding { + pub fn new<'p, P>( + p: P, + num_embeddings: i64, + embedding_dim: i64, + padding_idx: i64, + ) -> SinusoidalPositionalEmbedding + where + P: Borrow>, + { + let device = p.borrow().device(); + let mut local_varstore = nn::VarStore::new(device); + let offset = 2; + + let embedding = RwLock::new(embedding( + local_varstore.root(), + num_embeddings + offset, + embedding_dim, + Default::default(), + )); + + embedding.write().unwrap().ws = SinusoidalPositionalEmbedding::build_positional_embeddings( + num_embeddings + offset, + embedding_dim, + padding_idx, + device, + ); + + local_varstore.freeze(); + SinusoidalPositionalEmbedding { + embedding, + embedding_dim, + padding_idx, + offset, + } + } + + fn build_positional_embeddings( + num_embeddings: i64, + embedding_dim: i64, + padding_idx: i64, + device: Device, + ) -> Tensor { + let half_dim = embedding_dim / 2; + + let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64); + let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp(); + let emb = + Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0); + let mut sinusoidal_embedding = + Tensor::cat(&[&emb.sin(), &emb.cos()], 1).view([num_embeddings, -1]); + + if embedding_dim % 2 == 1 { + sinusoidal_embedding = Tensor::cat( + &[ + sinusoidal_embedding, + Tensor::zeros(&[num_embeddings, 1], (Kind::Float, device)), + ], + 1, + ); + } + let _ = sinusoidal_embedding.select(0, padding_idx).fill_(0); + + let _ = sinusoidal_embedding.requires_grad_(false); + sinusoidal_embedding + } + + fn create_position_ids_from_input_ids( + &self, + input_ids: &Tensor, + past_key_values_length: i64, + ) -> Tensor { + let mask = input_ids.ne(self.padding_idx).to_kind(Kind::Int64); + let incremental_indices = (mask.cumsum(1, Kind::Int64) + past_key_values_length) * mask; + incremental_indices + self.padding_idx + } + + pub fn forward(&self, input_ids: &Tensor, past_key_values_length: i64) -> Tensor { + let position_ids = + self.create_position_ids_from_input_ids(input_ids, past_key_values_length); + let input_size = input_ids.size(); + let seq_length = input_size[1]; + + let max_pos = self.padding_idx + 1 + seq_length; + if max_pos > self.embedding.read().unwrap().ws.size()[0] { + self.embedding.write().unwrap().ws = + SinusoidalPositionalEmbedding::build_positional_embeddings( + max_pos + self.offset, + self.embedding_dim, + self.padding_idx, + input_ids.device(), + ); + } + position_ids.apply(self.embedding.read().unwrap().deref()) + } +} diff --git a/src/m2m_100/encoder.rs b/src/m2m_100/encoder.rs new file mode 100644 index 0000000..1a13de7 --- /dev/null +++ b/src/m2m_100/encoder.rs @@ -0,0 +1,134 @@ +// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::bart::{BartEncoderOutput, _expand_mask}; +use crate::common::dropout::Dropout; +use crate::m2m_100::embeddings::SinusoidalPositionalEmbedding; +use crate::m2m_100::M2M100Config; +use crate::mbart::MBartEncoderLayer; +use std::borrow::{Borrow, BorrowMut}; +use tch::{nn, Tensor}; + +pub type M2M100EncoderLayer = MBartEncoderLayer; + +pub struct M2M100Encoder { + dropout: Dropout, + layer_norm: nn::LayerNorm, + layers: Vec, + embed_positions: SinusoidalPositionalEmbedding, + output_attentions: bool, + output_hidden_states: bool, + scale_embedding: f64, +} + +impl M2M100Encoder { + pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Encoder + where + P: Borrow>, + { + let p = p.borrow(); + let output_attentions = config.output_attentions.unwrap_or(false); + let output_hidden_states = config.output_hidden_states.unwrap_or(false); + + let scale_embedding = if let Some(scale_embeddings) = config.scale_embedding { + if scale_embeddings { + (config.d_model as f64).sqrt() + } else { + 1.0 + } + } else { + 1.0 + }; + + let dropout = Dropout::new(config.dropout); + + let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default()); + + let embed_positions = SinusoidalPositionalEmbedding::new( + p / "embed_positions", + config.max_position_embeddings, + config.d_model, + config.pad_token_id.unwrap_or(1), + ); + + let mut layers: Vec = vec![]; + let p_layers = p / "layers"; + for layer_index in 0..config.encoder_layers { + layers.push(M2M100EncoderLayer::new(&p_layers / layer_index, config)); + } + + M2M100Encoder { + dropout, + layer_norm, + layers, + embed_positions, + output_attentions, + output_hidden_states, + scale_embedding, + } + } + + pub fn forward_t( + &self, + input_ids: &Tensor, + attention_mask: Option<&Tensor>, + embeddings: &nn::Embedding, + train: bool, + ) -> M2M100EncoderOutput { + let attention_mask = attention_mask.map(|mask| _expand_mask(mask, None)); + + let x = input_ids.apply(embeddings) * self.scale_embedding; + let x = x + &self.embed_positions.forward(input_ids, 0); + let mut hidden_state = x.apply_t(&self.dropout, train); + + let mut all_hidden_states: Option> = if self.output_hidden_states { + Some(vec![]) + } else { + None + }; + let mut all_attentions: Option> = if self.output_attentions { + Some(vec![]) + } else { + None + }; + + let mut attention_weights: Option; + + for layer in &self.layers { + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy()); + }; + + let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train); + hidden_state = temp.0; + attention_weights = temp.1; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(attention_weights.as_ref().unwrap().copy()); + }; + } + + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy()); + }; + + hidden_state = hidden_state.apply(&self.layer_norm); + + M2M100EncoderOutput { + hidden_state, + all_hidden_states, + all_attentions, + } + } +} + +/// Container holding a M2M100 encoder output +pub type M2M100EncoderOutput = BartEncoderOutput; diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs new file mode 100644 index 0000000..9e8137e --- /dev/null +++ b/src/m2m_100/m2m_100_model.rs @@ -0,0 +1,869 @@ +// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::gpt2::{ + Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources, +}; +use crate::m2m_100::decoder::M2M100Decoder; +use crate::m2m_100::encoder::M2M100Encoder; +use crate::m2m_100::LayerState; +use crate::mbart::{MBartConfig, MBartModelOutput}; +use crate::pipelines::common::{ModelType, TokenizerOption}; +use crate::pipelines::generation_utils::private_generation_utils::{ + PreparedInput, PrivateLanguageGenerator, +}; +use crate::pipelines::generation_utils::{ + Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, +}; +use crate::resources::{RemoteResource, Resource}; +use crate::{Config, RustBertError}; +use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy}; +use rust_tokenizers::vocab::{M2M100Vocab, Vocab}; +use std::borrow::Borrow; +use tch::nn::{embedding, EmbeddingConfig}; +use tch::{nn, Kind, Tensor}; + +/// # M2M100 Pretrained model weight files +pub struct M2M100ModelResources; + +/// # M2M100 Pretrained model config files +pub struct M2M100ConfigResources; + +/// # M2M100 Pretrained model vocab files +pub struct M2M100VocabResources; + +/// # M2M100 Pretrained model merges files +pub struct M2M100MergesResources; + +impl M2M100ModelResources { + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_418M: (&'static str, &'static str) = ( + "m2m100-418m/model", + "https://huggingface.co/facebook/m2m100_418M/resolve/main/rust_model.ot", + ); + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_1_2B: (&'static str, &'static str) = ( + "m2m100-1_2b/model", + "https://huggingface.co/facebook/m2m100_1.2B/resolve/main/rust_model.ot", + ); +} + +impl M2M100ConfigResources { + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_418M: (&'static str, &'static str) = ( + "m2m100-418m/config", + "https://huggingface.co/facebook/m2m100_418M/resolve/main/config.json", + ); + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_1_2B: (&'static str, &'static str) = ( + "m2m100-1_2b/config", + "https://huggingface.co/facebook/m2m100_1.2B/resolve/main/config.json", + ); +} + +impl M2M100VocabResources { + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_418M: (&'static str, &'static str) = ( + "m2m100-418m/vocab", + "https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json", + ); + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_1_2B: (&'static str, &'static str) = ( + "m2m100-1_2b/vocab", + "https://huggingface.co/facebook/m2m100_1.2B/resolve/main/vocab.json", + ); +} + +impl M2M100MergesResources { + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_418M: (&'static str, &'static str) = ( + "m2m100-418m/merges", + "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model", + ); + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. + pub const M2M100_1_2B: (&'static str, &'static str) = ( + "m2m100-1_2b/merges", + "https://huggingface.co/facebook/m2m100_1.2B/resolve/main/sentencepiece.bpe.model", + ); +} + +pub type M2M100Config = MBartConfig; + +fn _shift_tokens_right( + input_ids: &Tensor, + pad_token_id: i64, + decoder_start_token_id: i64, +) -> Tensor { + let shifted_input_ids = Tensor::zeros( + input_ids.size().as_slice(), + (Kind::Int64, input_ids.device()), + ); + let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id); + let _ = shifted_input_ids + .slice(1, 1, *shifted_input_ids.size().last().unwrap(), 1) + .copy_(&input_ids.slice(1, 0, *input_ids.size().last().unwrap() - 1, 1)); + shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), pad_token_id) +} + +/// # M2M100 Base model +/// Base architecture for M2M100 model. Usually complemented with a task-specific head, such as a language model head. +/// It is made of the following blocks: +/// - `encoder`: `M2M100Encoder` (transformer) made of a vector of encoding layers +/// - `decoder`: `M2M100Decoder` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention. +/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values) +/// - `pad_token_id`: padding token id +pub struct M2M100Model { + pub(crate) encoder: M2M100Encoder, + decoder: M2M100Decoder, + pub(crate) embeddings: nn::Embedding, + pad_token_id: i64, + decoder_start_token_id: i64, +} + +impl M2M100Model { + /// Build a new `M2M100Model` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the M2M100 model + /// * `config` - `M2M100Config` object defining the model architecture + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::m2m_100::{M2M100Config, M2M100Model}; + /// use rust_bert::Config; + /// use std::path::Path; + /// use tch::{nn, Device}; + /// + /// let config_path = Path::new("path/to/config.json"); + /// let device = Device::Cpu; + /// let p = nn::VarStore::new(device); + /// let config = M2M100Config::from_file(config_path); + /// let m2m100: M2M100Model = M2M100Model::new(&p.root() / "m2m100", &config); + /// ``` + pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Model + where + P: Borrow>, + { + let p = p.borrow(); + + let pad_token_id = config.pad_token_id.unwrap_or(1); + let decoder_start_token_id = config.decoder_start_token_id.unwrap_or(2); + let embedding_config = EmbeddingConfig { + padding_idx: pad_token_id, + ..Default::default() + }; + let embeddings: nn::Embedding = embedding( + p / "shared", + config.vocab_size, + config.d_model, + embedding_config, + ); + + let encoder = M2M100Encoder::new(p / "encoder", config); + let decoder = M2M100Decoder::new(p / "decoder", config); + + M2M100Model { + encoder, + decoder, + embeddings, + pad_token_id, + decoder_start_token_id, + } + } + + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode + /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) + /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). + /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. + /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `M2M100ModelOutput` containing: + /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state + /// - `encoder_hidden_states` - `Option` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None + /// - `cache` - `(Option, Option>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder. + /// - `all_encoder_hidden_states` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_encoder_attentions` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_decoder_hidden_states` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// - `all_decoder_attentions` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::m2m_100::{M2M100Config, M2M100Model}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = M2M100Config::from_file(config_path); + /// # let m2m100_model: M2M100Model = M2M100Model::new(&vs.root(), &config); + /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); + /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); + /// let encoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// let decoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// m2m100_model.forward_t( + /// Some(&input_tensor), + /// Some(&encoder_attention_mask), + /// Some(&target_tensor), + /// None, + /// Some(&decoder_attention_mask), + /// None, + /// false, + /// ) + /// }); + /// ``` + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + encoder_output: Option<&Tensor>, + decoder_attention_mask: Option<&Tensor>, + layer_states: Option, Option)>>, + train: bool, + ) -> M2M100ModelOutput { + let calc_decoder_input_ids = if decoder_input_ids.is_none() { + Some(_shift_tokens_right( + input_ids.unwrap(), + self.pad_token_id, + self.decoder_start_token_id, + )) + } else { + None + }; + + let decoder_input_ids = + decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap()); + + let calc_encoder_output = if encoder_output.is_none() { + Some(self.encoder.forward_t( + input_ids.unwrap(), + attention_mask, + &self.embeddings, + train, + )) + } else { + None + }; + + let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) = + if let Some(calc_encoder_output) = calc_encoder_output { + ( + Some(calc_encoder_output.hidden_state), + calc_encoder_output.all_hidden_states, + calc_encoder_output.all_attentions, + ) + } else { + (None, None, None) + }; + + let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap()); + + let decoder_output = self.decoder.forward_t( + &decoder_input_ids, + &encoder_output, + attention_mask, + decoder_attention_mask, + &self.embeddings, + layer_states, + train, + ); + + M2M100ModelOutput { + decoder_output: decoder_output.hidden_state, + encoder_hidden_state: calc_hidden_states, + cache: decoder_output.next_decoder_cache, + all_decoder_hidden_states: decoder_output.all_hidden_states, + all_decoder_attentions: decoder_output.all_attentions, + all_encoder_hidden_states, + all_encoder_attentions, + } + } +} + +/// Container holding a M2M100 model output +pub type M2M100ModelOutput = MBartModelOutput; + +/// # M2M100 Model for conditional generation +/// M2M100 model with a vocabulary decoding head +/// It is made of the following blocks: +/// - `base_model`: `M2M100Model` Base M2M100 model +/// - `linear`: Linear layer without bias tied to the weights of the token id embeddings +pub struct M2M100ForConditionalGeneration { + base_model: M2M100Model, +} + +impl M2M100ForConditionalGeneration { + /// Build a new `M2M100ForConditionalGeneration` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the M2M100 model + /// * `config` - `M2M100Config` object defining the model architecture + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration}; + /// use rust_bert::Config; + /// use std::path::Path; + /// use tch::{nn, Device}; + /// + /// let config_path = Path::new("path/to/config.json"); + /// let device = Device::Cpu; + /// let p = nn::VarStore::new(device); + /// let config = M2M100Config::from_file(config_path); + /// let m2m100: M2M100ForConditionalGeneration = + /// M2M100ForConditionalGeneration::new(&p.root(), &config); + /// ``` + pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100ForConditionalGeneration + where + P: Borrow>, + { + let base_model = M2M100Model::new(p.borrow() / "model", config); + M2M100ForConditionalGeneration { base_model } + } + + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode + /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. + /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). + /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) + /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `M2M100ModelOutput` containing: + /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position + /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state + /// - `cache` - `(Option, Option>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder. + /// - `all_encoder_hidden_states` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_encoder_attentions` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_decoder_hidden_states` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// - `all_decoder_attentions` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// # use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = M2M100Config::from_file(config_path); + /// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config); + /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); + /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); + /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// m2m100_model + /// .forward_t(Some(&input_tensor), + /// Some(&encoder_attention_mask), + /// None, + /// Some(&target_tensor), + /// Some(&decoder_attention_mask), + /// None, + /// false) + /// }); + /// ``` + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + encoder_output: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + decoder_attention_mask: Option<&Tensor>, + old_layer_states: Option, Option)>>, + train: bool, + ) -> M2M100ModelOutput { + let base_model_output = self.base_model.forward_t( + input_ids, + attention_mask, + decoder_input_ids, + encoder_output, + decoder_attention_mask, + old_layer_states, + train, + ); + + let lm_logits = base_model_output + .decoder_output + .linear::(&self.base_model.embeddings.ws, None); + M2M100ModelOutput { + decoder_output: lm_logits, + ..base_model_output + } + } + + pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { + self.base_model + .encoder + .forward_t( + input_ids, + attention_mask, + &self.base_model.embeddings, + false, + ) + .hidden_state + } +} + +impl LMHeadModel for M2M100ForConditionalGeneration { + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) + /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 + /// * `input_embeds` - Unused for M2M100 + /// * `token_type_ids` - Unused for M2M100 + /// * `position_ids` - Unused for M2M100 + /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `LMModelOutput` containing: + /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position + /// - `cache` - `BARTCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for + /// both the self attention and the encoder cross attention of each layer of the decoder. + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::pipelines::generation_utils::LMHeadModel; + /// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = M2M100Config::from_file(config_path); + /// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config); + /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); + /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); + /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// m2m100_model + /// .forward_t(Some(&input_tensor), + /// Some(&encoder_attention_mask), + /// None, + /// Some(&target_tensor), + /// Some(&decoder_attention_mask), + /// None, + /// false) + /// }); + /// ``` + fn forward_t( + &self, + input_ids: &Option, + cache: Cache, + attention_mask: &Option, + _token_type_ids: &Option, + _position_ids: &Option, + _input_embeds: &Option, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: &Option, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( + input_ids.as_ref(), + attention_mask.as_ref(), + decoder_input_ids.as_ref(), + encoder_outputs, + None, + cached_layer_states, + train, + ), + + Cache::None => self.base_model.forward_t( + input_ids.as_ref(), + attention_mask.as_ref(), + decoder_input_ids.as_ref(), + encoder_outputs, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with M2M100 Model".into(), + )); + } + }; + + let lm_logits = base_model_output + .decoder_output + .linear::(&self.base_model.embeddings.ws, None); + Ok(LMModelOutput { + lm_logits, + cache: Cache::BARTCache(base_model_output.cache), + }) + } +} + +/// # Language generation model based on the M2M100 architecture +pub struct M2M100Generator { + model: M2M100ForConditionalGeneration, + tokenizer: TokenizerOption, + var_store: nn::VarStore, + generate_config: GenerateConfig, + bos_token_id: Option, + eos_token_ids: Option>, + pad_token_id: Option, + is_encoder_decoder: bool, + vocab_size: i64, + decoder_start_id: Option, + max_position_embeddings: i64, +} + +impl M2M100Generator { + /// Build a new `M2M100Generator` + /// + /// # Arguments + /// + /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention + /// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention + /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention + /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided. + /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)` + /// + /// # Example + /// + /// ```no_run + /// # use std::path::PathBuf; + /// # use tch::Device; + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::m2m_100::M2M100Generator; + /// use rust_bert::pipelines::generation_utils::GenerateConfig; + /// # let mut home: PathBuf = dirs::home_dir().unwrap(); + /// # home.push("rustbert"); + /// # home.push("openai-gpt"); + /// # let config_path = &home.as_path().join("config.json"); + /// # let vocab_path = &home.as_path().join("vocab.txt"); + /// # let merges_path = &home.as_path().join("merges.txt"); + /// # let weights_path = &home.as_path().join("model.ot"); + /// let device = Device::cuda_if_available(); + /// let generate_config = GenerateConfig { + /// max_length: 30, + /// do_sample: true, + /// num_beams: 5, + /// temperature: 1.1, + /// num_return_sequences: 3, + /// ..Default::default() + /// }; + /// let m2m100_generator = M2M100Generator::new(generate_config)?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(generate_config: GenerateConfig) -> Result { + // The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models + let model_resource = if generate_config.model_resource + == Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) + { + Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )) + } else { + generate_config.model_resource.clone() + }; + + let config_resource = if generate_config.config_resource + == Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)) + { + Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )) + } else { + generate_config.config_resource.clone() + }; + + let vocab_resource = if generate_config.vocab_resource + == Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)) + { + Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )) + } else { + generate_config.vocab_resource.clone() + }; + + let merges_resource = if generate_config.merges_resource + == Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)) + { + Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )) + } else { + generate_config.merges_resource.clone() + }; + + let config_path = config_resource.get_local_path()?; + let vocab_path = vocab_resource.get_local_path()?; + let merges_path = merges_resource.get_local_path()?; + let weights_path = model_resource.get_local_path()?; + let device = generate_config.device; + + generate_config.validate(); + let mut var_store = nn::VarStore::new(device); + let tokenizer = TokenizerOption::from_file( + ModelType::M2M100, + vocab_path.to_str().unwrap(), + Some(merges_path.to_str().unwrap()), + false, + None, + None, + )?; + let config = M2M100Config::from_file(config_path); + let model = M2M100ForConditionalGeneration::new(&var_store.root(), &config); + var_store.load(weights_path)?; + + let bos_token_id = Some(0); + let eos_token_ids = Some(match config.eos_token_id { + Some(value) => vec![value], + None => vec![2], + }); + let pad_token_id = Some(config.pad_token_id.unwrap_or(1)); + let vocab_size = config.vocab_size; + let is_encoder_decoder = true; + let decoder_start_id = Some(2); + let max_position_embeddings = config.max_position_embeddings; + + Ok(M2M100Generator { + model, + tokenizer, + var_store, + generate_config, + bos_token_id, + eos_token_ids, + pad_token_id, + is_encoder_decoder, + vocab_size, + decoder_start_id, + max_position_embeddings, + }) + } + + fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { + let impossible_tokens: Vec = (0..self.get_vocab_size() as i64) + .filter(|pos| !token_ids.contains(pos)) + .collect(); + let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); + let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY); + } +} + +impl PrivateLanguageGenerator + for M2M100Generator +{ + fn get_model(&self) -> &M2M100ForConditionalGeneration { + &self.model + } + fn _get_tokenizer(&self) -> &TokenizerOption { + &self.tokenizer + } + fn get_var_store(&self) -> &nn::VarStore { + &self.var_store + } + fn get_config(&self) -> &GenerateConfig { + &self.generate_config + } + fn get_bos_id(&self) -> &Option { + &self.bos_token_id + } + fn get_eos_ids(&self) -> &Option> { + &self.eos_token_ids + } + fn get_pad_id(&self) -> &Option { + &self.pad_token_id + } + fn is_encoder_decoder(&self) -> bool { + self.is_encoder_decoder + } + fn get_vocab_size(&self) -> i64 { + self.vocab_size + } + fn get_decoder_start_id(&self) -> Option { + self.decoder_start_id + } + + fn get_max_positions_embeddings(&self) -> i64 { + self.max_position_embeddings + } + + fn prepare_scores_for_generation( + &self, + scores: &mut Tensor, + current_length: i64, + max_length: i64, + forced_bos_token_id: Option, + ) { + if current_length == 1 { + self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]); + } else if current_length == max_length - 1 { + self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap()); + } + } + + fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option { + Some(self.get_model().encode(input_ids, attention_mask)) + } + + fn prepare_inputs_for_generation<'a>( + &self, + input_ids: Tensor, + encoder_outputs: Option<&'a Tensor>, + past: Cache, + attention_mask: Tensor, + ) -> PreparedInput<'a> { + match past { + Cache::BARTCache(past) => PreparedInput { + prepared_input: None, + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: encoder_outputs, + prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)), + prepared_position_ids: None, + prepared_past: Cache::BARTCache(past), + }, + Cache::None => PreparedInput { + prepared_input: None, + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: encoder_outputs, + prepared_decoder_input: Some(input_ids), + prepared_position_ids: None, + prepared_past: Cache::BARTCache(None), + }, + _ => panic!("Cache type incompatible with M2M100"), + } + } + + fn encode_prompt_text<'a, S>( + &self, + prompt_text: S, + max_len: i64, + pad_token_id: Option, + ) -> Tensor + where + S: AsRef<[&'a str]>, + { + let tokens = self._get_tokenizer().encode_list( + prompt_text.as_ref(), + max_len as usize, + &TruncationStrategy::LongestFirst, + 0, + ); + let token_ids = tokens + .into_iter() + .map(|tokenized_input| tokenized_input.token_ids) + .collect::>>(); + + let max_len = token_ids.iter().map(|input| input.len()).max().unwrap(); + + let pad_token = match pad_token_id { + Some(value) => value, + None => self + ._get_tokenizer() + .convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0], + }; + + let token_ids = token_ids + .into_iter() + .map(|mut input| { + let temp = vec![pad_token; max_len - input.len()]; + input.extend(temp); + input + }) + .map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device())) + .collect::>(); + + Tensor::stack(&token_ids, 0) + } + + fn reorder_cache( + &self, + past: &mut Cache, + encoder_outputs: Option, + beam_indices: &Tensor, + ) -> Option { + let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices)); + match past { + Cache::BARTCache(old_cache_option) => match old_cache_option { + Some(old_cache) => { + for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() { + if self_layer_state.is_some() { + self_layer_state + .as_mut() + .unwrap() + .reorder_cache(beam_indices) + }; + if encoder_layer_state.is_some() { + encoder_layer_state + .as_mut() + .unwrap() + .reorder_cache(beam_indices) + }; + } + } + None => {} + }, + Cache::None => {} + _ => { + panic!("Invalid cache for M2M100 model"); + } + }; + encoder_outputs + } +} + +impl LanguageGenerator + for M2M100Generator +{ +} diff --git a/src/m2m_100/mod.rs b/src/m2m_100/mod.rs new file mode 100644 index 0000000..e3fa63b --- /dev/null +++ b/src/m2m_100/mod.rs @@ -0,0 +1,70 @@ +//! # M2M-100 (Fan et al.) +//! +//! Implementation of the M2M-100 language model ([Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) Fan, Bhosale, Schwenk, Ma, El-Kishky, Goyal, Baines, Celebi, Wenzel, Chaudhary, Goyal, Birch, Liptchinsky, Edunov, Grave, Auli, Joulin, 2020). +//! The base model is implemented in the `m2m_100::M2M100Model` struct. The model also includes a language model head: `m2m_100::M2M100ForConditionalGeneration` +//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! This model allows for direct translation between 100 languages. +//! The translation capabilities are illustrated in `examples/translation_m2m100`, run with `cargo run --example translation_m2m100`. +//! +//! # Model set-up and pre-trained weights loading +//! +//! All models expect the following resources: +//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) +//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. +//! - `M2M100Tokenizer` using a `config.json` vocabulary and a `spiece.model` SentencePiece BPE model +//! Pretrained models are available and can be downloaded using RemoteResources. +//! +//! ```no_run +//! # fn main() -> anyhow::Result<()> { +//! # +//! use tch::{nn, Device}; +//! # use std::path::PathBuf; +//! use rust_bert::m2m_100::{M2M100Config, M2M100Model}; +//! use rust_bert::resources::{LocalResource, Resource}; +//! use rust_bert::Config; +//! use rust_tokenizers::tokenizer::M2M100Tokenizer; +//! +//! let config_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/config.json"), +//! }); +//! let vocab_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/vocab.txt"), +//! }); +//! let merges_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/spiece.model"), +//! }); +//! let weights_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/model.ot"), +//! }); +//! let config_path = config_resource.get_local_path()?; +//! let vocab_path = vocab_resource.get_local_path()?; +//! let merges_path = merges_resource.get_local_path()?; +//! let weights_path = weights_resource.get_local_path()?; +//! +//! let device = Device::cuda_if_available(); +//! let mut vs = nn::VarStore::new(device); +//! let tokenizer: M2M100Tokenizer = M2M100Tokenizer::from_file( +//! vocab_path.to_str().unwrap(), +//! merges_path.to_str().unwrap(), +//! false, +//! )?; +//! let config = M2M100Config::from_file(config_path); +//! let m2m100_model = M2M100Model::new(&vs.root(), &config); +//! vs.load(weights_path)?; +//! +//! # Ok(()) +//! # } +//! ``` + +mod attention; +mod decoder; +mod embeddings; +mod encoder; +mod m2m_100_model; + +pub use m2m_100_model::{ + M2M100Config, M2M100ConfigResources, M2M100ForConditionalGeneration, M2M100Generator, + M2M100MergesResources, M2M100Model, M2M100ModelResources, M2M100VocabResources, +}; + +pub use attention::LayerState; diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 50b8525..5957314 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -49,102 +49,102 @@ pub struct MarianSourceLanguages; pub struct MarianTargetLanguages; impl MarianModelResources { - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ( "marian-mt-en-ROMANCE/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ROMANCE-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2GERMAN: (&'static str, &'static str) = ( "marian-mt-en-de/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const GERMAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-de-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ( "marian-mt-en-ru/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ru-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const FRENCH2GERMAN: (&'static str, &'static str) = ( "marian-mt-fr-de/model", "https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const GERMAN2FRENCH: (&'static str, &'static str) = ( "marian-mt-de-fr/model", "https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2DUTCH: (&'static str, &'static str) = ( "marian-mt-en-nl/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const DUTCH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-nl-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2CHINESE: (&'static str, &'static str) = ( "marian-mt-en-zh/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const CHINESE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-zh-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2SWEDISH: (&'static str, &'static str) = ( "marian-mt-en-sv/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const SWEDISH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-sv-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ARABIC2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ar-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2ARABIC: (&'static str, &'static str) = ( "marian-mt-en-ar/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const HINDI2ENGLISH: (&'static str, &'static str) = ( "marian-mt-hi-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/rust_model.ot", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . Modified with conversion to C-array format. pub const ENGLISH2HINDI: (&'static str, &'static str) = ( "marian-mt-en-hi/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const HEBREW2ENGLISH: (&'static str, &'static str) = ( "marian-mt-he-en/model", "https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const ENGLISH2HEBREW: (&'static str, &'static str) = ( "marian-mt-en-he/model", "https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/rust_model.ot", @@ -152,102 +152,102 @@ impl MarianModelResources { } impl MarianConfigResources { - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ( "marian-mt-en-ROMANCE/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ROMANCE-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2GERMAN: (&'static str, &'static str) = ( "marian-mt-en-de/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const GERMAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-de-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ( "marian-mt-en-ru/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ru-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const FRENCH2GERMAN: (&'static str, &'static str) = ( "marian-mt-fr-de/config", "https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const GERMAN2FRENCH: (&'static str, &'static str) = ( "marian-mt-de-fr/config", "https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2DUTCH: (&'static str, &'static str) = ( "marian-mt-en-nl/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const DUTCH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-nl-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const CHINESE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-zh-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2CHINESE: (&'static str, &'static str) = ( "marian-mt-en-zh/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2SWEDISH: (&'static str, &'static str) = ( "marian-mt-en-sv/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const SWEDISH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-sv-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ARABIC2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ar-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2ARABIC: (&'static str, &'static str) = ( "marian-mt-en-ar/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const HINDI2ENGLISH: (&'static str, &'static str) = ( "marian-mt-hi-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/config.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2HINDI: (&'static str, &'static str) = ( "marian-mt-en-hi/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/config.json", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const HEBREW2ENGLISH: (&'static str, &'static str) = ( "marian-mt-he-en/config", "https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/config.json", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const ENGLISH2HEBREW: (&'static str, &'static str) = ( "marian-mt-en-he/config", "https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/config.json", @@ -255,102 +255,102 @@ impl MarianConfigResources { } impl MarianVocabResources { - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ( "marian-mt-en-ROMANCE/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ROMANCE-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2GERMAN: (&'static str, &'static str) = ( "marian-mt-en-de/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const GERMAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-de-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ( "marian-mt-en-ru/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ru-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const FRENCH2GERMAN: (&'static str, &'static str) = ( "marian-mt-fr-de/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const GERMAN2FRENCH: (&'static str, &'static str) = ( "marian-mt-de-fr/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2DUTCH: (&'static str, &'static str) = ( "marian-mt-en-nl/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const DUTCH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-nl-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const CHINESE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-zh-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2CHINESE: (&'static str, &'static str) = ( "marian-mt-en-zh/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2SWEDISH: (&'static str, &'static str) = ( "marian-mt-en-sv/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const SWEDISH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-sv-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ARABIC2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ar-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2ARABIC: (&'static str, &'static str) = ( "marian-mt-en-ar/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const HINDI2ENGLISH: (&'static str, &'static str) = ( "marian-mt-hi-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/vocab.json", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2HINDI: (&'static str, &'static str) = ( "marian-mt-en-hi/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/vocab.json", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const HEBREW2ENGLISH: (&'static str, &'static str) = ( "marian-mt-he-en/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/vocab.json", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const ENGLISH2HEBREW: (&'static str, &'static str) = ( "marian-mt-en-he/vocab", "https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/vocab.json", @@ -358,102 +358,102 @@ impl MarianVocabResources { } impl MarianSpmResources { - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ( "marian-mt-en-ROMANCE/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ROMANCE-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2GERMAN: (&'static str, &'static str) = ( "marian-mt-en-de/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const GERMAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-de-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ( "marian-mt-en-ru/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ru/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ru-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-ru-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const FRENCH2GERMAN: (&'static str, &'static str) = ( "marian-mt-fr-de/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-fr-de/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const GERMAN2FRENCH: (&'static str, &'static str) = ( "marian-mt-de-fr/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-de-fr/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2DUTCH: (&'static str, &'static str) = ( "marian-mt-en-nl/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-nl/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const DUTCH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-nl-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-nl-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const CHINESE2ENGLISH: (&'static str, &'static str) = ( "marian-mt-zh-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2CHINESE: (&'static str, &'static str) = ( "marian-mt-en-zh/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-zh/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2SWEDISH: (&'static str, &'static str) = ( "marian-mt-en-sv/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-sv/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const SWEDISH2ENGLISH: (&'static str, &'static str) = ( "marian-mt-sv-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-sv-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ARABIC2ENGLISH: (&'static str, &'static str) = ( "marian-mt-ar-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-ar-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2ARABIC: (&'static str, &'static str) = ( "marian-mt-en-ar/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-ar/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const HINDI2ENGLISH: (&'static str, &'static str) = ( "marian-mt-hi-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-hi-en/resolve/main/source.spm", ); - /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. + /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . pub const ENGLISH2HINDI: (&'static str, &'static str) = ( "marian-mt-en-hi/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-hi/resolve/main/source.spm", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-he-en. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const HEBREW2ENGLISH: (&'static str, &'static str) = ( "marian-mt-he-en/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-he-en/resolve/main/source.spm", ); - /// Shared under Apache 2.0 License license at https://huggingface.co/tiedeman/opus-mt-en-he. Modified with conversion to C-array format. + /// Shared under Apache 2.0 License license at . Modified with conversion to C-array format. pub const ENGLISH2HEBREW: (&'static str, &'static str) = ( "marian-mt-en-he/spiece", "https://huggingface.co/Helsinki-NLP/opus-mt-en-he/resolve/main/source.spm", @@ -615,7 +615,7 @@ impl MarianForConditionalGeneration { /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// @@ -717,7 +717,7 @@ impl LMHeadModel for MarianForConditionalGeneration { /// * `position_ids` - Unused for BART /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// /// diff --git a/src/marian/mod.rs b/src/marian/mod.rs index 3e66438..1f55218 100644 --- a/src/marian/mod.rs +++ b/src/marian/mod.rs @@ -12,7 +12,7 @@ //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. //! - `MarianTokenizer` using a `vocab.json` vocabulary and `spiece.model` sentence piece model //! -//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. +//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at . //! //! ```no_run //! # fn main() -> anyhow::Result<()> { diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index 072057b..9c2584e 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -44,7 +44,7 @@ pub struct MBartConfigResources; pub struct MBartVocabResources; impl MBartModelResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const MBART50_MANY_TO_MANY: (&'static str, &'static str) = ( "mbart-50-many-to-many-mmt/model", "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/rust_model.ot", @@ -52,7 +52,7 @@ impl MBartModelResources { } impl MBartConfigResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const MBART50_MANY_TO_MANY: (&'static str, &'static str) = ( "mbart-50-many-to-many-mmt/config", "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json", @@ -60,7 +60,7 @@ impl MBartConfigResources { } impl MBartVocabResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const MBART50_MANY_TO_MANY: (&'static str, &'static str) = ( "mbart-50-many-to-many-mmt/vocab", "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model", @@ -107,7 +107,7 @@ pub struct MBartConfig { pub static_position_embeddings: Option, } -impl Config for MBartConfig {} +impl Config for MBartConfig {} fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id); @@ -247,7 +247,7 @@ impl MBartModel { /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. @@ -423,7 +423,7 @@ impl MBartForConditionalGeneration { /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// @@ -571,7 +571,7 @@ impl MBartForSequenceClassification { /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// @@ -670,7 +670,7 @@ impl LMHeadModel for MBartForConditionalGeneration { /// # Arguments /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 /// * `input_embeds` - Unused for BART /// * `token_type_ids` - Unused for BART diff --git a/src/mbart/mod.rs b/src/mbart/mod.rs index 021bfd9..ab65fda 100644 --- a/src/mbart/mod.rs +++ b/src/mbart/mod.rs @@ -6,7 +6,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! The summarization capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`. +//! The translation capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. @@ -41,7 +41,7 @@ //! let tokenizer: MBart50Tokenizer = //! MBart50Tokenizer::from_file(vocab_path.to_str().unwrap(), false)?; //! let config = MBartConfig::from_file(config_path); -//! let bart_model = MBartModel::new(&vs.root(), &config); +//! let mbart_model = MBartModel::new(&vs.root(), &config); //! vs.load(weights_path)?; //! //! # Ok(()) diff --git a/src/mobilebert/mobilebert_model.rs b/src/mobilebert/mobilebert_model.rs index dcd665c..d116dc0 100644 --- a/src/mobilebert/mobilebert_model.rs +++ b/src/mobilebert/mobilebert_model.rs @@ -31,12 +31,12 @@ pub struct MobileBertConfigResources; pub struct MobileBertVocabResources; impl MobileBertModelResources { - /// Shared under Apache 2.0 license by the Google team at https://huggingface.co/google/mobilebert-uncased. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = ( "mobilebert-uncased/model", "https://huggingface.co/google/mobilebert-uncased/resolve/main/rust_model.ot", ); - /// Shared under MIT license at https://huggingface.co/mrm8488/mobilebert-finetuned-pos. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = ( "mobilebert-finetuned-pos/model", "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/rust_model.ot", @@ -44,12 +44,12 @@ impl MobileBertModelResources { } impl MobileBertConfigResources { - /// Shared under Apache 2.0 license by the Google team at https://huggingface.co/google/mobilebert-uncased. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = ( "mobilebert-uncased/config", "https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json", ); - /// Shared under MIT license at https://huggingface.co/mrm8488/mobilebert-finetuned-pos. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = ( "mobilebert-finetuned-pos/config", "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/config.json", @@ -57,12 +57,12 @@ impl MobileBertConfigResources { } impl MobileBertVocabResources { - /// Shared under Apache 2.0 license by the Google team at https://huggingface.co/google/mobilebert-uncased. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Google team at . Modified with conversion to C-array format. pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = ( "mobilebert-uncased/vocab", "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt", ); - /// Shared under MIT license at https://huggingface.co/mrm8488/mobilebert-finetuned-pos. Modified with conversion to C-array format. + /// Shared under MIT license at . Modified with conversion to C-array format. pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = ( "mobilebert-finetuned-pos/vocab", "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/vocab.txt", @@ -194,7 +194,7 @@ pub struct MobileBertConfig { pub label2id: Option>, } -impl Config for MobileBertConfig {} +impl Config for MobileBertConfig {} pub struct MobileBertPredictionHeadTransform { dense: nn::Linear, @@ -319,7 +319,7 @@ impl MobileBertModel { /// /// * `p` - Variable store path for the root of the MobileBERT model /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status - /// * `add_poling_layer` - boolean flag indicating if a pooling layer shuld be added after the encoder + /// * `add_poling_layer` - boolean flag indicating if a pooling layer should be added after the encoder /// /// # Example /// diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index bb97132..332df33 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -45,7 +45,7 @@ pub struct OpenAiGptVocabResources; pub struct OpenAiGptMergesResources; impl OpenAiGptModelResources { - /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. + /// Shared under MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT: (&'static str, &'static str) = ( "openai-gpt/model", "https://huggingface.co/openai-gpt/resolve/main/rust_model.ot", @@ -53,7 +53,7 @@ impl OpenAiGptModelResources { } impl OpenAiGptConfigResources { - /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. + /// Shared under MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT: (&'static str, &'static str) = ( "openai-gpt/config", "https://huggingface.co/openai-gpt/resolve/main/config.json", @@ -61,7 +61,7 @@ impl OpenAiGptConfigResources { } impl OpenAiGptVocabResources { - /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. + /// Shared under MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT: (&'static str, &'static str) = ( "openai-gpt/vocab", "https://huggingface.co/openai-gpt/resolve/main/vocab.json", @@ -69,7 +69,7 @@ impl OpenAiGptVocabResources { } impl OpenAiGptMergesResources { - /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. + /// Shared under MIT license by the OpenAI team at . Modified with conversion to C-array format. pub const GPT: (&'static str, &'static str) = ( "openai-gpt/merges", "https://huggingface.co/openai-gpt/resolve/main/merges.txt", diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index 8e0e4e1..4b9df89 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -41,7 +41,7 @@ pub struct PegasusConfigResources; pub struct PegasusVocabResources; impl PegasusModelResources { - /// Shared under Apache 2.0 license by the Pegasus team at https://huggingface.co/google/pegasus-cnn_dailymail. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Pegasus team at . Modified with conversion to C-array format. pub const CNN_DAILYMAIL: (&'static str, &'static str) = ( "pegasus-cnn_dailymail/model", "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/rust_model.ot", @@ -49,7 +49,7 @@ impl PegasusModelResources { } impl PegasusConfigResources { - /// Shared under Apache 2.0 license by the Pegasus team at https://huggingface.co/google/pegasus-cnn_dailymail. + /// Shared under Apache 2.0 license by the Pegasus team at . pub const CNN_DAILYMAIL: (&'static str, &'static str) = ( "pegasus-cnn_dailymail/config", "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/config.json", @@ -57,7 +57,7 @@ impl PegasusConfigResources { } impl PegasusVocabResources { - /// Shared under Apache 2.0 license by the Pegasus team at https://huggingface.co/google/pegasus-cnn_dailymail. + /// Shared under Apache 2.0 license by the Pegasus team at . pub const CNN_DAILYMAIL: (&'static str, &'static str) = ( "pegasus-cnn_dailymail/spiece", "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/spiece.model", @@ -156,7 +156,7 @@ impl PegasusModel { /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. @@ -329,7 +329,7 @@ impl PegasusForConditionalGeneration { /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// @@ -437,7 +437,7 @@ impl LMHeadModel for PegasusForConditionalGeneration { /// # Arguments /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 /// * `input_embeds` - Unused for Pegasus /// * `token_type_ids` - Unused for Pegasus diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index 9f6963d..ddc6c20 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -25,6 +25,7 @@ use crate::electra::ElectraConfig; use crate::gpt2::Gpt2Config; use crate::gpt_neo::GptNeoConfig; use crate::longformer::LongformerConfig; +use crate::m2m_100::M2M100Config; use crate::mbart::MBartConfig; use crate::mobilebert::MobileBertConfig; use crate::pegasus::PegasusConfig; @@ -34,14 +35,15 @@ use crate::t5::T5Config; use crate::xlnet::XLNetConfig; use crate::Config; use rust_tokenizers::tokenizer::{ - AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, MBart50Tokenizer, MarianTokenizer, - MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, - ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer, TruncationStrategy, - XLMRobertaTokenizer, XLNetTokenizer, + AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer, + MarianTokenizer, MultiThreadedTokenizer, OpenAiGptTokenizer, PegasusTokenizer, + ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer, Tokenizer, + TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer, }; use rust_tokenizers::vocab::{ - AlbertVocab, BertVocab, Gpt2Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab, PegasusVocab, - ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab, + AlbertVocab, BertVocab, Gpt2Vocab, M2M100Vocab, MBart50Vocab, MarianVocab, OpenAiGptVocab, + PegasusVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, + XLNetVocab, }; use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets}; use serde::{Deserialize, Serialize}; @@ -70,6 +72,7 @@ pub enum ModelType { Pegasus, GPTNeo, MBart, + M2M100, } /// # Abstraction that holds a model configuration, can be of any of the supported models @@ -106,6 +109,8 @@ pub enum ConfigOption { GPTNeo(GptNeoConfig), /// MBart configuration MBart(MBartConfig), + /// M2M100 configuration + M2M100(M2M100Config), } /// # Abstraction that holds a particular tokenizer, can be of any of the supported models @@ -136,6 +141,8 @@ pub enum TokenizerOption { Pegasus(PegasusTokenizer), /// MBart50 Tokenizer MBart50(MBart50Tokenizer), + /// M2M100 Tokenizer + M2M100(M2M100Tokenizer), } impl ConfigOption { @@ -160,6 +167,7 @@ impl ConfigOption { ModelType::Longformer => ConfigOption::Longformer(LongformerConfig::from_file(path)), ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)), ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)), + ModelType::M2M100 => ConfigOption::M2M100(M2M100Config::from_file(path)), } } @@ -201,6 +209,9 @@ impl ConfigOption { Self::MBart(config) => config .id2label .expect("No label dictionary (id2label) provided in configuration file"), + Self::M2M100(config) => config + .id2label + .expect("No label dictionary (id2label) provided in configuration file"), Self::T5(_) => panic!("T5 does not use a label mapping"), Self::GPT2(_) => panic!("GPT2 does not use a label mapping"), Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"), @@ -403,6 +414,26 @@ impl TokenizerOption { } TokenizerOption::MBart50(MBart50Tokenizer::from_file(vocab_path, lower_case)?) } + ModelType::M2M100 => { + if add_prefix_space.is_some() { + return Err(RustBertError::InvalidConfigurationError( + format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}", + add_prefix_space.unwrap(), + model_type))); + } + if strip_accents.is_some() { + return Err(RustBertError::InvalidConfigurationError(format!( + "Optional input `strip_accents` set to value {} but cannot be used by {:?}", + strip_accents.unwrap(), + model_type + ))); + } + TokenizerOption::M2M100(M2M100Tokenizer::from_files( + vocab_path, + merges_path.expect("No merges specified!"), + lower_case, + )?) + } }; Ok(tokenizer) } @@ -423,6 +454,7 @@ impl TokenizerOption { Self::ProphetNet(_) => ModelType::ProphetNet, Self::Pegasus(_) => ModelType::Pegasus, Self::MBart50(_) => ModelType::MBart, + Self::M2M100(_) => ModelType::M2M100, } } @@ -526,6 +558,13 @@ impl TokenizerOption { truncation_strategy, stride, ), + Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_list( + tokenizer, + text_list, + max_len, + truncation_strategy, + stride, + ), } } @@ -629,6 +668,13 @@ impl TokenizerOption { truncation_strategy, stride, ), + Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list( + tokenizer, + text_pair_list, + max_len, + truncation_strategy, + stride, + ), } } @@ -681,6 +727,9 @@ impl TokenizerOption { Self::MBart50(ref tokenizer) => { tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride) } + Self::M2M100(ref tokenizer) => { + tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride) + } } } @@ -700,6 +749,7 @@ impl TokenizerOption { Self::ProphetNet(ref tokenizer) => tokenizer.tokenize(text), Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text), Self::MBart50(ref tokenizer) => tokenizer.tokenize(text), + Self::M2M100(ref tokenizer) => tokenizer.tokenize(text), } } @@ -719,6 +769,7 @@ impl TokenizerOption { Self::ProphetNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text), Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text), Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text), + Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text), } } @@ -744,6 +795,7 @@ impl TokenizerOption { } Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), + Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), } } @@ -794,6 +846,9 @@ impl TokenizerOption { Self::MBart50(ref tokenizer) => { tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces) } + Self::M2M100(ref tokenizer) => { + tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces) + } } } @@ -856,6 +911,10 @@ impl TokenizerOption { token_ids_with_offsets_1, token_ids_with_offsets_2, ), + Self::M2M100(ref tokenizer) => tokenizer.build_input_with_special_tokens( + token_ids_with_offsets_1, + token_ids_with_offsets_2, + ), }; TokenizedInput { token_ids: token_ids_with_special_tokens.token_ids, @@ -889,6 +948,7 @@ impl TokenizerOption { Self::ProphetNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), } } @@ -947,6 +1007,10 @@ impl TokenizerOption { .special_values .get(MBart50Vocab::unknown_value()) .expect("UNK token not found in vocabulary"), + Self::M2M100(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(M2M100Vocab::unknown_value()) + .expect("UNK token not found in vocabulary"), } } @@ -1013,6 +1077,12 @@ impl TokenizerOption { .get(MBart50Vocab::pad_value()) .expect("PAD token not found in vocabulary"), ), + Self::M2M100(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(M2M100Vocab::pad_value()) + .expect("PAD token not found in vocabulary"), + ), Self::Reformer(_) => None, Self::GPT2(_) => None, Self::OpenAiGpt(_) => None, @@ -1064,6 +1134,12 @@ impl TokenizerOption { .get(MBart50Vocab::sep_value()) .expect("SEP token not found in vocabulary"), ), + Self::M2M100(ref tokenizer) => Some( + *MultiThreadedTokenizer::vocab(tokenizer) + .special_values + .get(M2M100Vocab::sep_value()) + .expect("SEP token not found in vocabulary"), + ), Self::Marian(_) => None, Self::T5(_) => None, Self::GPT2(_) => None, diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index a97a7ae..1d7a3ae 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -622,7 +622,7 @@ impl ConversationManager { /// /// # Returns /// - /// * `Option` deregistered conversation + /// * `Option` de-registered conversation /// /// # Example /// @@ -643,7 +643,7 @@ impl ConversationManager { /// /// # Returns /// - /// * `HashMap` deregistered conversations + /// * `HashMap` de-registered conversations /// /// # Example /// diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index 1c52fd8..84041b8 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -42,12 +42,12 @@ pub struct ProphetNetConfigResources; pub struct ProphetNetVocabResources; impl ProphetNetModelResources { - /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = ( "prophetnet-large-uncased/model", "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/rust_model.ot", ); - /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = ( "prophetnet-large-uncased-cnndm/model", "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/rust_model.ot", @@ -55,12 +55,12 @@ impl ProphetNetModelResources { } impl ProphetNetConfigResources { - /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = ( "prophetnet-large-uncased/config", "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json", ); - /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = ( "prophetnet-large-uncased-cnndm/config", "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/config.json", @@ -68,12 +68,12 @@ impl ProphetNetConfigResources { } impl ProphetNetVocabResources { - /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = ( "prophetnet-large-uncased/vocab", "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer", ); - /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = ( "prophetnet-large-uncased-cnndm/vocab", "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/prophetnet.tokenizer", @@ -120,7 +120,7 @@ pub struct ProphetNetConfig { pub add_cross_attention: Option, } -impl Config for ProphetNetConfig {} +impl Config for ProphetNetConfig {} /// # ProphetNet Base model /// Base architecture for ProphetNet models. Task-specific models will be built from this common base model @@ -190,7 +190,7 @@ impl ProphetNetModel { /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided. /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation) /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. @@ -393,7 +393,7 @@ impl ProphetNetForConditionalGeneration { /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided. /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation) /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. @@ -693,7 +693,7 @@ impl ProphetNetForCausalGeneration { /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided. /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialiazed with a BOS token) + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `old_layer_states` - Optional Vector `Option, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder. /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index ba74aba..a034566 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -47,7 +47,7 @@ pub struct ReformerConfigResources; pub struct ReformerVocabResources; impl ReformerModelResources { - /// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Trax Authors at . Modified with conversion to C-array format. pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = ( "reformer-crime-punishment/model", "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/rust_model.ot", @@ -55,7 +55,7 @@ impl ReformerModelResources { } impl ReformerConfigResources { - /// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Trax Authors at . Modified with conversion to C-array format. pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = ( "reformer-crime-punishment/config", "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json", @@ -63,7 +63,7 @@ impl ReformerConfigResources { } impl ReformerVocabResources { - /// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Trax Authors at . Modified with conversion to C-array format. pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = ( "reformer-crime-punishment/spiece", "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model", @@ -115,7 +115,7 @@ pub struct ReformerConfig { pub output_hidden_states: Option, } -impl Config for ReformerConfig {} +impl Config for ReformerConfig {} pub struct ReformerLMHead { decoder: nn::Linear, diff --git a/src/roberta/roberta_model.rs b/src/roberta/roberta_model.rs index 16912e2..48c1ad3 100644 --- a/src/roberta/roberta_model.rs +++ b/src/roberta/roberta_model.rs @@ -33,37 +33,37 @@ pub struct RobertaVocabResources; pub struct RobertaMergesResources; impl RobertaModelResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const ROBERTA: (&'static str, &'static str) = ( "roberta/model", "https://huggingface.co/roberta-base/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face Inc. team at . Modified with conversion to C-array format. pub const DISTILROBERTA_BASE: (&'static str, &'static str) = ( "distilroberta-base/model", "https://cdn.huggingface.co/distilroberta-base-rust_model.ot", ); - /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at . Modified with conversion to C-array format. pub const ROBERTA_QA: (&'static str, &'static str) = ( "roberta-qa/model", "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = ( "xlm-roberta-ner-en/model", "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = ( "xlm-roberta-ner-de/model", "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = ( "xlm-roberta-ner-nl/model", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = ( "xlm-roberta-ner-es/model", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot", @@ -71,37 +71,37 @@ impl RobertaModelResources { } impl RobertaConfigResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const ROBERTA: (&'static str, &'static str) = ( "roberta/config", "https://huggingface.co/roberta-base/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face Inc. team at . Modified with conversion to C-array format. pub const DISTILROBERTA_BASE: (&'static str, &'static str) = ( "distilroberta-base/config", "https://cdn.huggingface.co/distilroberta-base-config.json", ); - /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at . Modified with conversion to C-array format. pub const ROBERTA_QA: (&'static str, &'static str) = ( "roberta-qa/config", "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = ( "xlm-roberta-ner-en/config", "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = ( "xlm-roberta-ner-de/config", "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = ( "xlm-roberta-ner-nl/config", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = ( "xlm-roberta-ner-es/config", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json", @@ -109,37 +109,37 @@ impl RobertaConfigResources { } impl RobertaVocabResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const ROBERTA: (&'static str, &'static str) = ( "roberta/vocab", "https://huggingface.co/roberta-base/resolve/main/vocab.json", ); - /// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face Inc. team at . Modified with conversion to C-array format. pub const DISTILROBERTA_BASE: (&'static str, &'static str) = ( "distilroberta-base/vocab", "https://cdn.huggingface.co/distilroberta-base-vocab.json", ); - /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at . Modified with conversion to C-array format. pub const ROBERTA_QA: (&'static str, &'static str) = ( "roberta-qa/vocab", "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/vocab.json", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = ( "xlm-roberta-ner-en/spiece", "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = ( "xlm-roberta-ner-de/spiece", "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = ( "xlm-roberta-ner-nl/spiece", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model", ); - /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = ( "xlm-roberta-ner-es/spiece", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model", @@ -147,17 +147,17 @@ impl RobertaVocabResources { } impl RobertaMergesResources { - /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. + /// Shared under MIT license by the Facebook AI Research Fairseq team at . Modified with conversion to C-array format. pub const ROBERTA: (&'static str, &'static str) = ( "roberta/merges", "https://huggingface.co/roberta-base/resolve/main/merges.txt", ); - /// Shared under Apache 2.0 license by the Hugging Face Inc. team at https://huggingface.co/distilroberta-base. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the Hugging Face Inc. team at . Modified with conversion to C-array format. pub const DISTILROBERTA_BASE: (&'static str, &'static str) = ( "distilroberta-base/merges", "https://cdn.huggingface.co/distilroberta-base-merges.txt", ); - /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at . Modified with conversion to C-array format. pub const ROBERTA_QA: (&'static str, &'static str) = ( "roberta-qa/merges", "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt", diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 943d424..f5e75cc 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -50,12 +50,12 @@ pub struct T5SourceLanguages; pub type T5TargetLanguages = T5SourceLanguages; impl T5ModelResources { - /// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the T5 Authors at . Modified with conversion to C-array format. pub const T5_SMALL: (&'static str, &'static str) = ( "t5-small/model", "https://huggingface.co/t5-small/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the T5 Authors at . Modified with conversion to C-array format. pub const T5_BASE: (&'static str, &'static str) = ( "t5-base/model", "https://huggingface.co/t5-base/resolve/main/rust_model.ot", @@ -63,12 +63,12 @@ impl T5ModelResources { } impl T5ConfigResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. + /// Shared under Apache 2.0 license by the Google team at . pub const T5_SMALL: (&'static str, &'static str) = ( "t5-small/config", "https://huggingface.co/t5-small/resolve/main/config.json", ); - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. + /// Shared under Apache 2.0 license by the Google team at . pub const T5_BASE: (&'static str, &'static str) = ( "t5-base/config", "https://huggingface.co/t5-base/resolve/main/config.json", @@ -76,12 +76,12 @@ impl T5ConfigResources { } impl T5VocabResources { - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. + /// Shared under Apache 2.0 license by the Google team at . pub const T5_SMALL: (&'static str, &'static str) = ( "t5-small/spiece", "https://huggingface.co/t5-small/resolve/main/spiece.model", ); - /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. + /// Shared under Apache 2.0 license by the Google team at . pub const T5_BASE: (&'static str, &'static str) = ( "t5-base/spiece", "https://huggingface.co/t5-base/resolve/main/spiece.model", @@ -172,7 +172,7 @@ pub struct TranslationEnToRo { prefix: String, } -impl Config for T5Config {} +impl Config for T5Config {} /// # T5 Base model /// Base architecture for T5 model. Usually complemented with a task-specific head, such as a language model head. @@ -464,7 +464,7 @@ impl T5ForConditionalGeneration { /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided. - /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. /// /// # Returns @@ -572,7 +572,7 @@ impl LMHeadModel for T5ForConditionalGeneration { /// # Arguments /// /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 /// * `input_embeds` - Unused for T5 /// * `token_type_ids` - Unused for T5 diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 5f76076..8fab78a 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -43,7 +43,7 @@ pub struct XLNetConfigResources; pub struct XLNetVocabResources; impl XLNetModelResources { - /// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the XLNet Authors at . Modified with conversion to C-array format. pub const XLNET_BASE_CASED: (&'static str, &'static str) = ( "xlnet-base-cased/model", "https://huggingface.co/xlnet-base-cased/resolve/main/rust_model.ot", @@ -51,7 +51,7 @@ impl XLNetModelResources { } impl XLNetConfigResources { - /// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the XLNet Authors at . Modified with conversion to C-array format. pub const XLNET_BASE_CASED: (&'static str, &'static str) = ( "xlnet-base-cased/config", "https://huggingface.co/xlnet-base-cased/resolve/main/config.json", @@ -59,7 +59,7 @@ impl XLNetConfigResources { } impl XLNetVocabResources { - /// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the XLNet Authors at . Modified with conversion to C-array format. pub const XLNET_BASE_CASED: (&'static str, &'static str) = ( "xlnet-base-cased/spiece", "https://huggingface.co/xlnet-base-cased/resolve/main/spiece.model", @@ -116,7 +116,7 @@ pub struct XLNetConfig { pub chunk_size_feed_forward: Option, } -impl Config for XLNetConfig {} +impl Config for XLNetConfig {} /// # XLNet Base model /// Base architecture for XLNet models. Task-specific models will be built from this common base model diff --git a/tests/m2m100.rs b/tests/m2m100.rs new file mode 100644 index 0000000..60c2cc4 --- /dev/null +++ b/tests/m2m100.rs @@ -0,0 +1,117 @@ +use rust_bert::m2m_100::{ + M2M100Config, M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100Model, + M2M100ModelResources, M2M100VocabResources, +}; +use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::resources::{RemoteResource, Resource}; +use rust_bert::Config; +use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy}; +use tch::{nn, Device, Tensor}; + +#[test] +fn m2m100_lm_model() -> anyhow::Result<()> { + // Resources paths + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )); + let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )); + let weights_resource = Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )); + let config_path = config_resource.get_local_path()?; + let vocab_path = vocab_resource.get_local_path()?; + let merges_path = merges_resource.get_local_path()?; + let weights_path = weights_resource.get_local_path()?; + + // Set-up masked LM model + let device = Device::Cpu; + let mut vs = nn::VarStore::new(device); + let tokenizer = M2M100Tokenizer::from_files( + vocab_path.to_str().unwrap(), + merges_path.to_str().unwrap(), + false, + )?; + let config = M2M100Config::from_file(config_path); + let m2m100_model = M2M100Model::new(&vs.root() / "model", &config); + vs.load(weights_path)?; + + // Define input + let input = ["One two three four"]; + let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); + let max_len = tokenized_input + .iter() + .map(|input| input.token_ids.len()) + .max() + .unwrap(); + let tokenized_input = tokenized_input + .iter() + .map(|input| input.token_ids.clone()) + .map(|mut input| { + input.extend(vec![0; max_len - input.len()]); + input + }) + .map(|input| Tensor::of_slice(&(input))) + .collect::>(); + let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); + + // Forward pass + let model_output = + m2m100_model.forward_t(Some(&input_tensor), None, None, None, None, None, false); + assert_eq!(model_output.decoder_output.size(), vec!(1, 5, 1024)); + assert_eq!( + model_output.encoder_hidden_state.unwrap().size(), + vec!(1, 5, 1024) + ); + assert!( + (model_output.decoder_output.double_value(&[0, 0, 0]) - -2.047429323196411).abs() < 1e-4 + ); + Ok(()) +} + +#[test] +fn m2m100_translation() -> anyhow::Result<()> { + // Resources paths + let generate_config = GenerateConfig { + max_length: 56, + model_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )), + config_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )), + merges_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )), + do_sample: false, + num_beams: 3, + ..Default::default() + }; + let model = M2M100Generator::new(generate_config)?; + + let input_context = ">>en.<< The dog did not wake up."; + let target_language = model.get_tokenizer().convert_tokens_to_ids([">>es.<<"])[0]; + + let output = model.generate( + Some(&[input_context]), + None, + None, + None, + None, + target_language, + None, + false, + ); + + assert_eq!(output.len(), 1); + assert_eq!(output[0].text, ">>es.<< El perro no se despertó."); + + Ok(()) +}