rust-bert/utils/convert_model.py
guillaume-be d7e9c03694
Long t5 implementation (#333)
* LongT5 config implementation

* LongT5 WiP: utility functions 1

* LongT5 WiP: utility functions (2)

* LongT5 WiP: utility functions (3)

* LongT5 WiP: utility functions (4)

* made T5 FF activations generic, expose T5 modules to crate

* Longt% local attention WIP

* LongT5 local attention

* LongT5 global attention WIP

* LongT5 global attention

* LongT5 attention modules (WIP)

* align LongT5 position bias with T5

* Addition of LongT5Block

* LongT5Stack WiP

* LongT5Stack implementation

* LongT5Model implementation

* LongT5ForConditionalGeneration implementation

* Addition of LongT5Generator, inclusion in pipelines

* LongT5 attention fixes

* Fix MIN/MAX dtype computation, mask for longt5

* Updated min/max and infinity computation across models

* GlobalTransient attention fixes

* Updated changelog, readme, tests, clippy
2023-02-12 16:18:20 +00:00

58 lines
2.1 KiB
Python

from pathlib import Path
import numpy as np
import torch
import subprocess
import argparse
import sys
from torch import Tensor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings")
parser.add_argument("--skip_lm_head", action="store_true", help="Skip language model head")
parser.add_argument("--prefix", help="Add a prefix on weight names")
parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part")
args = parser.parse_args()
source_file = Path(args.source_file)
target_folder = source_file.parent
weights = torch.load(str(source_file), map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight"
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split('.')[-1]
if isinstance(v, Tensor):
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
else:
print(f'skipped non-tensor object: {k}')
np.savez(target_folder / 'model.npz', **nps)
source = str(target_folder / 'model.npz')
target = str(target_folder / 'rust_model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.run(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
)