BART config & initial input preparation

This commit is contained in:
Guillaume B 2020-03-29 12:01:13 +02:00
parent 6f98655832
commit efd2ed0509
7 changed files with 313 additions and 1 deletions

View File

@ -29,7 +29,7 @@ doc-only = ["tch/doc-only"]
features = [ "doc-only" ]
[dependencies]
rust_tokenizers = "2.0.3"
rust_tokenizers = "2.0.4"
tch = "0.1.6"
serde_json = "1.0.45"
serde = {version = "1.0.104", features = ["derive"]}

113
examples/bart.rs Normal file
View File

@ -0,0 +1,113 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::RobertaTokenizer;
use failure::err_msg;
use rust_bert::bart::BartConfig;
use rust_bert::Config;
use tch::kind::Kind::{Float, Int64};
use rust_bert::bart::bart::shift_tokens_right;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bart-large-cnn");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_bart.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
// Set-up masked LM model
let device = Device::Cpu;
let _vs = nn::VarStore::new(device);
let _tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = BartConfig::from_file(config_path);
// let bert_model = BertForMaskedLM::new(&vs.root(), &config);
// vs.load(weights_path)?;
//
//// Define input
// let input = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
// A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
// Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared \"I do\" five more times, sometimes only within two weeks of each other. \
// In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her \"first and only\" marriage. \
// Barrientos, now 39, is facing two criminal counts of \"offering a false instrument for filing in the first degree,\" referring to her false statements on the
// 2010 marriage license application, according to court documents.
// Prosecutors said the marriages were part of an immigration scam.
// On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
// After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
// Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
// All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
// Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
// Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
// The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s \
// Investigation Division. Seven of the men are from so-called \"red-flagged\" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
// Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
// If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18."];
// let tokenized_input = tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
// let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
// let tokenized_input = tokenized_input.
// iter().
// map(|input| input.token_ids.clone()).
// map(|mut input| {
// input.extend(vec![0; max_len - input.len()]);
// input
// }).
// map(|input|
// Tensor::of_slice(&(input))).
// collect::<Vec<_>>();
// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
let input_id = Tensor::randint1(2, 25, &[2, 10], (Float, device));
let _ = input_id.get(0).index_fill_(0, &Tensor::of_slice(&[7, 8, 9]).to_kind(Int64), 1);
let _ = input_id.get(1).index_fill_(0, &Tensor::of_slice(&[5, 6, 7, 8, 9]).to_kind(Int64), 1);
input_id.print();
let output = shift_tokens_right(&input_id, config.pad_token_id.unwrap());
output.print();
//// Forward pass
// let (output, _, _) = no_grad(|| {
// bert_model
// .forward_t(Some(input_tensor),
// None,
// None,
// None,
// None,
// &None,
// &None,
// false)
// });
//
//// Print masked tokens
// let index_1 = output.get(0).get(4).argmax(0, false);
// let index_2 = output.get(1).get(7).argmax(0, false);
// let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
// let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
//
// println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
// println!("{}", word_2);// Outputs "pear" : "It was a very nice and [pleasant] day"
Ok(())
}

95
src/bart/bart.rs Normal file
View File

@ -0,0 +1,95 @@
// Copyright 2020 The Facebook AI Research Team Authors
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::Config;
use tch::Tensor;
use tch::kind::Kind::Int64;
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
/// Swish ([Ramachandran, 2017](https://arxiv.org/abs/1710.05941))
swish,
/// Gaussian Error Linear Unit - OpenAI version ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu_new,
/// Tanh
tanh,
}
#[derive(Debug, Serialize, Deserialize)]
/// # BART model configuration
/// Defines the BART model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct BartConfig {
pub num_labels: Option<i64>,
pub activation_function: Option<Activation>,
pub activation_dropout: f64,
pub attention_dropout: f64,
pub classif_dropout: f64,
pub d_model: i64,
pub decoder_attention_heads: i64,
pub decoder_ffn_dim: i64,
pub decoder_layerdrop: f64,
pub decoder_layers: i64,
pub decoder_start_token_id: Option<i64>,
pub do_sample: bool,
pub dropout: f64,
pub early_stopping: bool,
pub encoder_attention_heads: i64,
pub encoder_ffn_dim: i64,
pub encoder_layerdrop: f64,
pub encoder_layers: i64,
pub bos_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub init_std: f64,
pub is_decoder: Option<bool>,
pub is_encoder_decoder: Option<bool>,
pub length_penalty: f64,
pub max_length: i64,
pub max_position_embeddings: i64,
pub min_length: Option<i64>,
pub no_repeat_ngram_size: Option<i64>,
pub num_beams: i64,
pub num_hidden_layers: i64,
pub num_return_sequences: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub output_past: Option<bool>,
pub pad_token_id: Option<i64>,
pub repetition_penalty: f64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub vocab_size: i64,
}
impl Config<BartConfig> for BartConfig {}
pub fn shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
let output = input_ids.empty_like().to_kind(Int64);
output
.select(1, 0)
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
output
.slice(1, 1, *output.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));
output
}

3
src/bart/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod bart;
pub use bart::BartConfig;

View File

@ -64,6 +64,7 @@ pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
pub mod bart;
mod common;
pub mod pipelines;

View File

@ -0,0 +1,50 @@
from transformers import BART_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_bart import vocab_url, merges_url
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = BART_PRETRAINED_CONFIG_ARCHIVE_MAP['bart-large']
vocab_path = vocab_url
merges_path = merges_url
weights_path = BART_PRETRAINED_MODEL_ARCHIVE_MAP['bart-large']
target_path = Path.home() / 'rustbert' / 'bart-large'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_merges = get_from_cache(merges_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
merges_path = str(target_path / 'merges.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_merges, merges_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

View File

@ -0,0 +1,50 @@
from transformers import BART_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_bart import vocab_url, merges_url
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = BART_PRETRAINED_CONFIG_ARCHIVE_MAP['bart-large-cnn']
vocab_path = vocab_url
merges_path = merges_url
weights_path = BART_PRETRAINED_MODEL_ARCHIVE_MAP['bart-large-cnn']
target_path = Path.home() / 'rustbert' / 'bart-large-cnn'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_merges = get_from_cache(merges_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
merges_path = str(target_path / 'merges.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_merges, merges_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])