Initial commit for GPT model

This commit is contained in:
Guillaume B 2020-03-01 10:59:48 +01:00
parent b49c194ea4
commit ad1dbfcbc2
10 changed files with 276 additions and 5 deletions

74
examples/openai_gpt.rs Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::common::config::Config;
use rust_bert::openai_gpt::openai_gpt::OpenAiGptModel;
use rust_bert::Gpt2Config;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("openai-gpt");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = Gpt2Config::from_file(config_path);
let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let _input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// let (output, _, _, _) = gpt2_model.forward_t(
// &Some(input_tensor),
// &None,
// &None,
// &None,
// &None,
// &None,
// false).unwrap();
//
// let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
// let next_word = tokenizer.decode(vec!(next_word_id), true, true);
// println!("Provided input: {}", input[0]);
// println!("Next word: {}", next_word);
Ok(())
}

View File

@ -5,6 +5,8 @@ pub fn _gelu(x: &Tensor) -> Tensor { x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())
pub fn _relu(x: &Tensor) -> Tensor { x.relu() }
pub fn _swish(x: &Tensor) -> Tensor {x * x.sigmoid()}
pub fn _mish(x: &Tensor) -> Tensor { x * (x.softplus().tanh()) }
pub fn _gelu_new(x: &Tensor) -> Tensor { x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) }

View File

@ -22,11 +22,20 @@ use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
pub enum GptActivation {
gelu,
relu,
swish,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Gpt2Config {
pub attn_pdrop: Option<f64>,
pub embd_pdrop: Option<f64>,
pub hidden_dropout_prob: Option<f64>,
pub afn: Option<GptActivation>,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub n_ctx: i64,

View File

@ -1,3 +1,3 @@
pub mod gpt2;
pub mod attention;
pub mod transformer;
pub(crate) mod attention;
pub(crate) mod transformer;

View File

@ -15,8 +15,8 @@
use crate::gpt2::attention::{GPTConv1D, Attention};
use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::Gpt2Config;
use crate::common::activations::_gelu_new;
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use crate::common::activations::{_gelu_new, _relu, _swish};
pub struct MLP {
c_fc: GPTConv1D,
@ -29,7 +29,14 @@ impl MLP {
pub fn new(p: &nn::Path, config: &Gpt2Config) -> MLP {
let c_fc = GPTConv1D::new(&(p / "c_fc"), config.n_embd * 4, config.n_embd);
let c_proj = GPTConv1D::new(&(p / "c_proj"), config.n_embd, config.n_embd * 4);
let activation = Box::new(_gelu_new);
let activation = Box::new(match &config.afn {
Some(activation_enum) => match activation_enum {
GptActivation::gelu => _gelu_new,
GptActivation::relu => _relu,
GptActivation::swish => _swish,
},
None => _gelu_new
});
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1

View File

@ -1,6 +1,7 @@
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
pub mod common;
pub mod pipelines;
@ -12,5 +13,7 @@ pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM,
pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};
pub use gpt2::gpt2::{Gpt2Config, Gpt2Model, GPT2LMHeadModel};
pub use pipelines::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use pipelines::ner::{Entity, NERModel};

1
src/openai_gpt/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod openai_gpt;

View File

@ -0,0 +1,126 @@
// Copyright 2018-present, the HuggingFace Inc. team
// Copyright 2018-present, The OpenAI Team Authors
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::gpt2::transformer::Block;
use crate::Gpt2Config;
use tch::nn::embedding;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
pub struct OpenAiGptModel {
tokens_embed: nn::Embedding,
positions_embed: nn::Embedding,
drop: Dropout,
h: Vec<Block>,
output_hidden_states: bool,
output_attentions: bool,
}
impl OpenAiGptModel {
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
let tokens_embed = embedding(&(p / "tokens_embed"), config.vocab_size, config.n_embd, Default::default());
let positions_embed = embedding(&(p / "positions_embed"), config.n_positions, config.n_embd, Default::default());
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1
};
let drop = Dropout::new(embd_pdrop);
let mut h: Vec<Block> = vec!();
let h_path = &(p / "h");
for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true));
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
};
OpenAiGptModel { tokens_embed, positions_embed, drop, h, output_hidden_states, output_attentions }
}
pub fn forward_t(&self,
input_ids: &Option<Tensor>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply(&self.tokens_embed), *input_value.size().last().unwrap())
}
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
};
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0)
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => {
Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0
) * 10000.0)
}
None => None
};
let position_embeds = position_ids.apply(&self.positions_embed);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.tokens_embed),
None => Tensor::zeros_like(&position_embeds)
};
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layers = self.h.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &None, &attention_mask, train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
None => break
};
};
Ok((hidden_state, all_hidden_states, all_attentions))
}
}

View File

@ -56,6 +56,7 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4);
assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve"));

View File

@ -0,0 +1,48 @@
from transformers import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.tokenization_openai import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP["openai-gpt"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["openai-gpt"]
merges_path = PRETRAINED_VOCAB_FILES_MAP["merges_file"]["openai-gpt"]
weights_path = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP["openai-gpt"]
target_path = Path.home() / 'rustbert' / 'openai-gpt'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_merges = get_from_cache(merges_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
merges_path = str(target_path / 'merges.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_merges, merges_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])