mirror of
https://github.com/marian-nmt/marian.git
synced 2024-10-26 09:09:10 +03:00
Update script exporting embeddings to support tied embeddings (#569)
This commit is contained in:
parent
22ad592a1d
commit
533604024b
@ -9,18 +9,22 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
desc = """Export word embedding from model"""
|
desc = """Export word embeddings from model"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter, description=desc)
|
formatter_class=argparse.RawDescriptionHelpFormatter, description=desc)
|
||||||
parser.add_argument("-m", "--model", help="Model file", required=True)
|
parser.add_argument("-m", "--model", help="path to model.npz file", required=True)
|
||||||
parser.add_argument(
|
parser.add_argument("-o", "--output-prefix", help="prefix for output files", required=True)
|
||||||
"-o", "--output-prefix", help="Output files prefix", required=True)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("Loading model")
|
print("Loading model")
|
||||||
model = np.load(args.model)
|
model = np.load(args.model)
|
||||||
special = yaml.load(model["special:model.yml"][:-1].tobytes())
|
special = yaml.load(model["special:model.yml"][:-1].tobytes())
|
||||||
|
|
||||||
|
if special["tied-embeddings-all"] or special["tied-embeddings-src"]:
|
||||||
|
all_emb = model["Wemb"]
|
||||||
|
export_emb(args.output_prefix + ".all", all_emb)
|
||||||
|
exit()
|
||||||
|
|
||||||
if special["type"] == "amun":
|
if special["type"] == "amun":
|
||||||
enc_emb = model["Wemb"]
|
enc_emb = model["Wemb"]
|
||||||
dec_emb = model["Wemb_dec"]
|
dec_emb = model["Wemb_dec"]
|
||||||
@ -28,16 +32,15 @@ def main():
|
|||||||
enc_emb = model["encoder_Wemb"]
|
enc_emb = model["encoder_Wemb"]
|
||||||
dec_emb = model["decoder_Wemb"]
|
dec_emb = model["decoder_Wemb"]
|
||||||
|
|
||||||
with open(args.output_prefix + ".src", "w") as out:
|
export_emb(args.output_prefix + ".src", enc_emb)
|
||||||
out.write("{0} {1}\n".format(*enc_emb.shape))
|
export_emb(args.output_prefix + ".trg", dec_emb)
|
||||||
for i in range(enc_emb.shape[0]):
|
|
||||||
vec = " ".join("{0:.8f}".format(v) for v in enc_emb[i])
|
|
||||||
out.write("{0} {1}\n".format(i, vec))
|
|
||||||
|
|
||||||
with open(args.output_prefix + ".trg", "w") as out:
|
|
||||||
out.write("{0} {1}\n".format(*dec_emb.shape))
|
def export_emb(filename, emb):
|
||||||
for i in range(dec_emb.shape[0]):
|
with open(filename, "w") as out:
|
||||||
vec = " ".join("{0:.8f}".format(v) for v in dec_emb[i])
|
out.write("{0} {1}\n".format(*emb.shape))
|
||||||
|
for i in range(emb.shape[0]):
|
||||||
|
vec = " ".join("{0:.8f}".format(v) for v in emb[i])
|
||||||
out.write("{0} {1}\n".format(i, vec))
|
out.write("{0} {1}\n".format(i, vec))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user