Initial commit for BERT implementation

This commit is contained in:
Guillaume B 2020-02-16 14:29:35 +01:00
parent 7762259906
commit 85d17e2736
12 changed files with 148 additions and 15 deletions

26
examples/bert.rs Normal file
View File

@ -0,0 +1,26 @@
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn};
use rust_tokenizers::BertTokenizer;
use rust_bert::bert::bert::BertConfig;
use rust_bert::common::config::Config;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let _weights_path = &home.as_path().join("model.ot");
let device = Device::Cpu;
let _vs = nn::VarStore::new(device);
let _tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap());
let _config = BertConfig::from_file(config_path);
Ok(())
}

39
src/bert/bert.rs Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use crate::common::config::Config;
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
pub enum Activation {
gelu,
relu,
mish,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BertConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f32,
pub intermediate_size: f32,
pub max_position_embeddings: i64,
pub num_attention_heads: i64,
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
}
impl Config<BertConfig> for BertConfig {}

1
src/bert/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod bert;

26
src/common/config.rs Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::fs::File;
use std::io::BufReader;
use serde::Deserialize;
pub trait Config<T>
where for<'de> T: Deserialize<'de> {
fn from_file(path: &Path) -> T {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);
let config: T = serde_json::from_reader(br).expect("could not parse configuration");
config
}
}

1
src/common/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod config;

View File

@ -12,15 +12,13 @@
extern crate tch;
use std::path::Path;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use serde::{Deserialize, Serialize};
use crate::distilbert::embeddings::BertEmbedding;
use crate::distilbert::transformer::Transformer;
use self::tch::{nn, Tensor};
use crate::distilbert::dropout::Dropout;
use crate::common::config::Config;
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
@ -56,14 +54,7 @@ pub struct DistilBertConfig {
pub vocab_size: i64,
}
impl DistilBertConfig {
pub fn from_file(path: &Path) -> DistilBertConfig {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);
let config: DistilBertConfig = serde_json::from_reader(br).expect("could not parse configuration");
config
}
}
impl Config<DistilBertConfig> for DistilBertConfig {}
pub struct DistilBertModel {
embeddings: BertEmbedding,

View File

@ -4,6 +4,7 @@ use std::path::Path;
use tch::{Device, Tensor, Kind, no_grad};
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::common::config::Config;
#[derive(Debug, PartialEq)]

View File

@ -1,4 +1,8 @@
pub mod distilbert;
pub mod bert;
pub mod common;
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM};
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use bert::bert::BertConfig;

View File

@ -5,6 +5,7 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, Trunc
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::{SentimentClassifier, SentimentPolarity};
use rust_bert::common::config::Config;
extern crate failure;
extern crate dirs;
@ -47,7 +48,6 @@ fn sentiment_classifier() -> failure::Fallible<()> {
}
#[test]
fn distilbert_masked_lm() -> failure::Fallible<()> {

View File

@ -0,0 +1,44 @@
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.tokenization_bert import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP["bert-base-uncased"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["bert-base-uncased"]
weights_path = BERT_PRETRAINED_MODEL_ARCHIVE_MAP["bert-base-uncased"]
target_path = Path.home() / 'rustbert' / 'bert'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

View File

@ -31,7 +31,7 @@ shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = v.cpu().numpy()
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)

View File

@ -31,7 +31,7 @@ shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = v.cpu().numpy()
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)