rust-bert/utils/convert_model.py

86 lines
2.5 KiB
Python
Raw Normal View History

import argparse
2020-12-04 18:35:13 +03:00
import numpy as np
import subprocess
2021-06-05 12:47:56 +03:00
import sys
import torch
2020-12-04 18:35:13 +03:00
from pathlib import Path
2021-11-28 14:43:35 +03:00
from torch import Tensor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"source_file", help="Absolute path to the Pytorch weights file to convert"
)
parser.add_argument(
"--skip_embeddings",
action="store_true",
help="Skip shared embeddings",
)
parser.add_argument(
"--skip_lm_head", action="store_true", help="Skip language model head"
)
parser.add_argument("--prefix", help="Add a prefix on weight names")
parser.add_argument(
"--suffix",
action="store_true",
help="Split weight names on '.' and keep only last part",
)
parser.add_argument(
"--dtype",
help="Convert weights to a specific numpy DataType (float32, float16, ...)",
)
args = parser.parse_args()
2020-12-04 18:35:13 +03:00
source_file = Path(args.source_file)
target_folder = source_file.parent
2020-12-04 18:35:13 +03:00
weights = torch.load(str(source_file), map_location="cpu")
2020-12-04 18:35:13 +03:00
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split(".")[-1]
2021-11-28 14:43:35 +03:00
if isinstance(v, Tensor):
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype)))
else:
nps[k] = np.ascontiguousarray(tensor)
print(f"converted {k} - {str(sys.getsizeof(nps[k]))} bytes")
2021-11-28 14:43:35 +03:00
else:
print(f"skipped non-tensor object: {k}")
np.savez(target_folder / "model.npz", **nps)
2020-12-04 18:35:13 +03:00
source = str(target_folder / "model.npz")
target = str(target_folder / "rust_model.ot")
toml_location = (Path(__file__).resolve() / ".." / ".." / "Cargo.toml").resolve()
subprocess.run(
[
"cargo",
"run",
"--bin=convert-tensor",
"--manifest-path=%s" % toml_location,
"--",
source,
target,
],
)