Updated GPT2 implementation, validated

This commit is contained in:
Guillaume B 2020-02-29 14:13:08 +01:00
parent a276e65cf4
commit d038626d37
3 changed files with 11 additions and 10 deletions

View File

@ -33,9 +33,9 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = Gpt2Config::from_file(config_path);
let _gpt2_model = Gpt2Model::new(&vs.root(), &config);
let gpt2_model = Gpt2Model::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -55,17 +55,16 @@ fn main() -> failure::Fallible<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let gpt2_model = Gpt2Model::new(&vs.root(), &config);
let output = gpt2_model.forward_t(
let (output, _, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
&None,
&None,
&None,
&None,
&None,
false);
println!("{:?}", output);
false).unwrap();
output.print();
Ok(())
}

View File

@ -142,7 +142,6 @@ impl Gpt2Model {
None => Tensor::zeros_like(&position_embeds)
};
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };

View File

@ -65,7 +65,10 @@ impl Block {
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
let output = x + self.mlp.forward_t(&(x + output).apply(&self.ln_2), train);
(output, present, attentions)
let x = x + output;
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
let x = x + m;
(x, present, attentions)
}
}