mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
GPT2 Model implemented
This commit is contained in:
parent
c12c0c479e
commit
a276e65cf4
@ -18,8 +18,6 @@ use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
|
||||
use rust_bert::gpt2::gpt2::{Gpt2Config, Gpt2Model};
|
||||
use rust_bert::common::config::Config;
|
||||
use tch::kind::Kind::Float;
|
||||
use rust_bert::gpt2::transformer::Block;
|
||||
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
@ -30,15 +28,15 @@ fn main() -> failure::Fallible<()> {
|
||||
let config_path = &home.as_path().join("config.json");
|
||||
let vocab_path = &home.as_path().join("vocab.txt");
|
||||
let merges_path = &home.as_path().join("merges.txt");
|
||||
let _weights_path = &home.as_path().join("model.ot");
|
||||
let weights_path = &home.as_path().join("model.ot");
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let _gpt2_model = Gpt2Model::new(&vs.root(), &config);
|
||||
// vs.load(weights_path)?;
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
@ -54,14 +52,20 @@ fn main() -> failure::Fallible<()> {
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let _input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let gpt2_model = Gpt2Model::new(&vs.root(), &config);
|
||||
let _input_tensor = Tensor::ones(&[32, 56, 768], (Float, vs.device()));
|
||||
|
||||
// let output = attention.forward_t(&_input_tensor, &None, &None, false);
|
||||
// println!("{:?}", output);
|
||||
let output = gpt2_model.forward_t(
|
||||
&Some(input_tensor),
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
false);
|
||||
println!("{:?}", output);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -19,6 +19,7 @@ use crate::common::dropout::Dropout;
|
||||
use tch::nn::embedding;
|
||||
use crate::gpt2::transformer::Block;
|
||||
use tch::kind::Kind::Int64;
|
||||
use std::borrow::BorrowMut;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Gpt2Config {
|
||||
@ -33,6 +34,7 @@ pub struct Gpt2Config {
|
||||
pub n_layer: i64,
|
||||
pub n_positions: i64,
|
||||
pub num_labels: Option<i64>,
|
||||
pub output_past: Option<bool>,
|
||||
pub output_attentions: Option<bool>,
|
||||
pub output_hidden_states: Option<bool>,
|
||||
pub resid_pdrop: Option<f64>,
|
||||
@ -47,10 +49,14 @@ pub struct Gpt2Model {
|
||||
drop: Dropout,
|
||||
ln_f: nn::LayerNorm,
|
||||
h: Vec<Block>,
|
||||
output_past: bool,
|
||||
output_hidden_states: bool,
|
||||
output_attentions: bool,
|
||||
}
|
||||
|
||||
impl Gpt2Model {
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
|
||||
let p = &(p / "transformer");
|
||||
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
|
||||
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default());
|
||||
|
||||
@ -60,14 +66,25 @@ impl Gpt2Model {
|
||||
};
|
||||
let drop = Dropout::new(embd_pdrop);
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
||||
let ln_f = nn::layer_norm(p / "ln_f ", vec![config.n_embd], layer_norm_config);
|
||||
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
|
||||
let mut h: Vec<Block> = vec!();
|
||||
let h_path = &(p / "h");
|
||||
for layer_index in 0..config.n_layer {
|
||||
h.push(Block::new(&(h_path / layer_index), config, true));
|
||||
};
|
||||
|
||||
Gpt2Model { wte, wpe, drop, ln_f, h }
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
};
|
||||
let output_past = match config.output_past {
|
||||
Some(value) => value,
|
||||
None => true
|
||||
};
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
};
|
||||
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
@ -77,7 +94,7 @@ impl Gpt2Model {
|
||||
token_type_ids: &Option<Tensor>,
|
||||
position_ids: &Option<Tensor>,
|
||||
input_embeds: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (input_embeddings, seq_length) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
@ -92,23 +109,23 @@ impl Gpt2Model {
|
||||
let (layer_past, layer_past_length) = match layer_past {
|
||||
Some(value) => {
|
||||
assert_eq!(value.len(), self.h.len(), "Past activations vector must be of length equal to the number of layers");
|
||||
(&value.iter().map(|&v| Some(v)).collect::<Vec<Option<Tensor>>>(), value[0].size()[3])
|
||||
(value.iter().map(|v| Some(v.copy())).collect::<Vec<Option<Tensor>>>(), value[0].size()[3])
|
||||
}
|
||||
None => {
|
||||
let mut out = Vec::with_capacity(self.h.len());
|
||||
out.resize_with(self.h.len(), || None::<Tensor>);
|
||||
(&out, 0)
|
||||
(out, 0)
|
||||
}
|
||||
};
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value,
|
||||
None => &Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0)
|
||||
Some(value) => value.copy(),
|
||||
None => Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0)
|
||||
};
|
||||
|
||||
let attention_mask: &Option<Tensor> = match attention_mask {
|
||||
let attention_mask: Option<Tensor> = match attention_mask {
|
||||
Some(value) => {
|
||||
&Some(
|
||||
Some(
|
||||
(value
|
||||
.view((input_embeddings.size()[0], -1))
|
||||
.unsqueeze(1)
|
||||
@ -116,7 +133,7 @@ impl Gpt2Model {
|
||||
- 1.0
|
||||
) * 10000.0)
|
||||
}
|
||||
None => &None
|
||||
None => None
|
||||
};
|
||||
|
||||
let position_embeds = position_ids.apply(&self.wpe);
|
||||
@ -124,10 +141,34 @@ impl Gpt2Model {
|
||||
Some(value) => value.apply(&self.wte),
|
||||
None => Tensor::zeros_like(&position_embeds)
|
||||
};
|
||||
let hidden_states: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||
|
||||
|
||||
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
|
||||
let mut layer_iter = self.h.iter().zip(layer_past);
|
||||
loop {
|
||||
match layer_iter.next() {
|
||||
Some(layer_values) => {
|
||||
let (layer, past) = layer_values;
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy());
|
||||
};
|
||||
|
||||
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
|
||||
hidden_state = temp.0;
|
||||
if let Some(presents) = all_presents.borrow_mut() {
|
||||
presents.push(temp.1.as_ref().copy());
|
||||
};
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(temp.2.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
};
|
||||
|
||||
Ok((hidden_state.apply(&self.ln_f), all_presents, all_hidden_states, all_attentions))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user