End-to-end OpenAI GPT model implemented

This commit is contained in:
Guillaume B 2020-03-01 11:26:34 +01:00
parent ad1dbfcbc2
commit d820ea8eea
6 changed files with 106 additions and 23 deletions

View File

@ -17,7 +17,7 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::common::config::Config;
use rust_bert::openai_gpt::openai_gpt::OpenAiGptModel;
use rust_bert::openai_gpt::openai_gpt::OpenAIGPTLMHeadModel;
use rust_bert::Gpt2Config;
@ -34,13 +34,13 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let config = Gpt2Config::from_file(config_path);
let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
@ -53,22 +53,22 @@ fn main() -> failure::Fallible<()> {
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let _input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// let (output, _, _, _) = gpt2_model.forward_t(
// &Some(input_tensor),
// &None,
// &None,
// &None,
// &None,
// &None,
// false).unwrap();
//
// let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
// let next_word = tokenizer.decode(vec!(next_word_id), true, true);
// println!("Provided input: {}", input[0]);
// println!("Next word: {}", next_word);
let (output, _, _) = openai_gpt.forward_t(
&Some(input_tensor),
&None,
&None,
&None,
&None,
false).unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
println!("Output: {:?}", output);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}

View File

@ -72,7 +72,6 @@ impl Block {
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
let x = x + output;
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
let x = x + m;

View File

@ -1 +1,2 @@
pub mod openai_gpt;
pub mod openai_gpt;
mod transformer;

View File

@ -15,11 +15,12 @@
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::gpt2::transformer::Block;
use crate::Gpt2Config;
use tch::nn::embedding;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::openai_gpt::transformer::Block;
pub struct OpenAiGptModel {
@ -111,10 +112,10 @@ impl OpenAiGptModel {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &None, &attention_mask, train);
let temp = layer.forward_t(&hidden_state, &attention_mask, train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
None => break
@ -123,4 +124,38 @@ impl OpenAiGptModel {
Ok((hidden_state, all_hidden_states, all_attentions))
}
}
pub struct OpenAIGPTLMHeadModel {
transformer: OpenAiGptModel,
lm_head: LinearNoBias,
}
impl OpenAIGPTLMHeadModel {
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
let transformer = OpenAiGptModel::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
OpenAIGPTLMHeadModel { transformer, lm_head }
}
pub fn forward_t(&self,
input_ids: &Option<Tensor>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
all_hidden_states,
all_attentions) = self.transformer.forward_t(input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, all_hidden_states, all_attentions))
}
}

View File

@ -0,0 +1,46 @@
// Copyright 2018-present, the HuggingFace Inc. team
// Copyright 2018-present, The OpenAI Team Authors
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::attention::Attention;
use tch::{Tensor, nn};
use crate::gpt2::transformer::MLP;
use crate::Gpt2Config;
pub struct Block {
ln_1: nn::LayerNorm,
attn: Attention,
ln_2: nn::LayerNorm,
mlp: MLP,
}
impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp }
}
pub fn forward_t(&self, x: &Tensor, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Option<Tensor>) {
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
let x = (x + output).apply(&self.ln_1);
let m = self.mlp.forward_t(&x, train);
let x = (x + m).apply(&self.ln_2);
(x, attentions)
}
}

View File

@ -36,6 +36,8 @@ weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = np.ascontiguousarray(v.cpu().numpy())
if k == 'tokens_embed.weight':
nps['lm_head.weight'] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)