rust-bert/utils/convert_model.py
Romain Leroux c448862185
Add GPT-J support (#285) (#288)
* Add GPT-J support (#285)

* Improve GPT-J implementation

* Improve GPT-J tests

* Adapt GPT-J to latest master branch

* Specify how to convert GPT-J weights instead of providing them
2023-02-15 19:10:47 +00:00

86 lines
2.5 KiB
Python

import argparse
import numpy as np
import subprocess
import sys
import torch
from pathlib import Path
from torch import Tensor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"source_file", help="Absolute path to the Pytorch weights file to convert"
)
parser.add_argument(
"--skip_embeddings",
action="store_true",
help="Skip shared embeddings",
)
parser.add_argument(
"--skip_lm_head", action="store_true", help="Skip language model head"
)
parser.add_argument("--prefix", help="Add a prefix on weight names")
parser.add_argument(
"--suffix",
action="store_true",
help="Split weight names on '.' and keep only last part",
)
parser.add_argument(
"--dtype",
help="Convert weights to a specific numpy DataType (float32, float16, ...)",
)
args = parser.parse_args()
source_file = Path(args.source_file)
target_folder = source_file.parent
weights = torch.load(str(source_file), map_location="cpu")
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split(".")[-1]
if isinstance(v, Tensor):
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype)))
else:
nps[k] = np.ascontiguousarray(tensor)
print(f"converted {k} - {str(sys.getsizeof(nps[k]))} bytes")
else:
print(f"skipped non-tensor object: {k}")
np.savez(target_folder / "model.npz", **nps)
source = str(target_folder / "model.npz")
target = str(target_folder / "rust_model.ot")
toml_location = (Path(__file__).resolve() / ".." / ".." / "Cargo.toml").resolve()
subprocess.run(
[
"cargo",
"run",
"--bin=convert-tensor",
"--manifest-path=%s" % toml_location,
"--",
source,
target,
],
)