Working example for M2M100 Translation

This commit is contained in:
Guillaume B 2021-06-26 10:49:42 +02:00
parent c71df2be5c
commit 9a04d1527a
3 changed files with 83 additions and 21 deletions

View File

@ -0,0 +1,62 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::m2m_100::{
M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources,
M2M100VocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: 142,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
do_sample: false,
early_stopping: true,
num_beams: 3,
..Default::default()
};
let model = M2M100Generator::new(generate_config)?;
let input_context_1 = ">>en.<< The quick brown fox jumps over the lazy dog.";
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>de.<<"])[0];
let output = model.generate(
Some(&[input_context_1]),
None,
None,
None,
None,
target_language,
None,
);
for sentence in output {
println!("{:?}", sentence);
}
Ok(())
}

View File

@ -46,7 +46,7 @@ impl SinusoidalPositionalEmbedding {
));
embedding.write().unwrap().ws = SinusoidalPositionalEmbedding::build_positional_embeddings(
num_embeddings,
num_embeddings + offset,
embedding_dim,
padding_idx,
device,
@ -69,7 +69,7 @@ impl SinusoidalPositionalEmbedding {
) -> Tensor {
let half_dim = embedding_dim / 2;
let emb = -(10000f64.log(2.0) as f64) / ((half_dim - 1) as f64);
let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64);
let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp();
let emb =
Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0);

View File

@ -114,7 +114,7 @@ impl M2M100Model {
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the MBart model
/// * `p` - Variable store path for the root of the M2M100 model
/// * `config` - `M2M100Config` object defining the model architecture
///
/// # Example
@ -129,9 +129,9 @@ impl M2M100Model {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = M2M100Config::from_file(config_path);
/// let mbart: M2M100Model = M2M100Model::new(&p.root() / "bart", &config);
/// let m2m100: M2M100Model = M2M100Model::new(&p.root() / "m2m100", &config);
/// ```
pub fn new<'p, P>(p: P, config: &MBartConfig) -> M2M100Model
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Model
where
P: Borrow<nn::Path<'p>>,
{
@ -305,7 +305,7 @@ impl M2M100ForConditionalGeneration {
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BART model
/// * `p` - Variable store path for the root of the M2M100 model
/// * `config` - `M2M100Config` object defining the model architecture
///
/// # Example
@ -320,7 +320,7 @@ impl M2M100ForConditionalGeneration {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = M2M100Config::from_file(config_path);
/// let bart: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&p.root(), &config);
/// let m2m100: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100ForConditionalGeneration
where
@ -360,13 +360,13 @@ impl M2M100ForConditionalGeneration {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// # use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let m2m100_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -434,9 +434,9 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for BART
/// * `token_type_ids` - Unused for BART
/// * `position_ids` - Unused for BART
/// * `input_embeds` - Unused for M2M100
/// * `token_type_ids` - Unused for M2M100
/// * `position_ids` - Unused for M2M100
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -445,7 +445,7 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// - `cache` - `BARTCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
///
/// # Example
@ -456,13 +456,13 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation_utils::LMHeadModel;
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
/// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -470,7 +470,7 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// bart_model
/// m2m100_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
@ -615,14 +615,14 @@ impl M2M100Generator {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.vocab_resource
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
))
} else {
generate_config.vocab_resource.clone()
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
@ -835,7 +835,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
},
Cache::None => {}
_ => {
panic!("Invalid cache for BART model");
panic!("Invalid cache for M2M100 model");
}
};
encoder_outputs