mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Working example for M2M100 Translation
This commit is contained in:
parent
c71df2be5c
commit
9a04d1527a
62
examples/translation_m2m100.rs
Normal file
62
examples/translation_m2m100.rs
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::m2m_100::{
|
||||
M2M100ConfigResources, M2M100Generator, M2M100MergesResources, M2M100ModelResources,
|
||||
M2M100VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 142,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ModelResources::M2M100_418M,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100ConfigResources::M2M100_418M,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100VocabResources::M2M100_418M,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100MergesResources::M2M100_418M,
|
||||
)),
|
||||
do_sample: false,
|
||||
early_stopping: true,
|
||||
num_beams: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let model = M2M100Generator::new(generate_config)?;
|
||||
|
||||
let input_context_1 = ">>en.<< The quick brown fox jumps over the lazy dog.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids([">>de.<<"])[0];
|
||||
|
||||
let output = model.generate(
|
||||
Some(&[input_context_1]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
);
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -46,7 +46,7 @@ impl SinusoidalPositionalEmbedding {
|
||||
));
|
||||
|
||||
embedding.write().unwrap().ws = SinusoidalPositionalEmbedding::build_positional_embeddings(
|
||||
num_embeddings,
|
||||
num_embeddings + offset,
|
||||
embedding_dim,
|
||||
padding_idx,
|
||||
device,
|
||||
@ -69,7 +69,7 @@ impl SinusoidalPositionalEmbedding {
|
||||
) -> Tensor {
|
||||
let half_dim = embedding_dim / 2;
|
||||
|
||||
let emb = -(10000f64.log(2.0) as f64) / ((half_dim - 1) as f64);
|
||||
let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64);
|
||||
let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp();
|
||||
let emb =
|
||||
Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0);
|
||||
|
@ -114,7 +114,7 @@ impl M2M100Model {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the MBart model
|
||||
/// * `p` - Variable store path for the root of the M2M100 model
|
||||
/// * `config` - `M2M100Config` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
@ -129,9 +129,9 @@ impl M2M100Model {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = M2M100Config::from_file(config_path);
|
||||
/// let mbart: M2M100Model = M2M100Model::new(&p.root() / "bart", &config);
|
||||
/// let m2m100: M2M100Model = M2M100Model::new(&p.root() / "m2m100", &config);
|
||||
/// ```
|
||||
pub fn new<'p, P>(p: P, config: &MBartConfig) -> M2M100Model
|
||||
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Model
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
@ -305,7 +305,7 @@ impl M2M100ForConditionalGeneration {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the BART model
|
||||
/// * `p` - Variable store path for the root of the M2M100 model
|
||||
/// * `config` - `M2M100Config` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
@ -320,7 +320,7 @@ impl M2M100ForConditionalGeneration {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = M2M100Config::from_file(config_path);
|
||||
/// let bart: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&p.root(), &config);
|
||||
/// let m2m100: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&p.root(), &config);
|
||||
/// ```
|
||||
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100ForConditionalGeneration
|
||||
where
|
||||
@ -360,13 +360,13 @@ impl M2M100ForConditionalGeneration {
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||
/// # use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration};
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = BartConfig::from_file(config_path);
|
||||
/// # let m2m100_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// # let config = M2M100Config::from_file(config_path);
|
||||
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -434,9 +434,9 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
|
||||
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Unused for BART
|
||||
/// * `token_type_ids` - Unused for BART
|
||||
/// * `position_ids` - Unused for BART
|
||||
/// * `input_embeds` - Unused for M2M100
|
||||
/// * `token_type_ids` - Unused for M2M100
|
||||
/// * `position_ids` - Unused for M2M100
|
||||
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
|
||||
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
@ -445,7 +445,7 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
|
||||
///
|
||||
/// * `LMModelOutput` containing:
|
||||
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
|
||||
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
|
||||
/// - `cache` - `BARTCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
|
||||
/// both the self attention and the encoder cross attention of each layer of the decoder.
|
||||
///
|
||||
/// # Example
|
||||
@ -456,13 +456,13 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::pipelines::generation_utils::LMHeadModel;
|
||||
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
|
||||
/// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config};
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = BartConfig::from_file(config_path);
|
||||
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// # let config = M2M100Config::from_file(config_path);
|
||||
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -470,7 +470,7 @@ impl LMHeadModel for M2M100ForConditionalGeneration {
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let model_output = no_grad(|| {
|
||||
/// bart_model
|
||||
/// m2m100_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
/// None,
|
||||
@ -615,14 +615,14 @@ impl M2M100Generator {
|
||||
generate_config.vocab_resource.clone()
|
||||
};
|
||||
|
||||
let merges_resource = if generate_config.vocab_resource
|
||||
let merges_resource = if generate_config.merges_resource
|
||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
|
||||
{
|
||||
Resource::Remote(RemoteResource::from_pretrained(
|
||||
M2M100MergesResources::M2M100_418M,
|
||||
))
|
||||
} else {
|
||||
generate_config.vocab_resource.clone()
|
||||
generate_config.merges_resource.clone()
|
||||
};
|
||||
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
@ -835,7 +835,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
|
||||
},
|
||||
Cache::None => {}
|
||||
_ => {
|
||||
panic!("Invalid cache for BART model");
|
||||
panic!("Invalid cache for M2M100 model");
|
||||
}
|
||||
};
|
||||
encoder_outputs
|
||||
|
Loading…
Reference in New Issue
Block a user