diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs new file mode 100644 index 0000000..97ee4d0 --- /dev/null +++ b/examples/translation_m2m100.rs @@ -0,0 +1,62 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::m2m_100::{ + M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources, + M2M100VocabResources, +}; +use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::resources::{RemoteResource, Resource}; + +fn main() -> anyhow::Result<()> { + let generate_config = GenerateConfig { + max_length: 142, + model_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100ModelResources::M2M100_418M, + )), + config_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100ConfigResources::M2M100_418M, + )), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100VocabResources::M2M100_418M, + )), + merges_resource: Resource::Remote(RemoteResource::from_pretrained( + M2M100MergesResources::M2M100_418M, + )), + do_sample: false, + early_stopping: true, + num_beams: 3, + ..Default::default() + }; + + let model = M2M100Generator::new(generate_config)?; + + let input_context_1 = ">>en.<< The quick brown fox jumps over the lazy dog."; + let target_language = model.get_tokenizer().convert_tokens_to_ids([">>de.<<"])[0]; + + let output = model.generate( + Some(&[input_context_1]), + None, + None, + None, + None, + target_language, + None, + ); + + for sentence in output { + println!("{:?}", sentence); + } + Ok(()) +} diff --git a/src/m2m_100/embeddings.rs b/src/m2m_100/embeddings.rs index 0700cc8..24b3988 100644 --- a/src/m2m_100/embeddings.rs +++ b/src/m2m_100/embeddings.rs @@ -46,7 +46,7 @@ impl SinusoidalPositionalEmbedding { )); embedding.write().unwrap().ws = SinusoidalPositionalEmbedding::build_positional_embeddings( - num_embeddings, + num_embeddings + offset, embedding_dim, padding_idx, device, @@ -69,7 +69,7 @@ impl SinusoidalPositionalEmbedding { ) -> Tensor { let half_dim = embedding_dim / 2; - let emb = -(10000f64.log(2.0) as f64) / ((half_dim - 1) as f64); + let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64); let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp(); let emb = Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0); diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index 536d3b0..daa28f3 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -114,7 +114,7 @@ impl M2M100Model { /// /// # Arguments /// - /// * `p` - Variable store path for the root of the MBart model + /// * `p` - Variable store path for the root of the M2M100 model /// * `config` - `M2M100Config` object defining the model architecture /// /// # Example @@ -129,9 +129,9 @@ impl M2M100Model { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = M2M100Config::from_file(config_path); - /// let mbart: M2M100Model = M2M100Model::new(&p.root() / "bart", &config); + /// let m2m100: M2M100Model = M2M100Model::new(&p.root() / "m2m100", &config); /// ``` - pub fn new<'p, P>(p: P, config: &MBartConfig) -> M2M100Model + pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Model where P: Borrow>, { @@ -305,7 +305,7 @@ impl M2M100ForConditionalGeneration { /// /// # Arguments /// - /// * `p` - Variable store path for the root of the BART model + /// * `p` - Variable store path for the root of the M2M100 model /// * `config` - `M2M100Config` object defining the model architecture /// /// # Example @@ -320,7 +320,7 @@ impl M2M100ForConditionalGeneration { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = M2M100Config::from_file(config_path); - /// let bart: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&p.root(), &config); + /// let m2m100: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&p.root(), &config); /// ``` pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100ForConditionalGeneration where @@ -360,13 +360,13 @@ impl M2M100ForConditionalGeneration { /// # use rust_bert::Config; /// # use std::path::Path; /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration}; + /// # use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration}; /// # let config_path = Path::new("path/to/config.json"); /// # let vocab_path = Path::new("path/to/vocab.txt"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); - /// # let config = BartConfig::from_file(config_path); - /// # let m2m100_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config); + /// # let config = M2M100Config::from_file(config_path); + /// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -434,9 +434,9 @@ impl LMHeadModel for M2M100ForConditionalGeneration { /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for BART - /// * `token_type_ids` - Unused for BART - /// * `position_ids` - Unused for BART + /// * `input_embeds` - Unused for M2M100 + /// * `token_type_ids` - Unused for M2M100 + /// * `position_ids` - Unused for M2M100 /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. @@ -445,7 +445,7 @@ impl LMHeadModel for M2M100ForConditionalGeneration { /// /// * `LMModelOutput` containing: /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `BartCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for + /// - `cache` - `BARTCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for /// both the self attention and the encoder cross attention of each layer of the decoder. /// /// # Example @@ -456,13 +456,13 @@ impl LMHeadModel for M2M100ForConditionalGeneration { /// # use std::path::Path; /// # use tch::kind::Kind::{Int64, Double}; /// use rust_bert::pipelines::generation_utils::LMHeadModel; - /// use rust_bert::bart::{BartForConditionalGeneration, BartConfig}; + /// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config}; /// # let config_path = Path::new("path/to/config.json"); /// # let vocab_path = Path::new("path/to/vocab.txt"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); - /// # let config = BartConfig::from_file(config_path); - /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config); + /// # let config = M2M100Config::from_file(config_path); + /// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -470,7 +470,7 @@ impl LMHeadModel for M2M100ForConditionalGeneration { /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// /// let model_output = no_grad(|| { - /// bart_model + /// m2m100_model /// .forward_t(Some(&input_tensor), /// Some(&encoder_attention_mask), /// None, @@ -615,14 +615,14 @@ impl M2M100Generator { generate_config.vocab_resource.clone() }; - let merges_resource = if generate_config.vocab_resource + let merges_resource = if generate_config.merges_resource == Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)) { Resource::Remote(RemoteResource::from_pretrained( M2M100MergesResources::M2M100_418M, )) } else { - generate_config.vocab_resource.clone() + generate_config.merges_resource.clone() }; let config_path = config_resource.get_local_path()?; @@ -835,7 +835,7 @@ impl PrivateLanguageGenerator {} _ => { - panic!("Invalid cache for BART model"); + panic!("Invalid cache for M2M100 model"); } }; encoder_outputs