Reuse gpt2 embeddings (#160)

* Updated GPT2 to re-use embeddings for LM head

* Updated conversion utilities

* Updated changelog
This commit is contained in:
guillaume-be 2021-06-12 11:11:34 +02:00 committed by GitHub
parent 24fdb2dfb4
commit 4282d7b5c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 24 deletions

View File

@ -7,6 +7,9 @@ All notable changes to this project will be documented in this file. The format
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
## Changed
- Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint)
## [0.15.1] - 2021-06-01
### Fixed
- Fixed conversation model panic for user inputs exceeding the maximum model length (1000 tokens)

View File

@ -14,7 +14,6 @@
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
@ -482,10 +481,8 @@ impl Gpt2Model {
/// GPT2 model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings
/// It is made of the following blocks:
/// - `transformer`: Base Gpt2Model
/// - `lm_head`: Linear layer without bias tied to the weights of the token id embeddings
pub struct GPT2LMHeadModel {
transformer: Gpt2Model,
lm_head: LinearNoBias,
}
impl GPT2LMHeadModel {
@ -517,16 +514,8 @@ impl GPT2LMHeadModel {
let p = p.borrow();
let transformer = Gpt2Model::new(p, config);
let lm_head = linear_no_bias(
p / "lm_head",
config.n_embd,
config.vocab_size,
Default::default(),
);
GPT2LMHeadModel {
transformer,
lm_head,
}
GPT2LMHeadModel { transformer }
}
}
@ -641,7 +630,9 @@ impl LMHeadModel for GPT2LMHeadModel {
}
}?;
let lm_logits = base_model_output.output.apply(&self.lm_head);
let lm_logits = base_model_output
.output
.linear::<Tensor>(&self.transformer.wte.ws, None);
Ok(LMModelOutput {
lm_logits,
cache: Cache::GPT2Cache(base_model_output.cache),

View File

@ -8,6 +8,7 @@ import sys
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head")
args = parser.parse_args()
source_file = Path(args.source_file)
@ -18,16 +19,17 @@ if __name__ == "__main__":
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
continue
if args.skip_embeddings:
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
continue
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(k + str(sys.getsizeof(nps[k])))
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
np.savez(target_folder / 'model.npz', **nps)
# source = str(target_folder / 'model.npz')
# target = str(target_folder / 'rust_model.ot')
#
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
# subprocess.run(
# ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
# )
source = str(target_folder / 'model.npz')
target = str(target_folder / 'rust_model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.run(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
)