mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
a90950ea25
Add model shapes flag to model_info.py script through `--matrix-shapes` flag This will print something like: ``` ... encoder_l6_ffn_W1 (1024, 4096) encoder_l6_ffn_W2 (4096, 1024) encoder_l6_ffn_b1 (1, 4096) encoder_l6_ffn_b2 (1, 1024) encoder_l6_ffn_ffn_ln_bias (1, 1024) encoder_l6_ffn_ffn_ln_scale (1, 1024) encoder_l6_self_Wk (1024, 1024) encoder_l6_self_Wo (1024, 1024) encoder_l6_self_Wo_ln_bias (1, 1024) encoder_l6_self_Wo_ln_scale (1, 1024) encoder_l6_self_Wq (1024, 1024) encoder_l6_self_Wv (1024, 1024) encoder_l6_self_bk (1, 1024) encoder_l6_self_bo (1, 1024) encoder_l6_self_bq (1, 1024) encoder_l6_self_bv (1, 1024) special:model.yml (1264,) ```
67 lines
2.0 KiB
Python
Executable File
67 lines
2.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import numpy as np
|
|
import yaml
|
|
|
|
|
|
DESC = "Prints keys and values from model.npz file."
|
|
S2S_SPECIAL_NODE = "special:model.yml"
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
model = np.load(args.model)
|
|
|
|
if args.special:
|
|
if S2S_SPECIAL_NODE not in model:
|
|
print("No special Marian YAML node found in the model")
|
|
exit(1)
|
|
|
|
yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii')
|
|
if not args.key:
|
|
print(yaml_text)
|
|
exit(0)
|
|
|
|
# fix the invalid trailing unicode character '#x0000' added to the YAML
|
|
# string by the C++ cnpy library
|
|
try:
|
|
yaml_node = yaml.safe_load(yaml_text)
|
|
except yaml.reader.ReaderError:
|
|
yaml_node = yaml.safe_load(yaml_text[:-1])
|
|
|
|
print(yaml_node[args.key])
|
|
else:
|
|
if args.key:
|
|
if args.key not in model:
|
|
print("Key not found")
|
|
exit(1)
|
|
if args.full_matrix:
|
|
for (x, y), val in np.ndenumerate(model[args.key]):
|
|
print(val)
|
|
else:
|
|
print(model[args.key])
|
|
else:
|
|
for key in model:
|
|
if args.matrix_shapes:
|
|
print(key, model[key].shape)
|
|
else:
|
|
print(key)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description=DESC)
|
|
parser.add_argument("-m", "--model", help="model file", required=True)
|
|
parser.add_argument("-k", "--key", help="print value for specific key")
|
|
parser.add_argument("-s", "--special", action="store_true",
|
|
help="print values from special:model.yml node")
|
|
parser.add_argument("-f", "--full-matrix", action="store_true",
|
|
help="force numpy to print full arrays for single key")
|
|
parser.add_argument("-ms", "--matrix-shapes", action="store_true",
|
|
help="print shapes of all arrays in the model")
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|