initial commit for GPT2

This commit is contained in:
Guillaume B 2020-02-27 18:55:17 +01:00
parent f1fc3f7dc0
commit b0e84fc2b9
6 changed files with 177 additions and 4 deletions

View File

@ -61,10 +61,6 @@ fn main() -> failure::Fallible<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// let mask = Tensor::of_slice(&[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.]).view((-1, 11));
// let encoder_hidden_state = Some(Tensor::ones(&[2, 11, 768], (Float, input_tensor.device())));
// mask.print();
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),

62
examples/gpt2.rs Normal file
View File

@ -0,0 +1,62 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use rust_bert::gpt2::gpt2::{Gpt2Config, Gpt2Model};
use rust_bert::common::config::Config;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilgpt2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let _weights_path = &home.as_path().join("model.ot");
// Set-up masked LM model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let config = Gpt2Config::from_file(config_path);
let _gpt2_model = Gpt2Model::new(&vs.root(), &config);
// vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let _input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
Ok(())
}

65
src/gpt2/gpt2.rs Normal file
View File

@ -0,0 +1,65 @@
// Copyright 2018-present, the HuggingFace Inc. team
// Copyright 2018-present, The OpenAI Team Authors
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::config::Config;
use serde::{Deserialize, Serialize};
use tch::nn;
use crate::common::dropout::Dropout;
use tch::nn::embedding;
#[derive(Debug, Serialize, Deserialize)]
pub struct Gpt2Config {
pub attn_pdrop: Option<f64>,
pub embd_pdrop: Option<f64>,
pub hidden_dropout_prob: Option<f64>,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub n_ctx: i64,
pub n_embd: i64,
pub n_head: i64,
pub n_layer: i64,
pub n_positions: i64,
pub num_labels: Option<i64>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_pdrop: Option<f64>,
pub vocab_size: i64,
}
impl Config<Gpt2Config> for Gpt2Config {}
pub struct Gpt2Model {
_wte: nn::Embedding,
_wpe: nn::Embedding,
_drop: Dropout,
_ln_f: nn::LayerNorm,
}
impl Gpt2Model {
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default());
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1
};
let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let ln_f = nn::layer_norm(p / "ln_f ", vec![config.n_embd], layer_norm_config);
Gpt2Model { _wte: wte, _wpe: wpe, _drop: drop, _ln_f: ln_f }
}
}

1
src/gpt2/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod gpt2;

View File

@ -1,6 +1,7 @@
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod gpt2;
pub mod common;
pub mod pipelines;

View File

@ -0,0 +1,48 @@
from transformers import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.tokenization_gpt2 import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP["distilgpt2"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["distilgpt2"]
merges_path = PRETRAINED_VOCAB_FILES_MAP["merges_file"]["distilgpt2"]
weights_path = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP["distilgpt2"]
target_path = Path.home() / 'rustbert' / 'distilgpt2'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_merges = get_from_cache(merges_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
merges_path = str(target_path / 'merges.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_merges, merges_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])