mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-04 04:37:58 +03:00
a48f235636
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
#!/usr/bin/env python3 -u
|
|
|
|
import argparse
|
|
import fileinput
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
from fairseq.models.transformer import TransformerModel
|
|
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="")
|
|
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
|
|
parser.add_argument(
|
|
"--fr2en", required=True, help="path to fr2en mixture of experts model"
|
|
)
|
|
parser.add_argument(
|
|
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
|
|
)
|
|
parser.add_argument(
|
|
"--num-experts",
|
|
type=int,
|
|
default=10,
|
|
help="(keep at 10 unless using a different model)",
|
|
)
|
|
parser.add_argument(
|
|
"files",
|
|
nargs="*",
|
|
default=["-"],
|
|
help='input files to paraphrase; "-" for stdin',
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.user_dir is None:
|
|
args.user_dir = os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
|
|
"translation_moe",
|
|
"src",
|
|
)
|
|
if os.path.exists(args.user_dir):
|
|
logging.info("found user_dir:" + args.user_dir)
|
|
else:
|
|
raise RuntimeError(
|
|
"cannot find fairseq examples/translation_moe/src "
|
|
"(tried looking here: {})".format(args.user_dir)
|
|
)
|
|
|
|
logging.info("loading en2fr model from:" + args.en2fr)
|
|
en2fr = TransformerModel.from_pretrained(
|
|
model_name_or_path=args.en2fr,
|
|
tokenizer="moses",
|
|
bpe="sentencepiece",
|
|
).eval()
|
|
|
|
logging.info("loading fr2en model from:" + args.fr2en)
|
|
fr2en = TransformerModel.from_pretrained(
|
|
model_name_or_path=args.fr2en,
|
|
tokenizer="moses",
|
|
bpe="sentencepiece",
|
|
user_dir=args.user_dir,
|
|
task="translation_moe",
|
|
).eval()
|
|
|
|
def gen_paraphrases(en):
|
|
fr = en2fr.translate(en)
|
|
return [
|
|
fr2en.translate(fr, inference_step_args={"expert": i})
|
|
for i in range(args.num_experts)
|
|
]
|
|
|
|
logging.info("Type the input sentence and press return:")
|
|
for line in fileinput.input(args.files):
|
|
line = line.strip()
|
|
if len(line) == 0:
|
|
continue
|
|
for paraphrase in gen_paraphrases(line):
|
|
print(paraphrase)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|