Updates to BART for Marian compatibility

This commit is contained in:
Guillaume B 2020-05-24 17:00:04 +02:00
parent 2b498bb10a
commit cccad54194
7 changed files with 161 additions and 36 deletions

43
examples/translation.rs Normal file
View File

@ -0,0 +1,43 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
use rust_bert::pipelines::generation::{LanguageGenerator, GenerateConfig, MarianGenerator};
use rust_bert::resources::{Resource, LocalResource};
use std::path::PathBuf;
fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
config_resource: Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/config.json")}),
model_resource: Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/model.ot")}),
vocab_resource: Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/vocab.json")}),
merges_resource: Resource::Local(LocalResource { local_path: PathBuf::from("E:/Coding/cache/rustbert/marian-mt-en-fr/spiece.model")}),
max_length: 512,
do_sample: false,
num_beams: 6,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
};
let mut model = MarianGenerator::new(generate_config)?;
let input_context = "The quick brown fox jumps over the lazy dog";
let output = model.generate(Some(vec!(input_context)), None);
for sentence in output {
println!("{:?}", sentence);
}
Ok(())
}

View File

@ -110,10 +110,12 @@ pub struct BartConfig {
pub max_position_embeddings: i64,
pub min_length: Option<i64>,
pub no_repeat_ngram_size: Option<i64>,
pub normalize_embedding: Option<bool>,
pub num_hidden_layers: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub output_past: Option<bool>,
pub static_position_embeddings: Option<bool>,
pub vocab_size: i64,
}

View File

@ -17,9 +17,9 @@ use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
use crate::bart::embeddings::PositionalEmbedding;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
pub struct DecoderLayer {
self_attention: SelfAttention,
@ -124,9 +124,9 @@ impl DecoderLayer {
pub struct BartDecoder {
dropout: Dropout,
layer_norm_embedding: nn::LayerNorm,
layer_norm_embedding: Option<nn::LayerNorm>,
layers: Vec<DecoderLayer>,
embed_positions: PositionalEmbedding,
embed_positions: EmbeddingOption,
output_attentions: bool,
output_hidden_states: bool,
output_past: bool,
@ -147,23 +147,41 @@ impl BartDecoder {
Some(value) => value,
None => false
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false
};
let dropout = Dropout::new(config.dropout);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
let layer_norm_embedding = nn::layer_norm(&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config);
let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
Some(nn::layer_norm(&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config))
} else {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
};
let embed_positions = PositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id);
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id))
};
let mut layers: Vec<DecoderLayer> = vec!();
let p_layers = &p / "layers";
@ -213,8 +231,8 @@ impl BartDecoder {
};
let x: Tensor = input_ids.as_ref().apply(embeddings) + positions;
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let x = x
.apply(&self.layer_norm_embedding)
.apply_t(&self.dropout, train)
.transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };

View File

@ -16,14 +16,31 @@ use tch::nn::{EmbeddingConfig, embedding};
use tch::kind::Kind::Int64;
/// # Abstraction that holds a embeddings configuration
pub enum EmbeddingOption {
/// PositionalEmbedding
LearnedPositionalEmbedding(LearnedPositionalEmbedding),
SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding),
}
impl EmbeddingOption {
/// Interface method to forward_t() of the particular models.
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
match *self {
Self::LearnedPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode),
Self::SinusoidalPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode)
}
}
}
#[derive(Debug)]
pub struct PositionalEmbedding {
pub struct LearnedPositionalEmbedding {
embedding: nn::Embedding,
padding_index: i64,
}
impl PositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> PositionalEmbedding {
impl LearnedPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> LearnedPositionalEmbedding {
let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() };
let num_embeddings = num_embeddings + padding_index + 1;
@ -31,7 +48,7 @@ impl PositionalEmbedding {
num_embeddings,
embedding_dim,
embedding_config);
PositionalEmbedding { embedding, padding_index }
LearnedPositionalEmbedding { embedding, padding_index }
}
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
@ -50,3 +67,27 @@ impl PositionalEmbedding {
position_ids
}
}
#[derive(Debug)]
pub struct SinusoidalPositionalEmbedding {
embedding: nn::Embedding,
}
impl SinusoidalPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding = embedding(p,
num_embeddings,
embedding_dim,
Default::default());
SinusoidalPositionalEmbedding { embedding }
}
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
let positions = if generation_mode {
Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device()))
} else {
Tensor::arange(input.size()[1],(Int64, input.device()))
};
positions.apply(&self.embedding)
}
}

View File

@ -17,7 +17,7 @@ use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
use crate::bart::embeddings::PositionalEmbedding;
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
use tch::kind::Kind::Bool;
use std::borrow::BorrowMut;
@ -90,9 +90,9 @@ impl EncoderLayer {
pub struct BartEncoder {
dropout: Dropout,
layer_norm_embedding: nn::LayerNorm,
layer_norm_embedding: Option<nn::LayerNorm>,
layers: Vec<EncoderLayer>,
embed_positions: PositionalEmbedding,
embed_positions: EmbeddingOption,
output_attentions: bool,
output_hidden_states: bool,
}
@ -107,22 +107,41 @@ impl BartEncoder {
Some(value) => value,
None => false
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false
};
let dropout = Dropout::new(config.dropout);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
let layer_norm_embedding = nn::layer_norm(&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config);
let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
Some(nn::layer_norm(&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config))
} else {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
};
let embed_positions = PositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id);
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id))
};
let mut layers: Vec<EncoderLayer> = vec!();
let p_layers = &p / "layers";
@ -153,8 +172,8 @@ impl BartEncoder {
let x = input_ids.apply(embeddings);
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let x = x
.apply(&self.layer_norm_embedding)
.apply_t(&self.dropout, train)
.transpose(0, 1);

View File

@ -724,6 +724,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, MarianVocab, MarianT
}
}
impl LanguageGenerator<BartForConditionalGeneration, MarianVocab, MarianTokenizer> for MarianGenerator {}
mod private_generation_utils {
use rust_tokenizers::{Vocab, Tokenizer, TruncationStrategy};

View File

@ -1,7 +1,7 @@
from transformers import BART_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_bart import vocab_url, merges_url
from transformers.file_utils import get_from_cache
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX
from pathlib import Path
import shutil
import os
@ -9,12 +9,13 @@ import numpy as np
import torch
import subprocess
config_path = BART_PRETRAINED_CONFIG_ARCHIVE_MAP['bart-large']
vocab_path = vocab_url
merges_path = merges_url
weights_path = BART_PRETRAINED_MODEL_ARCHIVE_MAP['bart-large']
ROOT_PATH = S3_BUCKET_PREFIX + '/' + 'Helsinki-NLP/opus-mt-en-fr'
config_path = ROOT_PATH + '/config.json'
vocab_path = ROOT_PATH + '/vocab.json'
merges_path = ROOT_PATH + '/source.spm'
weights_path = ROOT_PATH + '/pytorch_model.bin'
target_path = Path.home() / 'rustbert' / 'bart-large'
target_path = Path.home() / 'rustbert' / 'marian-mt-en-fr'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
@ -24,8 +25,8 @@ temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
merges_path = str(target_path / 'merges.txt')
vocab_path = str(target_path / 'vocab.json')
merges_path = str(target_path / 'spiece.model')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)