marian/scripts/contrib/model_info.py
Alex Muzio a90950ea25 Merged PR 25154: Add model shapes flag to model_info.py script
Add model shapes flag to model_info.py script through `--matrix-shapes` flag

This will print something like:
```
...
encoder_l6_ffn_W1 (1024, 4096)
encoder_l6_ffn_W2 (4096, 1024)
encoder_l6_ffn_b1 (1, 4096)
encoder_l6_ffn_b2 (1, 1024)
encoder_l6_ffn_ffn_ln_bias (1, 1024)
encoder_l6_ffn_ffn_ln_scale (1, 1024)
encoder_l6_self_Wk (1024, 1024)
encoder_l6_self_Wo (1024, 1024)
encoder_l6_self_Wo_ln_bias (1, 1024)
encoder_l6_self_Wo_ln_scale (1, 1024)
encoder_l6_self_Wq (1024, 1024)
encoder_l6_self_Wv (1024, 1024)
encoder_l6_self_bk (1, 1024)
encoder_l6_self_bo (1, 1024)
encoder_l6_self_bq (1, 1024)
encoder_l6_self_bv (1, 1024)
special:model.yml (1264,)
```
2022-08-10 22:23:47 +00:00

67 lines
2.0 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import numpy as np
import yaml
DESC = "Prints keys and values from model.npz file."
S2S_SPECIAL_NODE = "special:model.yml"
def main():
args = parse_args()
model = np.load(args.model)
if args.special:
if S2S_SPECIAL_NODE not in model:
print("No special Marian YAML node found in the model")
exit(1)
yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii')
if not args.key:
print(yaml_text)
exit(0)
# fix the invalid trailing unicode character '#x0000' added to the YAML
# string by the C++ cnpy library
try:
yaml_node = yaml.safe_load(yaml_text)
except yaml.reader.ReaderError:
yaml_node = yaml.safe_load(yaml_text[:-1])
print(yaml_node[args.key])
else:
if args.key:
if args.key not in model:
print("Key not found")
exit(1)
if args.full_matrix:
for (x, y), val in np.ndenumerate(model[args.key]):
print(val)
else:
print(model[args.key])
else:
for key in model:
if args.matrix_shapes:
print(key, model[key].shape)
else:
print(key)
def parse_args():
parser = argparse.ArgumentParser(description=DESC)
parser.add_argument("-m", "--model", help="model file", required=True)
parser.add_argument("-k", "--key", help="print value for specific key")
parser.add_argument("-s", "--special", action="store_true",
help="print values from special:model.yml node")
parser.add_argument("-f", "--full-matrix", action="store_true",
help="force numpy to print full arrays for single key")
parser.add_argument("-ms", "--matrix-shapes", action="store_true",
help="print shapes of all arrays in the model")
return parser.parse_args()
if __name__ == "__main__":
main()