rust-bert/utils/convert_model.py

36 lines
1.3 KiB
Python
Raw Normal View History

2020-12-04 18:35:13 +03:00
from pathlib import Path
import numpy as np
import torch
import subprocess
import argparse
2021-06-05 12:47:56 +03:00
import sys
2020-12-04 18:35:13 +03:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head")
args = parser.parse_args()
2020-12-04 18:35:13 +03:00
source_file = Path(args.source_file)
target_folder = source_file.parent
2020-12-04 18:35:13 +03:00
weights = torch.load(str(source_file), map_location='cpu')
2020-12-04 18:35:13 +03:00
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
continue
2021-06-05 12:47:56 +03:00
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
np.savez(target_folder / 'model.npz', **nps)
2020-12-04 18:35:13 +03:00
source = str(target_folder / 'model.npz')
target = str(target_folder / 'rust_model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.run(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
)