2020-12-04 18:35:13 +03:00
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import subprocess
|
2021-02-02 19:51:39 +03:00
|
|
|
import argparse
|
2021-06-05 12:47:56 +03:00
|
|
|
import sys
|
2020-12-04 18:35:13 +03:00
|
|
|
|
2021-02-02 19:51:39 +03:00
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
|
2021-06-12 12:11:34 +03:00
|
|
|
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head")
|
2021-02-02 19:51:39 +03:00
|
|
|
args = parser.parse_args()
|
2020-12-04 18:35:13 +03:00
|
|
|
|
2021-02-02 19:51:39 +03:00
|
|
|
source_file = Path(args.source_file)
|
|
|
|
target_folder = source_file.parent
|
2020-12-04 18:35:13 +03:00
|
|
|
|
2021-02-02 19:51:39 +03:00
|
|
|
weights = torch.load(str(source_file), map_location='cpu')
|
2020-12-04 18:35:13 +03:00
|
|
|
|
2021-02-02 19:51:39 +03:00
|
|
|
nps = {}
|
|
|
|
for k, v in weights.items():
|
|
|
|
k = k.replace("gamma", "weight").replace("beta", "bias")
|
2021-06-12 12:11:34 +03:00
|
|
|
if args.skip_embeddings:
|
|
|
|
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
|
|
|
|
continue
|
2021-06-05 12:47:56 +03:00
|
|
|
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
|
2021-06-12 12:11:34 +03:00
|
|
|
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
|
2021-02-02 19:51:39 +03:00
|
|
|
np.savez(target_folder / 'model.npz', **nps)
|
2020-12-04 18:35:13 +03:00
|
|
|
|
2021-06-12 12:11:34 +03:00
|
|
|
source = str(target_folder / 'model.npz')
|
|
|
|
target = str(target_folder / 'rust_model.ot')
|
|
|
|
|
|
|
|
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
|
|
|
subprocess.run(
|
|
|
|
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
|
|
|
|
)
|