albert config implementation, download scripts

This commit is contained in:
Guillaume B 2020-06-16 19:07:37 +02:00
parent 7a86436f38
commit f4afd35ed6
6 changed files with 200 additions and 1 deletions

View File

@ -30,7 +30,7 @@ all-tests = []
features = [ "doc-only" ]
[dependencies]
rust_tokenizers = "~3.1.2"
rust_tokenizers = {version = "~3.1.3", path = "E:/Coding/backup-rust/rust-tokenizers/main"}
tch = "~0.1.7"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}

79
examples/albert.rs Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
use tch::Device;
use rust_tokenizers::AlbertTokenizer;
use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, LocalResource};
use rust_bert::albert::AlbertConfig;
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Local(LocalResource { local_path: "E:/Coding/cache/rustbert/albert-base-v2/config.json".parse().unwrap() });
let vocab_resource = Resource::Local(LocalResource { local_path: "E:/Coding/cache/rustbert/albert-base-v2/spiece.model".parse().unwrap() });
let weights_resource = Resource::Local(LocalResource { local_path: "E:/Coding/cache/rustbert/albert-base-v2/model.ot".parse().unwrap() });
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let _weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let _device = Device::Cpu;
// let mut vs = nn::VarStore::new(device);
let _tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let _config = AlbertConfig::from_file(config_path);
// let bert_model = BertForMaskedLM::new(&vs.root(), &config);
// vs.load(weights_path)?;
//
// // Define input
// let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
// let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
// let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
// let tokenized_input = tokenized_input.
// iter().
// map(|input| input.token_ids.clone()).
// map(|mut input| {
// input.extend(vec![0; max_len - input.len()]);
// input
// }).
// map(|input|
// Tensor::of_slice(&(input))).
// collect::<Vec<_>>();
// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
//
// // Forward pass
// let (output, _, _) = no_grad(|| {
// bert_model
// .forward_t(Some(input_tensor),
// None,
// None,
// None,
// None,
// &None,
// &None,
// false)
// });
//
// // Print masked tokens
// let index_1 = output.get(0).get(4).argmax(0, false);
// let index_2 = output.get(1).get(7).argmax(0, false);
// let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
// let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
//
// println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
// println!("{}", word_2);// Outputs "pear" : "It was a very nice and [pleasant] day"
Ok(())
}

67
src/albert/albert.rs Normal file
View File

@ -0,0 +1,67 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use crate::Config;
use serde::{Deserialize, Serialize};
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu_new,
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
/// Mish ([Misra, 2019](https://arxiv.org/abs/1908.08681))
mish,
}
#[derive(Debug, Serialize, Deserialize)]
/// # ALBERT model configuration
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct AlbertConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub down_scale_factor: i64,
pub embedding_size: i64,
pub gap_size: i64,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f32,
pub inner_group_num: i64,
pub intermediate_size: i64,
pub layer_norm_eps: f64,
pub max_position_embeddings: i64,
pub net_structure_type: i64,
pub num_attention_heads: i64,
pub num_hidden_groups: i64,
pub num_hidden_layers: i64,
pub num_memory_blocks: i64,
pub pad_token_id: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub is_decoder: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<AlbertConfig> for AlbertConfig {}

3
src/albert/mod.rs Normal file
View File

@ -0,0 +1,3 @@
mod albert;
pub use albert::{AlbertConfig};

View File

@ -65,6 +65,7 @@ pub mod gpt2;
pub mod bart;
pub mod electra;
pub mod marian;
pub mod albert;
mod common;
pub mod pipelines;

View File

@ -0,0 +1,49 @@
from transformers import ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_albert import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP["albert-base-v2"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["albert-base-v2"]
weights_path = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP["albert-base-v2"]
target_path = Path.home() / 'rustbert' / 'albert-base-v2'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'spiece.model')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))