mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Add scripts printing special:mode.yml from model.npz
This commit is contained in:
parent
2b369a54f9
commit
43fbaa6c10
44
scripts/contrib/model_info.py
Normal file
44
scripts/contrib/model_info.py
Normal file
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
|
||||
DESC = "Prints version and model type from model.npz file."
|
||||
S2S_SPECIAL_NODE = "special:model.yml"
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
model = np.load(args.model)
|
||||
if S2S_SPECIAL_NODE not in model:
|
||||
print("No special Marian YAML node found in the model")
|
||||
exit(1)
|
||||
|
||||
yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii')
|
||||
if not args.key:
|
||||
print(yaml_text)
|
||||
exit(0)
|
||||
|
||||
# fix the invalid trailing unicode character '#x0000' added to the YAML
|
||||
# string by the C++ cnpy library
|
||||
try:
|
||||
yaml_node = yaml.load(yaml_text)
|
||||
except yaml.reader.ReaderError:
|
||||
yaml_node = yaml.load(yaml_text[:-1])
|
||||
|
||||
print(yaml_node[args.key])
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=DESC)
|
||||
parser.add_argument("-m", "--model", help="model file", required=True)
|
||||
parser.add_argument("-k", "--key", help="print value for specific key")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user