mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
Apply black+isort (#1357)
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
This commit is contained in:
parent
5695cdfb2c
commit
a48f235636
57
docs/conf.py
57
docs/conf.py
@ -20,10 +20,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# source code directory, relative to this file, for sphinx-autobuild
|
|
||||||
sys.path.insert(0, os.path.abspath('..'))
|
|
||||||
|
|
||||||
source_suffix = ['.rst']
|
# source code directory, relative to this file, for sphinx-autobuild
|
||||||
|
sys.path.insert(0, os.path.abspath(".."))
|
||||||
|
|
||||||
|
source_suffix = [".rst"]
|
||||||
|
|
||||||
# -- General configuration ------------------------------------------------
|
# -- General configuration ------------------------------------------------
|
||||||
|
|
||||||
@ -35,34 +36,34 @@ source_suffix = ['.rst']
|
|||||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
# ones.
|
# ones.
|
||||||
extensions = [
|
extensions = [
|
||||||
'sphinx.ext.autodoc',
|
"sphinx.ext.autodoc",
|
||||||
'sphinx.ext.intersphinx',
|
"sphinx.ext.intersphinx",
|
||||||
'sphinx.ext.viewcode',
|
"sphinx.ext.viewcode",
|
||||||
'sphinx.ext.napoleon',
|
"sphinx.ext.napoleon",
|
||||||
'sphinxarg.ext',
|
"sphinxarg.ext",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
templates_path = ['_templates']
|
templates_path = ["_templates"]
|
||||||
|
|
||||||
# The master toctree document.
|
# The master toctree document.
|
||||||
master_doc = 'index'
|
master_doc = "index"
|
||||||
|
|
||||||
# General information about the project.
|
# General information about the project.
|
||||||
project = 'fairseq'
|
project = "fairseq"
|
||||||
copyright = '2019, Facebook AI Research (FAIR)'
|
copyright = "2019, Facebook AI Research (FAIR)"
|
||||||
author = 'Facebook AI Research (FAIR)'
|
author = "Facebook AI Research (FAIR)"
|
||||||
|
|
||||||
github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
|
github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
|
||||||
|
|
||||||
# The version info for the project you're documenting, acts as replacement for
|
# The version info for the project you're documenting, acts as replacement for
|
||||||
# |version| and |release|, also used in various other places throughout the
|
# |version| and |release|, also used in various other places throughout the
|
||||||
# built documents.
|
# built documents.
|
||||||
#
|
#
|
||||||
# The short X.Y version.
|
# The short X.Y version.
|
||||||
version = '0.9.0'
|
version = "0.9.0"
|
||||||
# The full version, including alpha/beta/rc tags.
|
# The full version, including alpha/beta/rc tags.
|
||||||
release = '0.9.0'
|
release = "0.9.0"
|
||||||
|
|
||||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||||
# for a list of supported languages.
|
# for a list of supported languages.
|
||||||
@ -74,11 +75,11 @@ language = None
|
|||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This patterns also effect to html_static_path and html_extra_path
|
# This patterns also effect to html_static_path and html_extra_path
|
||||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
# The name of the Pygments (syntax highlighting) style to use.
|
# The name of the Pygments (syntax highlighting) style to use.
|
||||||
pygments_style = 'sphinx'
|
pygments_style = "sphinx"
|
||||||
highlight_language = 'python'
|
highlight_language = "python"
|
||||||
|
|
||||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||||
todo_include_todos = False
|
todo_include_todos = False
|
||||||
@ -89,7 +90,7 @@ todo_include_todos = False
|
|||||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
# a list of builtin themes.
|
# a list of builtin themes.
|
||||||
#
|
#
|
||||||
html_theme = 'sphinx_rtd_theme'
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
|
||||||
# Theme options are theme-specific and customize the look and feel of a theme
|
# Theme options are theme-specific and customize the look and feel of a theme
|
||||||
# further. For a list of options available for each theme, see the
|
# further. For a list of options available for each theme, see the
|
||||||
@ -100,11 +101,11 @@ html_theme = 'sphinx_rtd_theme'
|
|||||||
# Add any paths that contain custom static files (such as style sheets) here,
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
# relative to this directory. They are copied after the builtin static files,
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
html_static_path = ['_static']
|
html_static_path = ["_static"]
|
||||||
|
|
||||||
html_context = {
|
html_context = {
|
||||||
'css_files': [
|
"css_files": [
|
||||||
'_static/theme_overrides.css', # override wide tables in RTD theme
|
"_static/theme_overrides.css", # override wide tables in RTD theme
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,7 +114,7 @@ html_context = {
|
|||||||
#
|
#
|
||||||
# This is required for the alabaster theme
|
# This is required for the alabaster theme
|
||||||
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
||||||
#html_sidebars = {
|
# html_sidebars = {
|
||||||
# '**': [
|
# '**': [
|
||||||
# 'about.html',
|
# 'about.html',
|
||||||
# 'navigation.html',
|
# 'navigation.html',
|
||||||
@ -121,12 +122,12 @@ html_context = {
|
|||||||
# 'searchbox.html',
|
# 'searchbox.html',
|
||||||
# 'donate.html',
|
# 'donate.html',
|
||||||
# ]
|
# ]
|
||||||
#}
|
# }
|
||||||
|
|
||||||
|
|
||||||
# Example configuration for intersphinx: refer to the Python standard library.
|
# Example configuration for intersphinx: refer to the Python standard library.
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
|
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
|
||||||
'python': ('https://docs.python.org/', None),
|
"python": ("https://docs.python.org/", None),
|
||||||
'torch': ('https://pytorch.org/docs/master/', None),
|
"torch": ("https://pytorch.org/docs/master/", None),
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,6 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
__version__ = '0.9.0'
|
__version__ = "0.9.0"
|
||||||
|
|
||||||
import examples.noisychannel # noqa
|
import examples.noisychannel # noqa
|
||||||
|
@ -7,8 +7,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import fileinput
|
import fileinput
|
||||||
import hashlib
|
import hashlib
|
||||||
from multiprocessing import Pool
|
|
||||||
import sys
|
import sys
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
def get_hashes_and_lines(raw_line):
|
def get_hashes_and_lines(raw_line):
|
||||||
@ -18,12 +18,12 @@ def get_hashes_and_lines(raw_line):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--workers', type=int, default=10)
|
parser.add_argument("--workers", type=int, default=10)
|
||||||
parser.add_argument('files', nargs='*', help='input files')
|
parser.add_argument("files", nargs="*", help="input files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
seen = set()
|
seen = set()
|
||||||
with fileinput.input(args.files, mode='rb') as h:
|
with fileinput.input(args.files, mode="rb") as h:
|
||||||
pool = Pool(args.workers)
|
pool = Pool(args.workers)
|
||||||
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
|
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
|
||||||
for i, (hash, raw_line) in enumerate(results):
|
for i, (hash, raw_line) in enumerate(results):
|
||||||
@ -37,5 +37,5 @@ def main():
|
|||||||
print(file=sys.stderr, flush=True)
|
print(file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -11,26 +11,38 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description=(
|
parser = argparse.ArgumentParser(
|
||||||
'Extract back-translations from the stdout of fairseq-generate. '
|
description=(
|
||||||
'If there are multiply hypotheses for a source, we only keep the first one. '
|
"Extract back-translations from the stdout of fairseq-generate. "
|
||||||
))
|
"If there are multiply hypotheses for a source, we only keep the first one. "
|
||||||
parser.add_argument('--output', required=True, help='output prefix')
|
)
|
||||||
parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)')
|
)
|
||||||
parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)')
|
parser.add_argument("--output", required=True, help="output prefix")
|
||||||
parser.add_argument('--minlen', type=int, help='min length filter')
|
parser.add_argument(
|
||||||
parser.add_argument('--maxlen', type=int, help='max length filter')
|
"--srclang", required=True, help="source language (extracted from H-* lines)"
|
||||||
parser.add_argument('--ratio', type=float, help='ratio filter')
|
)
|
||||||
parser.add_argument('files', nargs='*', help='input files')
|
parser.add_argument(
|
||||||
|
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--minlen", type=int, help="min length filter")
|
||||||
|
parser.add_argument("--maxlen", type=int, help="max length filter")
|
||||||
|
parser.add_argument("--ratio", type=float, help="ratio filter")
|
||||||
|
parser.add_argument("files", nargs="*", help="input files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def validate(src, tgt):
|
def validate(src, tgt):
|
||||||
srclen = len(src.split(' ')) if src != '' else 0
|
srclen = len(src.split(" ")) if src != "" else 0
|
||||||
tgtlen = len(tgt.split(' ')) if tgt != '' else 0
|
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
|
||||||
if (
|
if (
|
||||||
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
|
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
|
||||||
or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen))
|
or (
|
||||||
or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio))
|
args.maxlen is not None
|
||||||
|
and (srclen > args.maxlen or tgtlen > args.maxlen)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
args.ratio is not None
|
||||||
|
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@ -41,19 +53,20 @@ def main():
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
with open(args.output + '.' + args.srclang, 'w') as src_h, \
|
with open(args.output + "." + args.srclang, "w") as src_h, open(
|
||||||
open(args.output + '.' + args.tgtlang, 'w') as tgt_h:
|
args.output + "." + args.tgtlang, "w"
|
||||||
|
) as tgt_h:
|
||||||
for line in tqdm(fileinput.input(args.files)):
|
for line in tqdm(fileinput.input(args.files)):
|
||||||
if line.startswith('S-'):
|
if line.startswith("S-"):
|
||||||
tgt = safe_index(line.rstrip().split('\t'), 1, '')
|
tgt = safe_index(line.rstrip().split("\t"), 1, "")
|
||||||
elif line.startswith('H-'):
|
elif line.startswith("H-"):
|
||||||
if tgt is not None:
|
if tgt is not None:
|
||||||
src = safe_index(line.rstrip().split('\t'), 2, '')
|
src = safe_index(line.rstrip().split("\t"), 2, "")
|
||||||
if validate(src, tgt):
|
if validate(src, tgt):
|
||||||
print(src, file=src_h)
|
print(src, file=src_h)
|
||||||
print(tgt, file=tgt_h)
|
print(tgt, file=tgt_h)
|
||||||
tgt = None
|
tgt = None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -4,203 +4,251 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import os.path as op
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from multiprocessing import cpu_count
|
import os.path as op
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Optional, List
|
from multiprocessing import cpu_count
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import sentencepiece as sp
|
import sentencepiece as sp
|
||||||
|
|
||||||
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
|
|
||||||
from fairseq.data.encoders.byte_utils import byte_encode
|
|
||||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
|
||||||
from fairseq.data.encoders.characters import Characters
|
|
||||||
from fairseq.data.encoders.byte_bpe import ByteBPE
|
from fairseq.data.encoders.byte_bpe import ByteBPE
|
||||||
|
from fairseq.data.encoders.byte_utils import byte_encode
|
||||||
from fairseq.data.encoders.bytes import Bytes
|
from fairseq.data.encoders.bytes import Bytes
|
||||||
|
from fairseq.data.encoders.characters import Characters
|
||||||
|
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
|
||||||
|
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
||||||
|
|
||||||
|
|
||||||
SPLITS = ['train', 'valid', 'test']
|
SPLITS = ["train", "valid", "test"]
|
||||||
|
|
||||||
|
|
||||||
def _convert_xml(in_path: str, out_path: str):
|
def _convert_xml(in_path: str, out_path: str):
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
ss = s.strip()
|
ss = s.strip()
|
||||||
if not ss.startswith('<seg'):
|
if not ss.startswith("<seg"):
|
||||||
continue
|
continue
|
||||||
ss = ss.replace('</seg>', '').split('">')
|
ss = ss.replace("</seg>", "").split('">')
|
||||||
assert len(ss) == 2
|
assert len(ss) == 2
|
||||||
f_o.write(ss[1].strip() + '\n')
|
f_o.write(ss[1].strip() + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _convert_train(in_path: str, out_path: str):
|
def _convert_train(in_path: str, out_path: str):
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
ss = s.strip()
|
ss = s.strip()
|
||||||
if ss.startswith('<'):
|
if ss.startswith("<"):
|
||||||
continue
|
continue
|
||||||
f_o.write(ss.strip() + '\n')
|
f_o.write(ss.strip() + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _get_bytes(in_path: str, out_path: str):
|
def _get_bytes(in_path: str, out_path: str):
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
f_o.write(Bytes.encode(s.strip()) + '\n')
|
f_o.write(Bytes.encode(s.strip()) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _get_chars(in_path: str, out_path: str):
|
def _get_chars(in_path: str, out_path: str):
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
f_o.write(Characters.encode(s.strip()) + '\n')
|
f_o.write(Characters.encode(s.strip()) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
|
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
|
||||||
Args = namedtuple('Args', ['moses_source_lang', 'moses_target_lang',
|
Args = namedtuple(
|
||||||
'moses_no_dash_splits', 'moses_no_escape'])
|
"Args",
|
||||||
args = Args(moses_source_lang=src, moses_target_lang=tgt,
|
[
|
||||||
moses_no_dash_splits=False, moses_no_escape=False)
|
"moses_source_lang",
|
||||||
|
"moses_target_lang",
|
||||||
|
"moses_no_dash_splits",
|
||||||
|
"moses_no_escape",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
args = Args(
|
||||||
|
moses_source_lang=src,
|
||||||
|
moses_target_lang=tgt,
|
||||||
|
moses_no_dash_splits=False,
|
||||||
|
moses_no_escape=False,
|
||||||
|
)
|
||||||
pretokenizer = MosesTokenizer(args)
|
pretokenizer = MosesTokenizer(args)
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
f_o.write(pretokenizer.encode(s.strip()) + '\n')
|
f_o.write(pretokenizer.encode(s.strip()) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
|
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
|
||||||
with open(out_path, 'w') as f_o:
|
with open(out_path, "w") as f_o:
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
with open(f'{in_path_prefix}.{lang}') as f:
|
with open(f"{in_path_prefix}.{lang}") as f:
|
||||||
for s in f:
|
for s in f:
|
||||||
f_o.write(byte_encode(s.strip()) + '\n')
|
f_o.write(byte_encode(s.strip()) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
|
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
|
||||||
arguments = [
|
arguments = [
|
||||||
f'--input={in_path}', f'--model_prefix={model_prefix}',
|
f"--input={in_path}",
|
||||||
f'--model_type=bpe', f'--vocab_size={vocab_size}',
|
f"--model_prefix={model_prefix}",
|
||||||
'--character_coverage=1.0', '--normalization_rule_name=identity',
|
f"--model_type=bpe",
|
||||||
f'--num_threads={cpu_count()}'
|
f"--vocab_size={vocab_size}",
|
||||||
|
"--character_coverage=1.0",
|
||||||
|
"--normalization_rule_name=identity",
|
||||||
|
f"--num_threads={cpu_count()}",
|
||||||
]
|
]
|
||||||
sp.SentencePieceTrainer.Train(' '.join(arguments))
|
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
||||||
|
|
||||||
|
|
||||||
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
|
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
|
||||||
Args = namedtuple('Args', ['sentencepiece_model_path'])
|
Args = namedtuple("Args", ["sentencepiece_model_path"])
|
||||||
args = Args(sentencepiece_model_path=model_path)
|
args = Args(sentencepiece_model_path=model_path)
|
||||||
tokenizer = ByteBPE(args)
|
tokenizer = ByteBPE(args)
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
f_o.write(tokenizer.encode(s.strip()) + '\n')
|
f_o.write(tokenizer.encode(s.strip()) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _apply_bpe(model_path: str, in_path: str, out_path: str):
|
def _apply_bpe(model_path: str, in_path: str, out_path: str):
|
||||||
Args = namedtuple('Args', ['sentencepiece_model'])
|
Args = namedtuple("Args", ["sentencepiece_model"])
|
||||||
args = Args(sentencepiece_model=model_path)
|
args = Args(sentencepiece_model=model_path)
|
||||||
tokenizer = SentencepieceBPE(args)
|
tokenizer = SentencepieceBPE(args)
|
||||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||||
for s in f:
|
for s in f:
|
||||||
f_o.write(tokenizer.encode(s.strip()) + '\n')
|
f_o.write(tokenizer.encode(s.strip()) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def _concat_files(in_paths: List[str], out_path: str):
|
def _concat_files(in_paths: List[str], out_path: str):
|
||||||
with open(out_path, 'w') as f_o:
|
with open(out_path, "w") as f_o:
|
||||||
for p in in_paths:
|
for p in in_paths:
|
||||||
with open(p) as f:
|
with open(p) as f:
|
||||||
for r in f:
|
for r in f:
|
||||||
f_o.write(r)
|
f_o.write(r)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_iwslt17(root: str, src: str, tgt: str, bpe_size: Optional[int],
|
def preprocess_iwslt17(
|
||||||
need_chars: bool, bbpe_size: Optional[int],
|
root: str,
|
||||||
need_bytes: bool):
|
src: str,
|
||||||
|
tgt: str,
|
||||||
|
bpe_size: Optional[int],
|
||||||
|
need_chars: bool,
|
||||||
|
bbpe_size: Optional[int],
|
||||||
|
need_bytes: bool,
|
||||||
|
):
|
||||||
# extract bitext
|
# extract bitext
|
||||||
in_root = op.join(root, f'{src}-{tgt}')
|
in_root = op.join(root, f"{src}-{tgt}")
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
_convert_train(
|
_convert_train(
|
||||||
op.join(in_root, f'train.tags.{src}-{tgt}.{lang}'),
|
op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
|
||||||
op.join(root, f'train.{lang}')
|
op.join(root, f"train.{lang}"),
|
||||||
)
|
)
|
||||||
_convert_xml(
|
_convert_xml(
|
||||||
op.join(in_root, f'IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml'),
|
op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
|
||||||
op.join(root, f'valid.{lang}')
|
op.join(root, f"valid.{lang}"),
|
||||||
)
|
)
|
||||||
_convert_xml(
|
_convert_xml(
|
||||||
op.join(in_root, f'IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml'),
|
op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
|
||||||
op.join(root, f'test.{lang}')
|
op.join(root, f"test.{lang}"),
|
||||||
)
|
)
|
||||||
# pre-tokenize
|
# pre-tokenize
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
pretokenize(op.join(root, f'{split}.{lang}'),
|
pretokenize(
|
||||||
op.join(root, f'{split}.moses.{lang}'), src, tgt)
|
op.join(root, f"{split}.{lang}"),
|
||||||
|
op.join(root, f"{split}.moses.{lang}"),
|
||||||
|
src,
|
||||||
|
tgt,
|
||||||
|
)
|
||||||
# tokenize with BPE vocabulary
|
# tokenize with BPE vocabulary
|
||||||
if bpe_size is not None:
|
if bpe_size is not None:
|
||||||
# learn vocabulary
|
# learn vocabulary
|
||||||
concated_train_path = op.join(root, 'train.all')
|
concated_train_path = op.join(root, "train.all")
|
||||||
_concat_files(
|
_concat_files(
|
||||||
[op.join(root, 'train.moses.fr'), op.join(root, 'train.moses.en')],
|
[op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
|
||||||
concated_train_path
|
concated_train_path,
|
||||||
)
|
)
|
||||||
bpe_model_prefix = op.join(root, f'spm_bpe{bpe_size}')
|
bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
|
||||||
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
|
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
|
||||||
os.remove(concated_train_path)
|
os.remove(concated_train_path)
|
||||||
# apply
|
# apply
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
_apply_bpe(
|
_apply_bpe(
|
||||||
bpe_model_prefix + '.model',
|
bpe_model_prefix + ".model",
|
||||||
op.join(root, f'{split}.moses.{lang}'),
|
op.join(root, f"{split}.moses.{lang}"),
|
||||||
op.join(root, f'{split}.moses.bpe{bpe_size}.{lang}')
|
op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
|
||||||
)
|
)
|
||||||
# tokenize with bytes vocabulary
|
# tokenize with bytes vocabulary
|
||||||
if need_bytes:
|
if need_bytes:
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
_get_bytes(op.join(root, f'{split}.moses.{lang}'),
|
_get_bytes(
|
||||||
op.join(root, f'{split}.moses.bytes.{lang}'))
|
op.join(root, f"{split}.moses.{lang}"),
|
||||||
|
op.join(root, f"{split}.moses.bytes.{lang}"),
|
||||||
|
)
|
||||||
# tokenize with characters vocabulary
|
# tokenize with characters vocabulary
|
||||||
if need_chars:
|
if need_chars:
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
_get_chars(op.join(root, f'{split}.moses.{lang}'),
|
_get_chars(
|
||||||
op.join(root, f'{split}.moses.chars.{lang}'))
|
op.join(root, f"{split}.moses.{lang}"),
|
||||||
|
op.join(root, f"{split}.moses.chars.{lang}"),
|
||||||
|
)
|
||||||
# tokenize with byte-level BPE vocabulary
|
# tokenize with byte-level BPE vocabulary
|
||||||
if bbpe_size is not None:
|
if bbpe_size is not None:
|
||||||
# learn vocabulary
|
# learn vocabulary
|
||||||
bchar_path = op.join(root, 'train.bchar')
|
bchar_path = op.join(root, "train.bchar")
|
||||||
_convert_to_bchar(op.join(root, 'train.moses'), src, tgt, bchar_path)
|
_convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
|
||||||
bbpe_model_prefix = op.join(root, f'spm_bbpe{bbpe_size}')
|
bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
|
||||||
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
|
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
|
||||||
os.remove(bchar_path)
|
os.remove(bchar_path)
|
||||||
# apply
|
# apply
|
||||||
for lang in [src, tgt]:
|
for lang in [src, tgt]:
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
_apply_bbpe(
|
_apply_bbpe(
|
||||||
bbpe_model_prefix + '.model',
|
bbpe_model_prefix + ".model",
|
||||||
op.join(root, f'{split}.moses.{lang}'),
|
op.join(root, f"{split}.moses.{lang}"),
|
||||||
op.join(root, f'{split}.moses.bbpe{bbpe_size}.{lang}')
|
op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--root', type=str, default='data')
|
parser.add_argument("--root", type=str, default="data")
|
||||||
parser.add_argument('--bpe-vocab', default=None, type=int,
|
parser.add_argument(
|
||||||
help='Generate tokenized bitext with BPE of size K.'
|
"--bpe-vocab",
|
||||||
'Default to None (disabled).')
|
default=None,
|
||||||
parser.add_argument('--bbpe-vocab', default=None, type=int,
|
type=int,
|
||||||
help='Generate tokenized bitext with BBPE of size K.'
|
help="Generate tokenized bitext with BPE of size K."
|
||||||
'Default to None (disabled).')
|
"Default to None (disabled).",
|
||||||
parser.add_argument('--byte-vocab', action='store_true',
|
)
|
||||||
help='Generate tokenized bitext with bytes vocabulary')
|
parser.add_argument(
|
||||||
parser.add_argument('--char-vocab', action='store_true',
|
"--bbpe-vocab",
|
||||||
help='Generate tokenized bitext with chars vocabulary')
|
default=None,
|
||||||
|
type=int,
|
||||||
|
help="Generate tokenized bitext with BBPE of size K."
|
||||||
|
"Default to None (disabled).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--byte-vocab",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate tokenized bitext with bytes vocabulary",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--char-vocab",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate tokenized bitext with chars vocabulary",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
preprocess_iwslt17(args.root, 'fr', 'en', args.bpe_vocab, args.char_vocab,
|
preprocess_iwslt17(
|
||||||
args.bbpe_vocab, args.byte_vocab)
|
args.root,
|
||||||
|
"fr",
|
||||||
|
"en",
|
||||||
|
args.bpe_vocab,
|
||||||
|
args.char_vocab,
|
||||||
|
args.bbpe_vocab,
|
||||||
|
args.byte_vocab,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairseq.models import register_model, register_model_architecture
|
from fairseq.models import register_model, register_model_architecture
|
||||||
from fairseq.models.transformer import TransformerModel, TransformerEncoder
|
from fairseq.models.transformer import TransformerEncoder, TransformerModel
|
||||||
|
|
||||||
|
|
||||||
@register_model("gru_transformer")
|
@register_model("gru_transformer")
|
||||||
@ -24,9 +24,12 @@ class GRUTransformerModel(TransformerModel):
|
|||||||
class GRUTransformerEncoder(TransformerEncoder):
|
class GRUTransformerEncoder(TransformerEncoder):
|
||||||
def __init__(self, args, dictionary, embed_tokens):
|
def __init__(self, args, dictionary, embed_tokens):
|
||||||
super().__init__(args, dictionary, embed_tokens)
|
super().__init__(args, dictionary, embed_tokens)
|
||||||
self.emb_ctx = nn.GRU(input_size=embed_tokens.embedding_dim,
|
self.emb_ctx = nn.GRU(
|
||||||
hidden_size=embed_tokens.embedding_dim // 2,
|
input_size=embed_tokens.embedding_dim,
|
||||||
num_layers=1, bidirectional=True)
|
hidden_size=embed_tokens.embedding_dim // 2,
|
||||||
|
num_layers=1,
|
||||||
|
bidirectional=True,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_embedding(self, src_tokens):
|
def forward_embedding(self, src_tokens):
|
||||||
# embed tokens and positions
|
# embed tokens and positions
|
||||||
|
@ -16,11 +16,12 @@ def main(args):
|
|||||||
print(normalizer.normalize(line.rstrip()), flush=True)
|
print(normalizer.normalize(line.rstrip()), flush=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--lang', '-l', default='en')
|
parser.add_argument("--lang", "-l", default="en")
|
||||||
parser.add_argument('--penn', '-p', action='store_true')
|
parser.add_argument("--penn", "-p", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -6,12 +6,14 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import sacremoses
|
import sacremoses
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
"""Tokenizes, preserving tabs"""
|
"""Tokenizes, preserving tabs"""
|
||||||
mt = sacremoses.MosesTokenizer(lang=args.lang)
|
mt = sacremoses.MosesTokenizer(lang=args.lang)
|
||||||
|
|
||||||
def tok(s):
|
def tok(s):
|
||||||
return mt.tokenize(s, return_str=True)
|
return mt.tokenize(s, return_str=True)
|
||||||
|
|
||||||
@ -20,12 +22,13 @@ def main(args):
|
|||||||
print(*parts, sep="\t", flush=True)
|
print(*parts, sep="\t", flush=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--lang', '-l', default='en')
|
parser.add_argument("--lang", "-l", default="en")
|
||||||
parser.add_argument('--penn', '-p', action='store_true')
|
parser.add_argument("--penn", "-p", action="store_true")
|
||||||
parser.add_argument('--fields', '-f', help="fields to tokenize")
|
parser.add_argument("--fields", "-f", help="fields to tokenize")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -3,14 +3,15 @@
|
|||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
import faiss
|
|
||||||
import numpy as np
|
|
||||||
import glob
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
from subprocess import check_call
|
from subprocess import check_call
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
GB = 1024*1024*1024
|
|
||||||
|
GB = 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def call(cmd):
|
def call(cmd):
|
||||||
@ -18,14 +19,14 @@ def call(cmd):
|
|||||||
check_call(cmd, shell=True)
|
check_call(cmd, shell=True)
|
||||||
|
|
||||||
|
|
||||||
def get_batches(directory, lang, prefix='all_avg_pool'):
|
def get_batches(directory, lang, prefix="all_avg_pool"):
|
||||||
print(f"Finding in {directory}/{prefix}.{lang}*")
|
print(f"Finding in {directory}/{prefix}.{lang}*")
|
||||||
files = glob.glob(f'{directory}/{prefix}.{lang}*')
|
files = glob.glob(f"{directory}/{prefix}.{lang}*")
|
||||||
emb_files = []
|
emb_files = []
|
||||||
txt_files = []
|
txt_files = []
|
||||||
for emb_fi in files:
|
for emb_fi in files:
|
||||||
emb_files.append(emb_fi)
|
emb_files.append(emb_fi)
|
||||||
txt_fi = emb_fi.replace(prefix, 'sentences')
|
txt_fi = emb_fi.replace(prefix, "sentences")
|
||||||
txt_files.append(txt_fi)
|
txt_files.append(txt_fi)
|
||||||
return emb_files, txt_files
|
return emb_files, txt_files
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ def load_batch(emb_file, dim):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
|
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
|
||||||
sims = []
|
sims = []
|
||||||
inds = []
|
inds = []
|
||||||
xfrom = 0
|
xfrom = 0
|
||||||
@ -53,7 +54,7 @@ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
|
|||||||
y_batch = load_batch(y_batch_f, dim)
|
y_batch = load_batch(y_batch_f, dim)
|
||||||
neighbor_size = min(k, y_batch.shape[0])
|
neighbor_size = min(k, y_batch.shape[0])
|
||||||
yto = yfrom + y_batch.shape[0]
|
yto = yfrom + y_batch.shape[0]
|
||||||
print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto))
|
print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
|
||||||
idx = faiss.IndexFlatIP(dim)
|
idx = faiss.IndexFlatIP(dim)
|
||||||
idx = faiss.index_cpu_to_all_gpus(idx)
|
idx = faiss.index_cpu_to_all_gpus(idx)
|
||||||
idx.add(y_batch)
|
idx.add(y_batch)
|
||||||
@ -86,8 +87,10 @@ def score(sim, fwd_mean, bwd_mean, margin):
|
|||||||
return margin(sim, (fwd_mean + bwd_mean) / 2)
|
return margin(sim, (fwd_mean + bwd_mean) / 2)
|
||||||
|
|
||||||
|
|
||||||
def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False):
|
def score_candidates(
|
||||||
print(' - scoring {:d} candidates'.format(sim_mat.shape[0]))
|
sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
|
||||||
|
):
|
||||||
|
print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
|
||||||
scores = np.zeros(candidate_inds.shape)
|
scores = np.zeros(candidate_inds.shape)
|
||||||
for i in range(scores.shape[0]):
|
for i in range(scores.shape[0]):
|
||||||
for j in range(scores.shape[1]):
|
for j in range(scores.shape[1]):
|
||||||
@ -106,42 +109,50 @@ def load_text(files):
|
|||||||
return all_sentences
|
return all_sentences
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Mine bitext')
|
parser = argparse.ArgumentParser(description="Mine bitext")
|
||||||
parser.add_argument('--src-lang', help='Source language')
|
parser.add_argument("--src-lang", help="Source language")
|
||||||
parser.add_argument('--tgt-lang', help='Target language')
|
parser.add_argument("--tgt-lang", help="Target language")
|
||||||
parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt')
|
parser.add_argument(
|
||||||
parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model')
|
"--dict-path", help="Path to dictionary file", default="dict.txt"
|
||||||
parser.add_argument('--dim', type=int, default=1024,
|
)
|
||||||
help='Embedding dimension')
|
parser.add_argument(
|
||||||
parser.add_argument('--mem', type=int, default=5,
|
"--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
|
||||||
help='Memory in GB')
|
)
|
||||||
parser.add_argument('--src-dir', help='Source directory')
|
parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
|
||||||
parser.add_argument('--tgt-dir', help='Target directory')
|
parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
|
||||||
parser.add_argument('--output', help='Output path')
|
parser.add_argument("--src-dir", help="Source directory")
|
||||||
parser.add_argument('--neighborhood', type=int, default=4,
|
parser.add_argument("--tgt-dir", help="Target directory")
|
||||||
help='Embedding dimension')
|
parser.add_argument("--output", help="Output path")
|
||||||
parser.add_argument('--threshold', type=float, default=1.06,
|
parser.add_argument(
|
||||||
help='Threshold on mined bitext')
|
"--neighborhood", type=int, default=4, help="Embedding dimension"
|
||||||
parser.add_argument('--valid-size', type=int, default=2000,
|
)
|
||||||
help='Number of sentences used for validation set')
|
parser.add_argument(
|
||||||
parser.add_argument('--min-count', type=int, default=50000,
|
"--threshold", type=float, default=1.06, help="Threshold on mined bitext"
|
||||||
help='Min num sentences used for each language')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--valid-size",
|
||||||
|
type=int,
|
||||||
|
default=2000,
|
||||||
|
help="Number of sentences used for validation set",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-count",
|
||||||
|
type=int,
|
||||||
|
default=50000,
|
||||||
|
help="Min num sentences used for each language",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
|
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
|
||||||
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
|
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
|
||||||
margin = lambda a, b: a / b
|
margin = lambda a, b: a / b
|
||||||
y2x_sim, y2x_ind = knnGPU_sharded(
|
y2x_sim, y2x_ind = knnGPU_sharded(
|
||||||
y_batches_f, x_batches_f,
|
y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
|
||||||
args.dim,
|
)
|
||||||
args.neighborhood,
|
|
||||||
direction='y2x')
|
|
||||||
x2y_sim, x2y_ind = knnGPU_sharded(
|
x2y_sim, x2y_ind = knnGPU_sharded(
|
||||||
x_batches_f, y_batches_f,
|
x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
|
||||||
args.dim,
|
)
|
||||||
args.neighborhood,
|
|
||||||
direction='x2y')
|
|
||||||
|
|
||||||
x2y_mean = x2y_sim.mean(axis=1)
|
x2y_mean = x2y_sim.mean(axis=1)
|
||||||
y2x_mean = y2x_sim.mean(axis=1)
|
y2x_mean = y2x_sim.mean(axis=1)
|
||||||
@ -149,8 +160,13 @@ if __name__ == '__main__':
|
|||||||
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
|
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
|
||||||
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
|
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
|
||||||
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
|
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
|
||||||
indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
|
indices = np.stack(
|
||||||
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1)
|
(
|
||||||
|
np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
|
||||||
|
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
|
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
|
||||||
|
|
||||||
x_sentences = load_text(x_sents_f)
|
x_sentences = load_text(x_sents_f)
|
||||||
@ -162,20 +178,20 @@ if __name__ == '__main__':
|
|||||||
directory = args.output
|
directory = args.output
|
||||||
call(f"mkdir -p {directory}")
|
call(f"mkdir -p {directory}")
|
||||||
src_out = open(
|
src_out = open(
|
||||||
f'{directory}/all.{args.src_lang}',
|
f"{directory}/all.{args.src_lang}",
|
||||||
mode='w',
|
mode="w",
|
||||||
encoding='utf-8',
|
encoding="utf-8",
|
||||||
errors='surrogateescape')
|
errors="surrogateescape",
|
||||||
|
)
|
||||||
tgt_out = open(
|
tgt_out = open(
|
||||||
f'{directory}/all.{args.tgt_lang}',
|
f"{directory}/all.{args.tgt_lang}",
|
||||||
mode='w',
|
mode="w",
|
||||||
encoding='utf-8',
|
encoding="utf-8",
|
||||||
errors='surrogateescape')
|
errors="surrogateescape",
|
||||||
|
)
|
||||||
scores_out = open(
|
scores_out = open(
|
||||||
f'{directory}/all.scores',
|
f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
|
||||||
mode='w',
|
)
|
||||||
encoding='utf-8',
|
|
||||||
errors='surrogateescape')
|
|
||||||
count = 0
|
count = 0
|
||||||
for i in np.argsort(-scores):
|
for i in np.argsort(-scores):
|
||||||
src_ind, trg_ind = indices[i]
|
src_ind, trg_ind = indices[i]
|
||||||
@ -195,20 +211,23 @@ if __name__ == '__main__':
|
|||||||
scores_out.close()
|
scores_out.close()
|
||||||
|
|
||||||
print(f"Found {count} pairs for threshold={threshold}")
|
print(f"Found {count} pairs for threshold={threshold}")
|
||||||
with open(f'{directory}/all.{args.src_lang}') as all_s, \
|
with open(f"{directory}/all.{args.src_lang}") as all_s, open(
|
||||||
open(f'{directory}/all.{args.tgt_lang}') as all_t, \
|
f"{directory}/all.{args.tgt_lang}"
|
||||||
open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \
|
) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
|
||||||
open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \
|
f"{directory}/valid.{args.tgt_lang}", "w"
|
||||||
open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \
|
) as valid_t, open(
|
||||||
open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t:
|
f"{directory}/train.{args.src_lang}", "w"
|
||||||
count = 0
|
) as train_s, open(
|
||||||
for s_line, t_line in zip(all_s, all_t):
|
f"{directory}/train.{args.tgt_lang}", "w"
|
||||||
s_line = s_line.split('\t')[1]
|
) as train_t:
|
||||||
t_line = t_line.split('\t')[1]
|
count = 0
|
||||||
if count >= args.valid_size:
|
for s_line, t_line in zip(all_s, all_t):
|
||||||
train_s.write(s_line)
|
s_line = s_line.split("\t")[1]
|
||||||
train_t.write(t_line)
|
t_line = t_line.split("\t")[1]
|
||||||
else:
|
if count >= args.valid_size:
|
||||||
valid_s.write(s_line)
|
train_s.write(s_line)
|
||||||
valid_t.write(t_line)
|
train_t.write(t_line)
|
||||||
count += 1
|
else:
|
||||||
|
valid_s.write(s_line)
|
||||||
|
valid_t.write(t_line)
|
||||||
|
count += 1
|
||||||
|
@ -7,27 +7,29 @@
|
|||||||
Translate pre-processed data with a trained model.
|
Translate pre-processed data with a trained model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
||||||
from fairseq.sequence_generator import EnsembleModel
|
from fairseq.sequence_generator import EnsembleModel
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False):
|
def get_avg_pool(
|
||||||
|
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
|
||||||
|
):
|
||||||
model = EnsembleModel(models)
|
model = EnsembleModel(models)
|
||||||
|
|
||||||
# model.forward normally channels prev_output_tokens into the decoder
|
# model.forward normally channels prev_output_tokens into the decoder
|
||||||
# separately, but SequenceGenerator directly calls model.encoder
|
# separately, but SequenceGenerator directly calls model.encoder
|
||||||
encoder_input = {
|
encoder_input = {
|
||||||
k: v for k, v in sample['net_input'].items()
|
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||||
if k != 'prev_output_tokens'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# compute the encoder output for each beam
|
# compute the encoder output for each beam
|
||||||
encoder_outs = model.forward_encoder(encoder_input)
|
encoder_outs = model.forward_encoder(encoder_input)
|
||||||
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
|
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
|
||||||
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32)
|
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
|
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
|
||||||
if has_langtok:
|
if has_langtok:
|
||||||
encoder_mask = encoder_mask[1:, :, :]
|
encoder_mask = encoder_mask[1:, :, :]
|
||||||
@ -38,13 +40,15 @@ def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langto
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert args.path is not None, '--path required for generation!'
|
assert args.path is not None, "--path required for generation!"
|
||||||
assert not args.sampling or args.nbest == args.beam, \
|
assert (
|
||||||
'--sampling requires --nbest to be equal to --beam'
|
not args.sampling or args.nbest == args.beam
|
||||||
assert args.replace_unk is None or args.raw_text, \
|
), "--sampling requires --nbest to be equal to --beam"
|
||||||
'--replace-unk requires a raw text dataset (--raw-text)'
|
assert (
|
||||||
|
args.replace_unk is None or args.raw_text
|
||||||
|
), "--replace-unk requires a raw text dataset (--raw-text)"
|
||||||
|
|
||||||
args.beam=1
|
args.beam = 1
|
||||||
utils.import_user_module(args)
|
utils.import_user_module(args)
|
||||||
|
|
||||||
if args.max_tokens is None:
|
if args.max_tokens is None:
|
||||||
@ -58,15 +62,15 @@ def main(args):
|
|||||||
|
|
||||||
# Set dictionaries
|
# Set dictionaries
|
||||||
try:
|
try:
|
||||||
src_dict = getattr(task, 'source_dictionary', None)
|
src_dict = getattr(task, "source_dictionary", None)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
src_dict = None
|
src_dict = None
|
||||||
tgt_dict = task.target_dictionary
|
tgt_dict = task.target_dictionary
|
||||||
|
|
||||||
# Load ensemble
|
# Load ensemble
|
||||||
print('| loading model(s) from {}'.format(args.path))
|
print("| loading model(s) from {}".format(args.path))
|
||||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||||
args.path.split(':'),
|
args.path.split(":"),
|
||||||
arg_overrides=eval(args.model_overrides),
|
arg_overrides=eval(args.model_overrides),
|
||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
@ -105,9 +109,9 @@ def main(args):
|
|||||||
shard_id = 0
|
shard_id = 0
|
||||||
all_avg_pool = None
|
all_avg_pool = None
|
||||||
encoder_has_langtok = (
|
encoder_has_langtok = (
|
||||||
hasattr(task.args, 'encoder_langtok')
|
hasattr(task.args, "encoder_langtok")
|
||||||
and task.args.encoder_langtok is not None
|
and task.args.encoder_langtok is not None
|
||||||
and hasattr(task.args, 'lang_tok_replacing_bos_eos')
|
and hasattr(task.args, "lang_tok_replacing_bos_eos")
|
||||||
and not task.args.lang_tok_replacing_bos_eos
|
and not task.args.lang_tok_replacing_bos_eos
|
||||||
)
|
)
|
||||||
with progress_bar.build_progress_bar(args, itr) as t:
|
with progress_bar.build_progress_bar(args, itr) as t:
|
||||||
@ -116,34 +120,42 @@ def main(args):
|
|||||||
print("Skipping None")
|
print("Skipping None")
|
||||||
continue
|
continue
|
||||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||||
if 'net_input' not in sample:
|
if "net_input" not in sample:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prefix_tokens = None
|
prefix_tokens = None
|
||||||
if args.prefix_size > 0:
|
if args.prefix_size > 0:
|
||||||
prefix_tokens = sample['target'][:, :args.prefix_size]
|
prefix_tokens = sample["target"][:, : args.prefix_size]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
avg_pool = get_avg_pool(
|
avg_pool = get_avg_pool(
|
||||||
models, sample, prefix_tokens, src_dict,
|
models,
|
||||||
args.remove_bpe,
|
sample,
|
||||||
has_langtok=encoder_has_langtok)
|
prefix_tokens,
|
||||||
|
src_dict,
|
||||||
|
args.remove_bpe,
|
||||||
|
has_langtok=encoder_has_langtok,
|
||||||
|
)
|
||||||
if all_avg_pool is not None:
|
if all_avg_pool is not None:
|
||||||
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
|
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
|
||||||
else:
|
else:
|
||||||
all_avg_pool = avg_pool
|
all_avg_pool = avg_pool
|
||||||
|
|
||||||
if not isinstance(sample['id'], list):
|
if not isinstance(sample["id"], list):
|
||||||
sample_ids = sample['id'].tolist()
|
sample_ids = sample["id"].tolist()
|
||||||
else:
|
else:
|
||||||
sample_ids = sample['id']
|
sample_ids = sample["id"]
|
||||||
for i, sample_id in enumerate(sample_ids):
|
for i, sample_id in enumerate(sample_ids):
|
||||||
# Remove padding
|
# Remove padding
|
||||||
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
|
src_tokens = utils.strip_pad(
|
||||||
|
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
|
||||||
|
)
|
||||||
|
|
||||||
# Either retrieve the original sentences or regenerate them from tokens.
|
# Either retrieve the original sentences or regenerate them from tokens.
|
||||||
if align_dict is not None:
|
if align_dict is not None:
|
||||||
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
|
src_str = task.dataset(args.gen_subset).src.get_original_text(
|
||||||
|
sample_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if src_dict is not None:
|
if src_dict is not None:
|
||||||
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
||||||
@ -152,37 +164,50 @@ def main(args):
|
|||||||
|
|
||||||
if not args.quiet:
|
if not args.quiet:
|
||||||
if src_dict is not None:
|
if src_dict is not None:
|
||||||
print('S-{}\t{}'.format(sample_id, src_str))
|
print("S-{}\t{}".format(sample_id, src_str))
|
||||||
|
|
||||||
source_sentences.append(f"{sample_id}\t{src_str}")
|
source_sentences.append(f"{sample_id}\t{src_str}")
|
||||||
|
|
||||||
num_sentences += sample['nsentences']
|
num_sentences += sample["nsentences"]
|
||||||
if all_avg_pool.shape[0] >= 1000000:
|
if all_avg_pool.shape[0] >= 1000000:
|
||||||
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
|
with open(
|
||||||
'w') as avg_pool_file:
|
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
|
||||||
|
"w",
|
||||||
|
) as avg_pool_file:
|
||||||
all_avg_pool.tofile(avg_pool_file)
|
all_avg_pool.tofile(avg_pool_file)
|
||||||
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
|
with open(
|
||||||
sentence_file.writelines(f'{line}\n' for line in source_sentences)
|
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
|
||||||
|
"w",
|
||||||
|
) as sentence_file:
|
||||||
|
sentence_file.writelines(f"{line}\n" for line in source_sentences)
|
||||||
all_avg_pool = None
|
all_avg_pool = None
|
||||||
source_sentences = []
|
source_sentences = []
|
||||||
shard_id += 1
|
shard_id += 1
|
||||||
|
|
||||||
if all_avg_pool is not None:
|
if all_avg_pool is not None:
|
||||||
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
|
with open(
|
||||||
'w') as avg_pool_file:
|
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
|
||||||
|
) as avg_pool_file:
|
||||||
all_avg_pool.tofile(avg_pool_file)
|
all_avg_pool.tofile(avg_pool_file)
|
||||||
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
|
with open(
|
||||||
sentence_file.writelines(f'{line}\n' for line in source_sentences)
|
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
|
||||||
|
) as sentence_file:
|
||||||
|
sentence_file.writelines(f"{line}\n" for line in source_sentences)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def cli_main():
|
def cli_main():
|
||||||
parser = options.get_generation_parser()
|
parser = options.get_generation_parser()
|
||||||
parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N',
|
parser.add_argument(
|
||||||
help='directory to save encoder outputs')
|
"--encoder-save-dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
metavar="N",
|
||||||
|
help="directory to save encoder outputs",
|
||||||
|
)
|
||||||
args = options.parse_args_and_arch(parser)
|
args = options.parse_args_and_arch(parser)
|
||||||
main(args)
|
main(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
import numpy as np
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
DIM = 1024
|
DIM = 1024
|
||||||
|
|
||||||
@ -14,9 +15,13 @@ DIM = 1024
|
|||||||
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
|
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
|
||||||
target_ids = [tid for tid in target_embs]
|
target_ids = [tid for tid in target_embs]
|
||||||
source_mat = np.stack(source_embs.values(), axis=0)
|
source_mat = np.stack(source_embs.values(), axis=0)
|
||||||
normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True)
|
normalized_source_mat = source_mat / np.linalg.norm(
|
||||||
|
source_mat, axis=1, keepdims=True
|
||||||
|
)
|
||||||
target_mat = np.stack(target_embs.values(), axis=0)
|
target_mat = np.stack(target_embs.values(), axis=0)
|
||||||
normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True)
|
normalized_target_mat = target_mat / np.linalg.norm(
|
||||||
|
target_mat, axis=1, keepdims=True
|
||||||
|
)
|
||||||
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
|
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
|
||||||
if return_sim_mat:
|
if return_sim_mat:
|
||||||
return sim_mat
|
return sim_mat
|
||||||
@ -36,14 +41,14 @@ def load_embeddings(directory, LANGS):
|
|||||||
lang_dir = f"{directory}/{lang}"
|
lang_dir = f"{directory}/{lang}"
|
||||||
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
|
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
|
||||||
for embed_file in embedding_files:
|
for embed_file in embedding_files:
|
||||||
shard_id = embed_file.split('.')[-1]
|
shard_id = embed_file.split(".")[-1]
|
||||||
embeddings = np.fromfile(embed_file, dtype=np.float32)
|
embeddings = np.fromfile(embed_file, dtype=np.float32)
|
||||||
num_rows = embeddings.shape[0] // DIM
|
num_rows = embeddings.shape[0] // DIM
|
||||||
embeddings = embeddings.reshape((num_rows, DIM))
|
embeddings = embeddings.reshape((num_rows, DIM))
|
||||||
|
|
||||||
with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file:
|
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
|
||||||
for idx, line in enumerate(sentence_file):
|
for idx, line in enumerate(sentence_file):
|
||||||
sentence_id, sentence = line.strip().split('\t')
|
sentence_id, sentence = line.strip().split("\t")
|
||||||
sentence_texts[lang][sentence_id] = sentence
|
sentence_texts[lang][sentence_id] = sentence
|
||||||
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
|
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
|
||||||
|
|
||||||
@ -55,7 +60,7 @@ def compute_accuracy(directory, LANGS):
|
|||||||
|
|
||||||
top_1_accuracy = {}
|
top_1_accuracy = {}
|
||||||
|
|
||||||
top1_str = " ".join(LANGS) + '\n'
|
top1_str = " ".join(LANGS) + "\n"
|
||||||
for source_lang in LANGS:
|
for source_lang in LANGS:
|
||||||
top_1_accuracy[source_lang] = {}
|
top_1_accuracy[source_lang] = {}
|
||||||
top1_str += f"{source_lang} "
|
top1_str += f"{source_lang} "
|
||||||
@ -63,8 +68,8 @@ def compute_accuracy(directory, LANGS):
|
|||||||
top1 = 0
|
top1 = 0
|
||||||
top5 = 0
|
top5 = 0
|
||||||
neighbors_map = compute_dist(
|
neighbors_map = compute_dist(
|
||||||
sentence_embeddings[source_lang],
|
sentence_embeddings[source_lang], sentence_embeddings[target_lang]
|
||||||
sentence_embeddings[target_lang])
|
)
|
||||||
for sentence_id, neighbors in neighbors_map.items():
|
for sentence_id, neighbors in neighbors_map.items():
|
||||||
if sentence_id == neighbors[0]:
|
if sentence_id == neighbors[0]:
|
||||||
top1 += 1
|
top1 += 1
|
||||||
@ -75,17 +80,13 @@ def compute_accuracy(directory, LANGS):
|
|||||||
top1_str += "\n"
|
top1_str += "\n"
|
||||||
|
|
||||||
print(top1_str)
|
print(top1_str)
|
||||||
print(top1_str, file=open(f"{directory}/accuracy", 'w'))
|
print(top1_str, file=open(f"{directory}/accuracy", "w"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Analyze encoder outputs')
|
parser = argparse.ArgumentParser(description="Analyze encoder outputs")
|
||||||
parser.add_argument('directory',
|
parser.add_argument("directory", help="Source language corpus")
|
||||||
help='Source language corpus'
|
parser.add_argument("--langs", help="List of langs")
|
||||||
)
|
|
||||||
parser.add_argument('--langs',
|
|
||||||
help='List of langs'
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
langs = args.langs.split(',')
|
langs = args.langs.split(",")
|
||||||
compute_accuracy(args.directory, langs)
|
compute_accuracy(args.directory, langs)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .models import latent_multilingual_transformer # noqa
|
from . import multilingual_translation_latent_depth # noqa
|
||||||
from .modules import latent_layers # noqa
|
from .loss import latent_depth # noqa
|
||||||
from .loss import latent_depth # noqa
|
from .models import latent_multilingual_transformer # noqa
|
||||||
from . import multilingual_translation_latent_depth # noqa
|
from .modules import latent_layers # noqa
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
|
||||||
@ -19,17 +20,16 @@ class LatentLayersKLLoss(_Loss):
|
|||||||
eps = 1e-7
|
eps = 1e-7
|
||||||
if prior == "uniform":
|
if prior == "uniform":
|
||||||
# uniform prior
|
# uniform prior
|
||||||
kl_loss = (samples * (
|
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
|
||||||
torch.log(samples + eps) - math.log(0.5)
|
|
||||||
)).sum(-1)
|
|
||||||
elif prior == "agged_posterior":
|
elif prior == "agged_posterior":
|
||||||
# aggregated posterior
|
# aggregated posterior
|
||||||
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
|
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
|
||||||
agged_q = torch.sum(y_t, dim=0)
|
agged_q = torch.sum(y_t, dim=0)
|
||||||
row_norm = agged_q.sum(-1)
|
row_norm = agged_q.sum(-1)
|
||||||
normed_agg_q = agged_q / row_norm
|
normed_agg_q = agged_q / row_norm
|
||||||
kl_loss = (samples * (
|
kl_loss = (
|
||||||
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
|
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
|
||||||
|
).sum(-1)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The specified prior is not implemented.")
|
raise NotImplementedError("The specified prior is not implemented.")
|
||||||
|
|
||||||
@ -37,7 +37,9 @@ class LatentLayersKLLoss(_Loss):
|
|||||||
kl_loss /= layer_samples[0].size()[0]
|
kl_loss /= layer_samples[0].size()[0]
|
||||||
kl_weight = min(
|
kl_weight = min(
|
||||||
self.args.sparsity_weight,
|
self.args.sparsity_weight,
|
||||||
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
|
(update_num - self.args.soft_update)
|
||||||
|
* self.args.sparsity_weight
|
||||||
|
/ self.args.anneal_updates,
|
||||||
)
|
)
|
||||||
kl_loss *= kl_weight * sample_size
|
kl_loss *= kl_weight * sample_size
|
||||||
return kl_loss
|
return kl_loss
|
||||||
@ -58,15 +60,17 @@ class LatentLayersSparsityLoss(_Loss):
|
|||||||
share_loss = 0
|
share_loss = 0
|
||||||
global_sparsity_loss = 0
|
global_sparsity_loss = 0
|
||||||
layer_samples = torch.stack(layer_samples_list, dim=0)
|
layer_samples = torch.stack(layer_samples_list, dim=0)
|
||||||
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
|
if (
|
||||||
update_num > (self.args.soft_update + self.args.anneal_updates)):
|
self.args.target_layers > 0 or self.args.share_weight > 0
|
||||||
|
) and update_num > (self.args.soft_update + self.args.anneal_updates):
|
||||||
# anneal sparsity weight
|
# anneal sparsity weight
|
||||||
if update_num < (self.args.anneal_updates + self.args.soft_update):
|
if update_num < (self.args.anneal_updates + self.args.soft_update):
|
||||||
weight_anneal = 0
|
weight_anneal = 0
|
||||||
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
|
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
|
||||||
weight_anneal = (
|
weight_anneal = (
|
||||||
(update_num - self.args.soft_update - self.args.anneal_updates)
|
(update_num - self.args.soft_update - self.args.anneal_updates)
|
||||||
* self.args.share_weight / self.args.anneal_updates
|
* self.args.share_weight
|
||||||
|
/ self.args.anneal_updates
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
weight_anneal = 1
|
weight_anneal = 1
|
||||||
@ -75,12 +79,21 @@ class LatentLayersSparsityLoss(_Loss):
|
|||||||
layer_utilization /= layer_samples.size()[0]
|
layer_utilization /= layer_samples.size()[0]
|
||||||
if self.args.share_weight > 0:
|
if self.args.share_weight > 0:
|
||||||
# encouraging sharing across languages
|
# encouraging sharing across languages
|
||||||
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
|
share_loss = sum(
|
||||||
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
|
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
|
||||||
|
)
|
||||||
|
batch_loss += (
|
||||||
|
weight_anneal * self.args.share_weight * sample_size * share_loss
|
||||||
|
)
|
||||||
if self.args.target_layers > 0:
|
if self.args.target_layers > 0:
|
||||||
# computed expected number of layers selected
|
# computed expected number of layers selected
|
||||||
expeted_layers = sum(layer_utilization)
|
expeted_layers = sum(layer_utilization)
|
||||||
# compute l2 loss wrt target number of layers
|
# compute l2 loss wrt target number of layers
|
||||||
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
|
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
|
||||||
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
|
batch_loss += (
|
||||||
|
weight_anneal
|
||||||
|
* self.args.share_weight
|
||||||
|
* sample_size
|
||||||
|
* global_sparsity_loss
|
||||||
|
)
|
||||||
return batch_loss
|
return batch_loss
|
||||||
|
@ -3,34 +3,31 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from fairseq.models import (
|
from fairseq.models import register_model, register_model_architecture
|
||||||
register_model,
|
|
||||||
register_model_architecture,
|
|
||||||
)
|
|
||||||
from fairseq.models.transformer import (
|
|
||||||
base_architecture,
|
|
||||||
TransformerEncoder,
|
|
||||||
TransformerDecoder,
|
|
||||||
)
|
|
||||||
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
|
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
|
||||||
|
from fairseq.models.transformer import (
|
||||||
from .latent_transformer import (
|
TransformerDecoder,
|
||||||
LatentTransformerEncoder,
|
TransformerEncoder,
|
||||||
LatentTransformerDecoder,
|
base_architecture,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
|
||||||
|
|
||||||
@register_model('latent_multilingual_transformer')
|
|
||||||
|
@register_model("latent_multilingual_transformer")
|
||||||
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
||||||
"""A variant of standard multilingual Transformer models which encoder and/or
|
"""A variant of standard multilingual Transformer models which encoder and/or
|
||||||
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
|
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
|
||||||
(https://arxiv.org/abs/2009.13102).
|
(https://arxiv.org/abs/2009.13102).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
|
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
|
||||||
if is_encoder:
|
if is_encoder:
|
||||||
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
|
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
|
||||||
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
|
return LatentTransformerEncoder(
|
||||||
|
args, lang_dict, embed_tokens, num_logits=len(langs)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return TransformerEncoder(args, lang_dict, embed_tokens)
|
return TransformerEncoder(args, lang_dict, embed_tokens)
|
||||||
else:
|
else:
|
||||||
@ -42,19 +39,21 @@ class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
|||||||
return TransformerDecoder(args, lang_dict, embed_tokens)
|
return TransformerDecoder(args, lang_dict, embed_tokens)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
|
@register_model_architecture(
|
||||||
|
"latent_multilingual_transformer", "latent_multilingual_transformer"
|
||||||
|
)
|
||||||
def latent_multilingual_architecture(args):
|
def latent_multilingual_architecture(args):
|
||||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
||||||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||||
args.encoder_layers = getattr(args, 'encoder_layers', 12)
|
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||||
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
||||||
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
||||||
args.decoder_layers = getattr(args, 'decoder_layers', 24)
|
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
||||||
args.share_encoders = getattr(args, 'share_encoders', True)
|
args.share_encoders = getattr(args, "share_encoders", True)
|
||||||
args.share_decoders = getattr(args, 'share_decoders', True)
|
args.share_decoders = getattr(args, "share_decoders", True)
|
||||||
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
|
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
|
||||||
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)
|
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
|
||||||
|
|
||||||
base_architecture(args)
|
base_architecture(args)
|
||||||
|
@ -7,26 +7,27 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from fairseq.models.fairseq_encoder import EncoderOut
|
from fairseq.models.fairseq_encoder import EncoderOut
|
||||||
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
|
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
|
||||||
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
|
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
|
||||||
from ..modules.latent_layers import LayerSelect
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ..modules.latent_layers import LayerSelect
|
||||||
|
|
||||||
|
|
||||||
class LatentTransformerEncoder(TransformerEncoder):
|
class LatentTransformerEncoder(TransformerEncoder):
|
||||||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
||||||
TransformerEncoder.
|
TransformerEncoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
|
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
|
||||||
self.num_logits = num_logits
|
self.num_logits = num_logits
|
||||||
self.num_layers = args.encoder_layers
|
self.num_layers = args.encoder_layers
|
||||||
super().__init__(args, dictionary, embed_tokens)
|
super().__init__(args, dictionary, embed_tokens)
|
||||||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
||||||
self.lang_idx = None
|
self.lang_idx = None
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
self._build_encoder_layer(args, idx)
|
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
|
||||||
for idx in range(args.encoder_layers)
|
)
|
||||||
])
|
|
||||||
|
|
||||||
def set_lang_idx(self, lang_idx):
|
def set_lang_idx(self, lang_idx):
|
||||||
self.lang_idx = lang_idx
|
self.lang_idx = lang_idx
|
||||||
@ -50,6 +51,7 @@ class LatentTransformerEncoderLayer(TransformerEncoderLayer):
|
|||||||
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
|
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
|
||||||
parameters and sampling method.
|
parameters and sampling method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, idx, layer_select=None):
|
def __init__(self, args, idx, layer_select=None):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.idx = idx
|
self.idx = idx
|
||||||
@ -63,7 +65,10 @@ class LatentTransformerDecoder(TransformerDecoder):
|
|||||||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
||||||
TransformerDecoder.
|
TransformerDecoder.
|
||||||
"""
|
"""
|
||||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
|
|
||||||
|
def __init__(
|
||||||
|
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
|
||||||
|
):
|
||||||
self.num_logits = num_logits
|
self.num_logits = num_logits
|
||||||
self.num_layers = args.decoder_layers
|
self.num_layers = args.decoder_layers
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -71,16 +76,20 @@ class LatentTransformerDecoder(TransformerDecoder):
|
|||||||
)
|
)
|
||||||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
||||||
self.lang_idx = None
|
self.lang_idx = None
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
self._build_decoder_layer(args, no_encoder_attn, idx)
|
[
|
||||||
for idx in range(args.decoder_layers)
|
self._build_decoder_layer(args, no_encoder_attn, idx)
|
||||||
])
|
for idx in range(args.decoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def set_lang_idx(self, lang_idx):
|
def set_lang_idx(self, lang_idx):
|
||||||
self.lang_idx = lang_idx
|
self.lang_idx = lang_idx
|
||||||
|
|
||||||
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
|
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
|
||||||
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)
|
return LatentTransformerDecoderLayer(
|
||||||
|
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -119,8 +128,15 @@ class LatentTransformerDecoderLayer(TransformerDecoderLayer):
|
|||||||
(default: False).
|
(default: False).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
self,
|
||||||
|
args,
|
||||||
|
idx,
|
||||||
|
layer_select=None,
|
||||||
|
no_encoder_attn=False,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
):
|
):
|
||||||
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
|
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
|
||||||
self.idx = idx
|
self.idx = idx
|
||||||
|
@ -12,6 +12,7 @@ class LayerSelect(nn.Module):
|
|||||||
either (soft) weighting or (hard) selection of residual connection.
|
either (soft) weighting or (hard) selection of residual connection.
|
||||||
https://arxiv.org/abs/2009.13102
|
https://arxiv.org/abs/2009.13102
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_layers, num_logits, args):
|
def __init__(self, num_layers, num_logits, args):
|
||||||
super(LayerSelect, self).__init__()
|
super(LayerSelect, self).__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
@ -27,14 +28,14 @@ class LayerSelect(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--soft-select',
|
"--soft-select",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='use soft samples in training an inference'
|
help="use soft samples in training an inference",
|
||||||
)
|
)
|
||||||
parser.add_argument('--sampling-tau', type=float, help='sampling temperature')
|
parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
|
||||||
|
|
||||||
def sample(self, logit_idx):
|
def sample(self, logit_idx):
|
||||||
""" To leverage the efficiency of distributed training, samples for all
|
"""To leverage the efficiency of distributed training, samples for all
|
||||||
layers are computed at once for each logit_idx. Logits are parameters
|
layers are computed at once for each logit_idx. Logits are parameters
|
||||||
learnt independent of each other.
|
learnt independent of each other.
|
||||||
|
|
||||||
@ -43,7 +44,9 @@ class LayerSelect(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert logit_idx is not None
|
assert logit_idx is not None
|
||||||
self.samples = self._gumbel_sigmoid(
|
self.samples = self._gumbel_sigmoid(
|
||||||
self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :],
|
self.layer_logits[logit_idx, :].detach()
|
||||||
|
if self.detach_grad
|
||||||
|
else self.layer_logits[logit_idx, :],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
tau=self.tau,
|
tau=self.tau,
|
||||||
hard=self.hard_select,
|
hard=self.hard_select,
|
||||||
@ -54,10 +57,20 @@ class LayerSelect(nn.Module):
|
|||||||
sample = self.samples[i]
|
sample = self.samples[i]
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5):
|
def _gumbel_sigmoid(
|
||||||
|
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
|
||||||
|
):
|
||||||
# ~Gumbel(0,1)
|
# ~Gumbel(0,1)
|
||||||
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
gumbels1 = (
|
||||||
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
||||||
|
.exponential_()
|
||||||
|
.log()
|
||||||
|
)
|
||||||
|
gumbels2 = (
|
||||||
|
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
||||||
|
.exponential_()
|
||||||
|
.log()
|
||||||
|
)
|
||||||
# Difference of two gumbels because we apply a sigmoid
|
# Difference of two gumbels because we apply a sigmoid
|
||||||
gumbels1 = (logits + gumbels1 - gumbels2) / tau
|
gumbels1 = (logits + gumbels1 - gumbels2) / tau
|
||||||
y_soft = gumbels1.sigmoid()
|
y_soft = gumbels1.sigmoid()
|
||||||
|
@ -5,10 +5,11 @@
|
|||||||
|
|
||||||
from fairseq.tasks import register_task
|
from fairseq.tasks import register_task
|
||||||
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
|
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
|
||||||
|
|
||||||
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
|
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
|
||||||
|
|
||||||
|
|
||||||
@register_task('multilingual_translation_latent_depth')
|
@register_task("multilingual_translation_latent_depth")
|
||||||
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||||
"""A task for multiple translation with latent depth.
|
"""A task for multiple translation with latent depth.
|
||||||
|
|
||||||
@ -39,7 +40,9 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
|||||||
|
|
||||||
def __init__(self, args, dicts, training):
|
def __init__(self, args, dicts, training):
|
||||||
super().__init__(args, dicts, training)
|
super().__init__(args, dicts, training)
|
||||||
self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs])
|
self.src_langs, self.tgt_langs = zip(
|
||||||
|
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
|
||||||
|
)
|
||||||
if self.training and self.encoder_latent_layer:
|
if self.training and self.encoder_latent_layer:
|
||||||
assert self.args.share_encoders
|
assert self.args.share_encoders
|
||||||
if self.training and self.decoder_latent_layer:
|
if self.training and self.decoder_latent_layer:
|
||||||
@ -47,46 +50,56 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
|||||||
if training or self.encoder_latent_layer or self.decoder_latent_layer:
|
if training or self.encoder_latent_layer or self.decoder_latent_layer:
|
||||||
self.lang_pairs = args.lang_pairs
|
self.lang_pairs = args.lang_pairs
|
||||||
else:
|
else:
|
||||||
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
|
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
|
||||||
self.eval_lang_pairs = self.lang_pairs
|
self.eval_lang_pairs = self.lang_pairs
|
||||||
self.model_lang_pairs = self.lang_pairs
|
self.model_lang_pairs = self.lang_pairs
|
||||||
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
|
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
|
||||||
self.kl_loss = LatentLayersKLLoss(self.args)
|
self.kl_loss = LatentLayersKLLoss(self.args)
|
||||||
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
|
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
|
||||||
|
|
||||||
def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
|
def _per_lang_pair_train_loss(
|
||||||
|
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
|
||||||
|
):
|
||||||
src, tgt = lang_pair.split("-")
|
src, tgt = lang_pair.split("-")
|
||||||
if self.encoder_latent_layer:
|
if self.encoder_latent_layer:
|
||||||
src_lang_idx = self.src_lang_idx_dict[src]
|
src_lang_idx = self.src_lang_idx_dict[src]
|
||||||
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
|
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
|
||||||
model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update
|
model.models[lang_pair].encoder.layer_select.hard_select = (
|
||||||
|
update_num > self.args.soft_update
|
||||||
|
)
|
||||||
if self.decoder_latent_layer:
|
if self.decoder_latent_layer:
|
||||||
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
|
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
|
||||||
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
|
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
|
||||||
model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update
|
model.models[lang_pair].decoder.layer_select.hard_select = (
|
||||||
|
update_num > self.args.soft_update
|
||||||
|
)
|
||||||
|
|
||||||
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
|
loss, sample_size, logging_output = criterion(
|
||||||
|
model.models[lang_pair], sample[lang_pair]
|
||||||
|
)
|
||||||
if self.encoder_latent_layer:
|
if self.encoder_latent_layer:
|
||||||
none_samples = sum(
|
none_samples = sum(
|
||||||
1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples
|
1 if x is None else 0
|
||||||
|
for x in model.models[lang_pair].encoder.layer_select.layer_samples
|
||||||
)
|
)
|
||||||
if none_samples == 0 or self.args.prior != "agged_posterior":
|
if none_samples == 0 or self.args.prior != "agged_posterior":
|
||||||
loss += self.kl_loss(
|
loss += self.kl_loss(
|
||||||
model.models[lang_pair].encoder.layer_select.layer_samples,
|
model.models[lang_pair].encoder.layer_select.layer_samples,
|
||||||
src_lang_idx,
|
src_lang_idx,
|
||||||
update_num,
|
update_num,
|
||||||
sample_size
|
sample_size,
|
||||||
)
|
)
|
||||||
if self.decoder_latent_layer:
|
if self.decoder_latent_layer:
|
||||||
none_samples = sum(
|
none_samples = sum(
|
||||||
1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples
|
1 if x is None else 0
|
||||||
|
for x in model.models[lang_pair].decoder.layer_select.layer_samples
|
||||||
)
|
)
|
||||||
if none_samples == 0 or self.args.prior != "agged_posterior":
|
if none_samples == 0 or self.args.prior != "agged_posterior":
|
||||||
loss += self.kl_loss(
|
loss += self.kl_loss(
|
||||||
model.models[lang_pair].decoder.layer_select.layer_samples,
|
model.models[lang_pair].decoder.layer_select.layer_samples,
|
||||||
tgt_lang_idx,
|
tgt_lang_idx,
|
||||||
update_num,
|
update_num,
|
||||||
sample_size
|
sample_size,
|
||||||
)
|
)
|
||||||
if ignore_grad:
|
if ignore_grad:
|
||||||
loss *= 0
|
loss *= 0
|
||||||
@ -99,18 +112,31 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
|||||||
|
|
||||||
return loss, sample_size, logging_output
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
|
def train_step(
|
||||||
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||||
|
):
|
||||||
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
|
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
|
||||||
sample, model, criterion, optimizer, update_num, ignore_grad)
|
sample, model, criterion, optimizer, update_num, ignore_grad
|
||||||
|
)
|
||||||
# compute auxiliary loss from layere sparsity, based on all samples from all languages
|
# compute auxiliary loss from layere sparsity, based on all samples from all languages
|
||||||
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
|
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
|
||||||
sparsity_loss = 0
|
sparsity_loss = 0
|
||||||
if self.encoder_latent_layer:
|
if self.encoder_latent_layer:
|
||||||
sparsity_loss += self.sparsity_loss(
|
sparsity_loss += self.sparsity_loss(
|
||||||
next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size)
|
next(
|
||||||
|
iter(model.models.values())
|
||||||
|
).encoder.layer_select.layer_samples,
|
||||||
|
update_num,
|
||||||
|
agg_sample_size,
|
||||||
|
)
|
||||||
if self.decoder_latent_layer:
|
if self.decoder_latent_layer:
|
||||||
sparsity_loss += self.sparsity_loss(
|
sparsity_loss += self.sparsity_loss(
|
||||||
next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size)
|
next(
|
||||||
|
iter(model.models.values())
|
||||||
|
).decoder.layer_select.layer_samples,
|
||||||
|
update_num,
|
||||||
|
agg_sample_size,
|
||||||
|
)
|
||||||
if sparsity_loss > 0:
|
if sparsity_loss > 0:
|
||||||
optimizer.backward(sparsity_loss)
|
optimizer.backward(sparsity_loss)
|
||||||
return agg_loss, agg_sample_size, agg_logging_output
|
return agg_loss, agg_sample_size, agg_logging_output
|
||||||
@ -123,10 +149,14 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
|||||||
if self.decoder_latent_layer:
|
if self.decoder_latent_layer:
|
||||||
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
|
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
|
||||||
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
|
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
|
||||||
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
|
loss, sample_size, logging_output = criterion(
|
||||||
|
model.models[lang_pair], sample[lang_pair]
|
||||||
|
)
|
||||||
return loss, sample_size, logging_output
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
|
def inference_step(
|
||||||
|
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||||
|
):
|
||||||
if self.encoder_latent_layer or self.decoder_latent_layer:
|
if self.encoder_latent_layer or self.decoder_latent_layer:
|
||||||
for model in models:
|
for model in models:
|
||||||
if self.encoder_latent_layer:
|
if self.encoder_latent_layer:
|
||||||
@ -137,15 +167,23 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
|||||||
assert model.decoder.layer_select is not None
|
assert model.decoder.layer_select is not None
|
||||||
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
|
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
|
||||||
model.decoder.set_lang_idx(tgt_lang_idx)
|
model.decoder.set_lang_idx(tgt_lang_idx)
|
||||||
return super().inference_step(generator, models, sample, prefix_tokens, constraints)
|
return super().inference_step(
|
||||||
|
generator, models, sample, prefix_tokens, constraints
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoder_latent_layer(self):
|
def encoder_latent_layer(self):
|
||||||
return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer
|
return (
|
||||||
|
hasattr(self.args, "encoder_latent_layer")
|
||||||
|
and self.args.encoder_latent_layer
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def decoder_latent_layer(self):
|
def decoder_latent_layer(self):
|
||||||
return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer
|
return (
|
||||||
|
hasattr(self.args, "decoder_latent_layer")
|
||||||
|
and self.args.decoder_latent_layer
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def src_lang_idx_dict(self):
|
def src_lang_idx_dict(self):
|
||||||
|
@ -8,37 +8,40 @@ Linformer: Self-Attention with Linear Complexity
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fairseq.models import (
|
from fairseq.models import register_model, register_model_architecture
|
||||||
register_model,
|
from fairseq.models.roberta import RobertaEncoder, RobertaModel
|
||||||
register_model_architecture,
|
|
||||||
)
|
|
||||||
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
|
|
||||||
|
|
||||||
from fairseq.models.roberta import (
|
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
|
||||||
RobertaModel,
|
|
||||||
RobertaEncoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@register_model('linformer_roberta')
|
@register_model("linformer_roberta")
|
||||||
class LinformerModel(RobertaModel):
|
class LinformerModel(RobertaModel):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
RobertaModel.add_args(parser)
|
RobertaModel.add_args(parser)
|
||||||
|
|
||||||
# add args for Linformer
|
# add args for Linformer
|
||||||
parser.add_argument('--compressed', type=int,
|
parser.add_argument(
|
||||||
help='compressed ratio of sequence length')
|
"--compressed", type=int, help="compressed ratio of sequence length"
|
||||||
parser.add_argument('--shared-kv-compressed', type=int,
|
)
|
||||||
help='share compressed matrix between k and v, in each layer')
|
parser.add_argument(
|
||||||
parser.add_argument('--shared-layer-kv-compressed', type=int,
|
"--shared-kv-compressed",
|
||||||
help='share compressed matrix between k and v and across all layers')
|
type=int,
|
||||||
parser.add_argument('--freeze-compress', type=int,
|
help="share compressed matrix between k and v, in each layer",
|
||||||
help='freeze the parameters in compressed layer')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shared-layer-kv-compressed",
|
||||||
|
type=int,
|
||||||
|
help="share compressed matrix between k and v and across all layers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze-compress",
|
||||||
|
type=int,
|
||||||
|
help="freeze the parameters in compressed layer",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model(cls, args, task):
|
def build_model(cls, args, task):
|
||||||
@ -47,7 +50,7 @@ class LinformerModel(RobertaModel):
|
|||||||
# make sure all arguments are present
|
# make sure all arguments are present
|
||||||
base_architecture(args)
|
base_architecture(args)
|
||||||
|
|
||||||
if not hasattr(args, 'max_positions'):
|
if not hasattr(args, "max_positions"):
|
||||||
args.max_positions = args.tokens_per_sample
|
args.max_positions = args.tokens_per_sample
|
||||||
|
|
||||||
encoder = LinformerEncoder(args, task.source_dictionary)
|
encoder = LinformerEncoder(args, task.source_dictionary)
|
||||||
@ -85,47 +88,47 @@ class LinformerEncoder(RobertaEncoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture('linformer_roberta', 'linformer_roberta')
|
@register_model_architecture("linformer_roberta", "linformer_roberta")
|
||||||
def base_architecture(args):
|
def base_architecture(args):
|
||||||
args.encoder_layers = getattr(args, 'encoder_layers', 12)
|
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
||||||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
||||||
|
|
||||||
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
|
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||||
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||||
|
|
||||||
args.dropout = getattr(args, 'dropout', 0.1)
|
args.dropout = getattr(args, "dropout", 0.1)
|
||||||
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||||
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||||
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
|
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||||
args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
|
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||||
args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
|
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
||||||
args.compressed = getattr(args, 'compressed', 4)
|
args.compressed = getattr(args, "compressed", 4)
|
||||||
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
|
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
||||||
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
|
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
||||||
args.freeze_compress = getattr(args, 'freeze_compress', 0)
|
args.freeze_compress = getattr(args, "freeze_compress", 0)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture('linformer_roberta', 'linformer_roberta_base')
|
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
|
||||||
def linformer_roberta_base_architecture(args):
|
def linformer_roberta_base_architecture(args):
|
||||||
base_architecture(args)
|
base_architecture(args)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture('linformer_roberta', 'linformer_roberta_large')
|
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
|
||||||
def linformer_roberta_large_architecture(args):
|
def linformer_roberta_large_architecture(args):
|
||||||
args.encoder_layers = getattr(args, 'encoder_layers', 24)
|
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
||||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
||||||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
||||||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
||||||
|
|
||||||
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
|
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||||
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||||
|
|
||||||
args.dropout = getattr(args, 'dropout', 0.1)
|
args.dropout = getattr(args, "dropout", 0.1)
|
||||||
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||||
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||||
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
|
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||||
args.compressed = getattr(args, 'compressed', 4)
|
args.compressed = getattr(args, "compressed", 4)
|
||||||
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
|
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
||||||
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
|
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from fairseq.modules import TransformerSentenceEncoder
|
from fairseq.modules import TransformerSentenceEncoder
|
||||||
|
|
||||||
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
|
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
@ -117,7 +117,9 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
|||||||
qn_block_size,
|
qn_block_size,
|
||||||
):
|
):
|
||||||
if self.shared_layer_kv_compressed == 1:
|
if self.shared_layer_kv_compressed == 1:
|
||||||
compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed)
|
compress_layer = nn.Linear(
|
||||||
|
self.max_seq_len, self.max_seq_len // self.compressed
|
||||||
|
)
|
||||||
# intialize parameters for compressed layer
|
# intialize parameters for compressed layer
|
||||||
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
|
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
|
||||||
if self.freeze_compress == 1:
|
if self.freeze_compress == 1:
|
||||||
@ -139,8 +141,7 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
|||||||
max_seq_len=self.max_seq_len,
|
max_seq_len=self.max_seq_len,
|
||||||
shared_kv_compressed=self.shared_kv_compressed,
|
shared_kv_compressed=self.shared_kv_compressed,
|
||||||
shared_compress_layer=(
|
shared_compress_layer=(
|
||||||
None if self.shared_layer_kv_compressed == 0
|
None if self.shared_layer_kv_compressed == 0 else self.compress_layer
|
||||||
else self.compress_layer
|
|
||||||
),
|
),
|
||||||
freeze_compress=self.freeze_compress,
|
freeze_compress=self.freeze_compress,
|
||||||
)
|
)
|
||||||
@ -156,7 +157,8 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
|||||||
if self.shared_layer_kv_compressed:
|
if self.shared_layer_kv_compressed:
|
||||||
for layer_idx in range(len(self.layers)):
|
for layer_idx in range(len(self.layers)):
|
||||||
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
|
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
|
||||||
layer_idx, k[len(prefix + 'compress_layer.'):],
|
layer_idx,
|
||||||
|
k[len(prefix + "compress_layer.") :],
|
||||||
)
|
)
|
||||||
items_to_add[new_k] = state_dict[k]
|
items_to_add[new_k] = state_dict[k]
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
|
|
||||||
from .multihead_linear_attention import MultiheadLinearAttention
|
from .multihead_linear_attention import MultiheadLinearAttention
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +24,7 @@ class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
attention_dropout: float = 0.1,
|
attention_dropout: float = 0.1,
|
||||||
activation_dropout: float = 0.1,
|
activation_dropout: float = 0.1,
|
||||||
activation_fn: str = 'relu',
|
activation_fn: str = "relu",
|
||||||
export: bool = False,
|
export: bool = False,
|
||||||
q_noise: float = 0.0,
|
q_noise: float = 0.0,
|
||||||
qn_block_size: int = 8,
|
qn_block_size: int = 8,
|
||||||
|
@ -9,10 +9,10 @@ from typing import Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from torch import Tensor, nn
|
|
||||||
from torch.nn import Parameter
|
|
||||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||||
from fairseq.modules.quant_noise import quant_noise
|
from fairseq.modules.quant_noise import quant_noise
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
@with_incremental_state
|
@with_incremental_state
|
||||||
@ -65,16 +65,24 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
"Self-attention requires query, key and " "value to be of the same size"
|
"Self-attention requires query, key and " "value to be of the same size"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
self.k_proj = quant_noise(
|
||||||
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
)
|
||||||
|
self.v_proj = quant_noise(
|
||||||
|
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.q_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
# used for compress sequence to subsequence
|
# used for compress sequence to subsequence
|
||||||
if shared_compress_layer is None:
|
if shared_compress_layer is None:
|
||||||
self.compress_seq_len = max_seq_len // compressed
|
self.compress_seq_len = max_seq_len // compressed
|
||||||
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
||||||
if shared_kv_compressed == 0:
|
if shared_kv_compressed == 0:
|
||||||
self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
self.compress_v = nn.Linear(
|
||||||
|
max_seq_len, self.compress_seq_len, bias=False
|
||||||
|
)
|
||||||
self.layerwise_sharing = False
|
self.layerwise_sharing = False
|
||||||
else:
|
else:
|
||||||
self.compress_k = shared_compress_layer
|
self.compress_k = shared_compress_layer
|
||||||
@ -83,7 +91,9 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
self.layerwise_sharing = True
|
self.layerwise_sharing = True
|
||||||
self.shared_kv_compressed = shared_kv_compressed
|
self.shared_kv_compressed = shared_kv_compressed
|
||||||
|
|
||||||
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
self.out_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
if add_bias_kv:
|
if add_bias_kv:
|
||||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
@ -116,22 +126,28 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
if not self.layerwise_sharing: # otherwise, we already initialize the parameters
|
if (
|
||||||
nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2))
|
not self.layerwise_sharing
|
||||||
|
): # otherwise, we already initialize the parameters
|
||||||
|
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
|
||||||
if self.shared_kv_compressed == 0:
|
if self.shared_kv_compressed == 0:
|
||||||
nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2))
|
nn.init.xavier_uniform_(
|
||||||
|
self.compress_v.weight, gain=1 / math.sqrt(2)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
if not self.layerwise_sharing: # otherwise, we already initialize the parameters
|
if (
|
||||||
|
not self.layerwise_sharing
|
||||||
|
): # otherwise, we already initialize the parameters
|
||||||
nn.init.xavier_uniform_(self.compress_k.weight)
|
nn.init.xavier_uniform_(self.compress_k.weight)
|
||||||
if self.shared_kv_compressed == 0:
|
if self.shared_kv_compressed == 0:
|
||||||
nn.init.xavier_uniform_(self.compress_v.weight)
|
nn.init.xavier_uniform_(self.compress_v.weight)
|
||||||
|
|
||||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
if self.out_proj.bias is not None:
|
if self.out_proj.bias is not None:
|
||||||
nn.init.constant_(self.out_proj.bias, 0.)
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
if self.bias_k is not None:
|
if self.bias_k is not None:
|
||||||
nn.init.xavier_normal_(self.bias_k)
|
nn.init.xavier_normal_(self.bias_k)
|
||||||
if self.bias_v is not None:
|
if self.bias_v is not None:
|
||||||
@ -189,14 +205,26 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
q = self.q_proj(query)
|
q = self.q_proj(query)
|
||||||
|
|
||||||
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
||||||
k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
|
k_input = (
|
||||||
|
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
k = self.k_proj(k_input)
|
k = self.k_proj(k_input)
|
||||||
|
|
||||||
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
||||||
if self.shared_kv_compressed == 0:
|
if self.shared_kv_compressed == 0:
|
||||||
v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
|
v_input = (
|
||||||
|
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
|
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
|
||||||
v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
|
v_input = (
|
||||||
|
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
v = self.v_proj(v_input)
|
v = self.v_proj(v_input)
|
||||||
elif self.encoder_decoder_attention:
|
elif self.encoder_decoder_attention:
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -302,7 +330,9 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
attn_weights = MultiheadLinearAttention.apply_sparse_mask(
|
||||||
|
attn_weights, tgt_len, src_len, bsz
|
||||||
|
)
|
||||||
|
|
||||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
@ -385,7 +415,9 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def reorder_incremental_state(
|
def reorder_incremental_state(
|
||||||
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
|
self,
|
||||||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||||
|
new_order: Tensor,
|
||||||
):
|
):
|
||||||
"""Reorder buffered internal state (for incremental generation)."""
|
"""Reorder buffered internal state (for incremental generation)."""
|
||||||
input_buffer = self._get_input_buffer(incremental_state)
|
input_buffer = self._get_input_buffer(incremental_state)
|
||||||
@ -393,7 +425,9 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
for k in input_buffer.keys():
|
for k in input_buffer.keys():
|
||||||
input_buffer_k = input_buffer[k]
|
input_buffer_k = input_buffer[k]
|
||||||
if input_buffer_k is not None:
|
if input_buffer_k is not None:
|
||||||
if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0):
|
if self.encoder_decoder_attention and input_buffer_k.size(
|
||||||
|
0
|
||||||
|
) == new_order.size(0):
|
||||||
break
|
break
|
||||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||||
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||||
@ -428,8 +462,8 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
# in_proj_weight used to be q + k + v with same dimensions
|
# in_proj_weight used to be q + k + v with same dimensions
|
||||||
dim = int(state_dict[k].shape[0] / 3)
|
dim = int(state_dict[k].shape[0] / 3)
|
||||||
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim]
|
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
||||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
|
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
||||||
|
|
||||||
keys_to_remove.append(k)
|
keys_to_remove.append(k)
|
||||||
|
|
||||||
@ -438,9 +472,9 @@ class MultiheadLinearAttention(nn.Module):
|
|||||||
dim = int(state_dict[k].shape[0] / 3)
|
dim = int(state_dict[k].shape[0] / 3)
|
||||||
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
||||||
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||||||
dim:2 * dim
|
dim : 2 * dim
|
||||||
]
|
]
|
||||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
|
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
||||||
|
|
||||||
keys_to_remove.append(prefix + "in_proj_bias")
|
keys_to_remove.append(prefix + "in_proj_bias")
|
||||||
|
|
||||||
|
@ -8,14 +8,16 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
|
|
||||||
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
|
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
|
||||||
|
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
|
||||||
|
|
||||||
factory=IndicNormalizerFactory()
|
|
||||||
normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing')
|
factory = IndicNormalizerFactory()
|
||||||
|
normalizer = factory.get_normalizer(
|
||||||
|
sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing"
|
||||||
|
)
|
||||||
|
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
normalized_line=normalizer.normalize(line.strip())
|
normalized_line = normalizer.normalize(line.strip())
|
||||||
tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1]))
|
tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1]))
|
||||||
print(tokenized_line)
|
print(tokenized_line)
|
||||||
|
|
||||||
|
@ -8,5 +8,6 @@ import sys
|
|||||||
|
|
||||||
from pythainlp import word_tokenize
|
from pythainlp import word_tokenize
|
||||||
|
|
||||||
|
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
print(" ".join(word_tokenize(line.strip())))
|
print(" ".join(word_tokenize(line.strip())))
|
||||||
|
@ -6,7 +6,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import fileinput
|
import fileinput
|
||||||
|
|
||||||
import sacrebleu
|
import sacrebleu
|
||||||
|
|
||||||
|
|
||||||
for line in fileinput.input():
|
for line in fileinput.input():
|
||||||
print(sacrebleu.tokenize_zh(line))
|
print(sacrebleu.tokenize_zh(line))
|
||||||
|
@ -6,19 +6,27 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import fileinput
|
import fileinput
|
||||||
|
|
||||||
import sacremoses
|
import sacremoses
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='')
|
parser = argparse.ArgumentParser(description="")
|
||||||
parser.add_argument('files', nargs='*', help='input files')
|
parser.add_argument("files", nargs="*", help="input files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
detok = sacremoses.MosesDetokenizer()
|
detok = sacremoses.MosesDetokenizer()
|
||||||
|
|
||||||
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
|
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
|
||||||
print(detok.detokenize(line.strip().split(' ')).replace(' @', '').replace('@ ', '').replace(' =', '=').replace('= ', '=').replace(' – ', '–'))
|
print(
|
||||||
|
detok.detokenize(line.strip().split(" "))
|
||||||
|
.replace(" @", "")
|
||||||
|
.replace("@ ", "")
|
||||||
|
.replace(" =", "=")
|
||||||
|
.replace("= ", "=")
|
||||||
|
.replace(" – ", "–")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -7,21 +7,22 @@ import math
|
|||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fairseq import options
|
from fairseq import options
|
||||||
from fairseq.data import dictionary
|
from fairseq.data import dictionary
|
||||||
from fairseq.scoring import bleu
|
from fairseq.scoring import bleu
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
rerank_generate,
|
rerank_generate,
|
||||||
|
rerank_options,
|
||||||
rerank_score_bw,
|
rerank_score_bw,
|
||||||
rerank_score_lm,
|
rerank_score_lm,
|
||||||
rerank_options,
|
|
||||||
rerank_utils,
|
rerank_utils,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize):
|
def score_target_hypo(
|
||||||
|
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
|
||||||
|
):
|
||||||
|
|
||||||
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
|
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
|
||||||
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
|
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
|
||||||
@ -61,11 +62,21 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
|||||||
bitext2_score = None
|
bitext2_score = None
|
||||||
bitext2_backwards = None
|
bitext2_backwards = None
|
||||||
|
|
||||||
score = rerank_utils.get_score(a, b, c, target_len,
|
score = rerank_utils.get_score(
|
||||||
bitext1.rescore_score[i], bitext2_score, lm_score=lm_score,
|
a,
|
||||||
lenpen=lenpen, src_len=bitext1.source_lengths[i],
|
b,
|
||||||
tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards,
|
c,
|
||||||
bitext2_backwards=bitext2_backwards, normalize=normalize)
|
target_len,
|
||||||
|
bitext1.rescore_score[i],
|
||||||
|
bitext2_score,
|
||||||
|
lm_score=lm_score,
|
||||||
|
lenpen=lenpen,
|
||||||
|
src_len=bitext1.source_lengths[i],
|
||||||
|
tgt_len=bitext1.target_lengths[i],
|
||||||
|
bitext1_backwards=bitext1.backwards,
|
||||||
|
bitext2_backwards=bitext2_backwards,
|
||||||
|
normalize=normalize,
|
||||||
|
)
|
||||||
|
|
||||||
if score > best_score:
|
if score > best_score:
|
||||||
best_score = score
|
best_score = score
|
||||||
@ -88,8 +99,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
|||||||
for key in range(len(gen_keys)):
|
for key in range(len(gen_keys)):
|
||||||
if args.prefix_len is None:
|
if args.prefix_len is None:
|
||||||
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
||||||
"pred and rescore hypo mismatch: i: " + str(key) + ", "
|
"pred and rescore hypo mismatch: i: "
|
||||||
+ str(hypo_lst[key]) + str(gen_keys[key])
|
+ str(key)
|
||||||
|
+ ", "
|
||||||
|
+ str(hypo_lst[key])
|
||||||
|
+ str(gen_keys[key])
|
||||||
+ str(gen_output.no_bpe_hypo[key])
|
+ str(gen_output.no_bpe_hypo[key])
|
||||||
)
|
)
|
||||||
sys_tok = dict.encode_line(hypo_lst[key])
|
sys_tok = dict.encode_line(hypo_lst[key])
|
||||||
@ -97,7 +111,9 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
|||||||
scorer.add(ref_tok, sys_tok)
|
scorer.add(ref_tok, sys_tok)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
|
full_hypo = rerank_utils.get_full_from_prefix(
|
||||||
|
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
|
||||||
|
)
|
||||||
sys_tok = dict.encode_line(full_hypo)
|
sys_tok = dict.encode_line(full_hypo)
|
||||||
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
|
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
|
||||||
scorer.add(ref_tok, sys_tok)
|
scorer.add(ref_tok, sys_tok)
|
||||||
@ -107,20 +123,31 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
|||||||
# recover the orinal ids from n best list generation
|
# recover the orinal ids from n best list generation
|
||||||
for key in range(len(gen_output.no_bpe_target)):
|
for key in range(len(gen_output.no_bpe_target)):
|
||||||
if args.prefix_len is None:
|
if args.prefix_len is None:
|
||||||
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
|
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
||||||
"pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key])
|
"pred and rescore hypo mismatch:"
|
||||||
|
+ "i:"
|
||||||
|
+ str(key)
|
||||||
|
+ str(hypo_lst[key])
|
||||||
|
+ str(gen_output.no_bpe_hypo[key])
|
||||||
|
)
|
||||||
ordered_hypos[gen_keys[key]] = hypo_lst[key]
|
ordered_hypos[gen_keys[key]] = hypo_lst[key]
|
||||||
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
|
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
|
||||||
|
gen_keys[key]
|
||||||
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
|
full_hypo = rerank_utils.get_full_from_prefix(
|
||||||
|
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
|
||||||
|
)
|
||||||
ordered_hypos[gen_keys[key]] = full_hypo
|
ordered_hypos[gen_keys[key]] = full_hypo
|
||||||
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
|
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
|
||||||
|
gen_keys[key]
|
||||||
|
]
|
||||||
|
|
||||||
# write the hypos in the original order from nbest list generation
|
# write the hypos in the original order from nbest list generation
|
||||||
if args.num_shards == (len(bitext1_lst)):
|
if args.num_shards == (len(bitext1_lst)):
|
||||||
with open(target_outfile, 'w') as t:
|
with open(target_outfile, "w") as t:
|
||||||
with open(hypo_outfile, 'w') as h:
|
with open(hypo_outfile, "w") as h:
|
||||||
for key in range(len(ordered_hypos)):
|
for key in range(len(ordered_hypos)):
|
||||||
t.write(ordered_targets[key])
|
t.write(ordered_targets[key])
|
||||||
h.write(ordered_hypos[key])
|
h.write(ordered_hypos[key])
|
||||||
@ -135,17 +162,38 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
|||||||
def match_target_hypo(args, target_outfile, hypo_outfile):
|
def match_target_hypo(args, target_outfile, hypo_outfile):
|
||||||
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
|
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
|
||||||
if len(args.weight1) == 1:
|
if len(args.weight1) == 1:
|
||||||
res = score_target_hypo(args, args.weight1[0], args.weight2[0],
|
res = score_target_hypo(
|
||||||
args.weight3[0], args.lenpen[0], target_outfile,
|
args,
|
||||||
hypo_outfile, True, args.normalize)
|
args.weight1[0],
|
||||||
|
args.weight2[0],
|
||||||
|
args.weight3[0],
|
||||||
|
args.lenpen[0],
|
||||||
|
target_outfile,
|
||||||
|
hypo_outfile,
|
||||||
|
True,
|
||||||
|
args.normalize,
|
||||||
|
)
|
||||||
rerank_scores = [res]
|
rerank_scores = [res]
|
||||||
else:
|
else:
|
||||||
print("launching pool")
|
print("launching pool")
|
||||||
with Pool(32) as p:
|
with Pool(32) as p:
|
||||||
rerank_scores = p.starmap(score_target_hypo,
|
rerank_scores = p.starmap(
|
||||||
[(args, args.weight1[i], args.weight2[i], args.weight3[i],
|
score_target_hypo,
|
||||||
args.lenpen[i], target_outfile, hypo_outfile,
|
[
|
||||||
False, args.normalize) for i in range(len(args.weight1))])
|
(
|
||||||
|
args,
|
||||||
|
args.weight1[i],
|
||||||
|
args.weight2[i],
|
||||||
|
args.weight3[i],
|
||||||
|
args.lenpen[i],
|
||||||
|
target_outfile,
|
||||||
|
hypo_outfile,
|
||||||
|
False,
|
||||||
|
args.normalize,
|
||||||
|
)
|
||||||
|
for i in range(len(args.weight1))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if len(rerank_scores) > 1:
|
if len(rerank_scores) > 1:
|
||||||
best_index = np.argmax(rerank_scores)
|
best_index = np.argmax(rerank_scores)
|
||||||
@ -155,11 +203,22 @@ def match_target_hypo(args, target_outfile, hypo_outfile):
|
|||||||
print("best weight1", args.weight1[best_index])
|
print("best weight1", args.weight1[best_index])
|
||||||
print("best weight2", args.weight2[best_index])
|
print("best weight2", args.weight2[best_index])
|
||||||
print("best weight3", args.weight3[best_index])
|
print("best weight3", args.weight3[best_index])
|
||||||
return args.lenpen[best_index], args.weight1[best_index], \
|
return (
|
||||||
args.weight2[best_index], args.weight3[best_index], best_score
|
args.lenpen[best_index],
|
||||||
|
args.weight1[best_index],
|
||||||
|
args.weight2[best_index],
|
||||||
|
args.weight3[best_index],
|
||||||
|
best_score,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0]
|
return (
|
||||||
|
args.lenpen[0],
|
||||||
|
args.weight1[0],
|
||||||
|
args.weight2[0],
|
||||||
|
args.weight3[0],
|
||||||
|
rerank_scores[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_score_files(args):
|
def load_score_files(args):
|
||||||
@ -175,55 +234,100 @@ def load_score_files(args):
|
|||||||
|
|
||||||
for shard_id in shard_ids:
|
for shard_id in shard_ids:
|
||||||
using_nbest = args.nbest_list is not None
|
using_nbest = args.nbest_list is not None
|
||||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
(
|
||||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
pre_gen,
|
||||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
left_to_right_preprocessed_dir,
|
||||||
args.gen_model_name, shard_id, args.num_shards, args.sampling,
|
right_to_left_preprocessed_dir,
|
||||||
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
|
backwards_preprocessed_dir,
|
||||||
|
lm_preprocessed_dir,
|
||||||
|
) = rerank_utils.get_directories(
|
||||||
|
args.data_dir_name,
|
||||||
|
args.num_rescore,
|
||||||
|
args.gen_subset,
|
||||||
|
args.gen_model_name,
|
||||||
|
shard_id,
|
||||||
|
args.num_shards,
|
||||||
|
args.sampling,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
rerank1_is_gen = (
|
||||||
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||||
|
)
|
||||||
|
rerank2_is_gen = (
|
||||||
|
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||||
|
)
|
||||||
|
|
||||||
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
|
score1_file = rerank_utils.rescore_file_name(
|
||||||
target_prefix_frac=args.target_prefix_frac,
|
pre_gen,
|
||||||
source_prefix_frac=args.source_prefix_frac,
|
args.prefix_len,
|
||||||
backwards=args.backwards1)
|
args.model1_name,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
backwards=args.backwards1,
|
||||||
|
)
|
||||||
if args.score_model2 is not None:
|
if args.score_model2 is not None:
|
||||||
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
|
score2_file = rerank_utils.rescore_file_name(
|
||||||
target_prefix_frac=args.target_prefix_frac,
|
pre_gen,
|
||||||
source_prefix_frac=args.source_prefix_frac,
|
args.prefix_len,
|
||||||
backwards=args.backwards2)
|
args.model2_name,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
backwards=args.backwards2,
|
||||||
|
)
|
||||||
if args.language_model is not None:
|
if args.language_model is not None:
|
||||||
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
|
lm_score_file = rerank_utils.rescore_file_name(
|
||||||
|
pre_gen, args.prefix_len, args.lm_name, lm_file=True
|
||||||
|
)
|
||||||
|
|
||||||
# get gen output
|
# get gen output
|
||||||
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
|
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
||||||
if using_nbest:
|
if using_nbest:
|
||||||
print("Using predefined n-best list from interactive.py")
|
print("Using predefined n-best list from interactive.py")
|
||||||
predictions_bpe_file = args.nbest_list
|
predictions_bpe_file = args.nbest_list
|
||||||
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
|
gen_output = rerank_utils.BitextOutputFromGen(
|
||||||
nbest=using_nbest, prefix_len=args.prefix_len,
|
predictions_bpe_file,
|
||||||
target_prefix_frac=args.target_prefix_frac)
|
bpe_symbol=args.remove_bpe,
|
||||||
|
nbest=using_nbest,
|
||||||
|
prefix_len=args.prefix_len,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
if rerank1_is_gen:
|
if rerank1_is_gen:
|
||||||
bitext1 = gen_output
|
bitext1 = gen_output
|
||||||
else:
|
else:
|
||||||
bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1,
|
bitext1 = rerank_utils.BitextOutput(
|
||||||
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
|
score1_file,
|
||||||
args.source_prefix_frac)
|
args.backwards1,
|
||||||
|
args.right_to_left1,
|
||||||
|
args.remove_bpe,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
if args.score_model2 is not None or args.nbest_list is not None:
|
if args.score_model2 is not None or args.nbest_list is not None:
|
||||||
if rerank2_is_gen:
|
if rerank2_is_gen:
|
||||||
bitext2 = gen_output
|
bitext2 = gen_output
|
||||||
else:
|
else:
|
||||||
bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2,
|
bitext2 = rerank_utils.BitextOutput(
|
||||||
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
|
score2_file,
|
||||||
args.source_prefix_frac)
|
args.backwards2,
|
||||||
|
args.right_to_left2,
|
||||||
|
args.remove_bpe,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
assert bitext2.source_lengths == bitext1.source_lengths, \
|
assert (
|
||||||
"source lengths for rescoring models do not match"
|
bitext2.source_lengths == bitext1.source_lengths
|
||||||
assert bitext2.target_lengths == bitext1.target_lengths, \
|
), "source lengths for rescoring models do not match"
|
||||||
"target lengths for rescoring models do not match"
|
assert (
|
||||||
|
bitext2.target_lengths == bitext1.target_lengths
|
||||||
|
), "target lengths for rescoring models do not match"
|
||||||
else:
|
else:
|
||||||
if args.diff_bpe:
|
if args.diff_bpe:
|
||||||
assert args.score_model2 is None
|
assert args.score_model2 is None
|
||||||
@ -232,8 +336,13 @@ def load_score_files(args):
|
|||||||
bitext2 = None
|
bitext2 = None
|
||||||
|
|
||||||
if args.language_model is not None:
|
if args.language_model is not None:
|
||||||
lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len,
|
lm_res1 = rerank_utils.LMOutput(
|
||||||
args.remove_bpe, args.target_prefix_frac)
|
lm_score_file,
|
||||||
|
args.lm_dict,
|
||||||
|
args.prefix_len,
|
||||||
|
args.remove_bpe,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
lm_res1 = None
|
lm_res1 = None
|
||||||
|
|
||||||
@ -259,28 +368,46 @@ def rerank(args):
|
|||||||
shard_ids = [args.shard_id]
|
shard_ids = [args.shard_id]
|
||||||
|
|
||||||
for shard_id in shard_ids:
|
for shard_id in shard_ids:
|
||||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
(
|
||||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
pre_gen,
|
||||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
left_to_right_preprocessed_dir,
|
||||||
args.gen_model_name, shard_id, args.num_shards, args.sampling,
|
right_to_left_preprocessed_dir,
|
||||||
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
|
backwards_preprocessed_dir,
|
||||||
|
lm_preprocessed_dir,
|
||||||
|
) = rerank_utils.get_directories(
|
||||||
|
args.data_dir_name,
|
||||||
|
args.num_rescore,
|
||||||
|
args.gen_subset,
|
||||||
|
args.gen_model_name,
|
||||||
|
shard_id,
|
||||||
|
args.num_shards,
|
||||||
|
args.sampling,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
rerank_generate.gen_and_reprocess_nbest(args)
|
rerank_generate.gen_and_reprocess_nbest(args)
|
||||||
rerank_score_bw.score_bw(args)
|
rerank_score_bw.score_bw(args)
|
||||||
rerank_score_lm.score_lm(args)
|
rerank_score_lm.score_lm(args)
|
||||||
|
|
||||||
if args.write_hypos is None:
|
if args.write_hypos is None:
|
||||||
write_targets = pre_gen+"/matched_targets"
|
write_targets = pre_gen + "/matched_targets"
|
||||||
write_hypos = pre_gen+"/matched_hypos"
|
write_hypos = pre_gen + "/matched_hypos"
|
||||||
else:
|
else:
|
||||||
write_targets = args.write_hypos+"_targets" + args.gen_subset
|
write_targets = args.write_hypos + "_targets" + args.gen_subset
|
||||||
write_hypos = args.write_hypos+"_hypos" + args.gen_subset
|
write_hypos = args.write_hypos + "_hypos" + args.gen_subset
|
||||||
|
|
||||||
if args.all_shards:
|
if args.all_shards:
|
||||||
write_targets += "_all_shards"
|
write_targets += "_all_shards"
|
||||||
write_hypos += "_all_shards"
|
write_hypos += "_all_shards"
|
||||||
|
|
||||||
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \
|
(
|
||||||
match_target_hypo(args, write_targets, write_hypos)
|
best_lenpen,
|
||||||
|
best_weight1,
|
||||||
|
best_weight2,
|
||||||
|
best_weight3,
|
||||||
|
best_score,
|
||||||
|
) = match_target_hypo(args, write_targets, write_hypos)
|
||||||
|
|
||||||
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
|
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
|
||||||
|
|
||||||
@ -291,5 +418,5 @@ def cli_main():
|
|||||||
rerank(args)
|
rerank(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
@ -8,9 +8,9 @@
|
|||||||
Generate n-best translations using a trained model.
|
Generate n-best translations using a trained model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from contextlib import redirect_stdout
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
from fairseq import options
|
from fairseq import options
|
||||||
from fairseq_cli import generate, preprocess
|
from fairseq_cli import generate, preprocess
|
||||||
@ -22,8 +22,12 @@ def gen_and_reprocess_nbest(args):
|
|||||||
if args.score_dict_dir is None:
|
if args.score_dict_dir is None:
|
||||||
args.score_dict_dir = args.data
|
args.score_dict_dir = args.data
|
||||||
if args.prefix_len is not None:
|
if args.prefix_len is not None:
|
||||||
assert args.right_to_left1 is False, "prefix length not compatible with right to left models"
|
assert (
|
||||||
assert args.right_to_left2 is False, "prefix length not compatible with right to left models"
|
args.right_to_left1 is False
|
||||||
|
), "prefix length not compatible with right to left models"
|
||||||
|
assert (
|
||||||
|
args.right_to_left2 is False
|
||||||
|
), "prefix length not compatible with right to left models"
|
||||||
|
|
||||||
if args.nbest_list is not None:
|
if args.nbest_list is not None:
|
||||||
assert args.score_model2 is None
|
assert args.score_model2 is None
|
||||||
@ -35,27 +39,50 @@ def gen_and_reprocess_nbest(args):
|
|||||||
scorer1_src = args.source_lang
|
scorer1_src = args.source_lang
|
||||||
scorer1_tgt = args.target_lang
|
scorer1_tgt = args.target_lang
|
||||||
|
|
||||||
store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name
|
store_data = (
|
||||||
|
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
|
||||||
|
)
|
||||||
if not os.path.exists(store_data):
|
if not os.path.exists(store_data):
|
||||||
os.makedirs(store_data)
|
os.makedirs(store_data)
|
||||||
|
|
||||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
(
|
||||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
pre_gen,
|
||||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
left_to_right_preprocessed_dir,
|
||||||
args.gen_model_name, args.shard_id, args.num_shards,
|
right_to_left_preprocessed_dir,
|
||||||
args.sampling, args.prefix_len, args.target_prefix_frac,
|
backwards_preprocessed_dir,
|
||||||
args.source_prefix_frac)
|
lm_preprocessed_dir,
|
||||||
assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported"
|
) = rerank_utils.get_directories(
|
||||||
assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported"
|
args.data_dir_name,
|
||||||
assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \
|
args.num_rescore,
|
||||||
"target prefix frac and target prefix len incompatible"
|
args.gen_subset,
|
||||||
|
args.gen_model_name,
|
||||||
|
args.shard_id,
|
||||||
|
args.num_shards,
|
||||||
|
args.sampling,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
|
assert not (
|
||||||
|
args.right_to_left1 and args.backwards1
|
||||||
|
), "backwards right to left not supported"
|
||||||
|
assert not (
|
||||||
|
args.right_to_left2 and args.backwards2
|
||||||
|
), "backwards right to left not supported"
|
||||||
|
assert not (
|
||||||
|
args.prefix_len is not None and args.target_prefix_frac is not None
|
||||||
|
), "target prefix frac and target prefix len incompatible"
|
||||||
|
|
||||||
# make directory to store generation results
|
# make directory to store generation results
|
||||||
if not os.path.exists(pre_gen):
|
if not os.path.exists(pre_gen):
|
||||||
os.makedirs(pre_gen)
|
os.makedirs(pre_gen)
|
||||||
|
|
||||||
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
rerank1_is_gen = (
|
||||||
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||||
|
)
|
||||||
|
rerank2_is_gen = (
|
||||||
|
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||||
|
)
|
||||||
|
|
||||||
if args.nbest_list is not None:
|
if args.nbest_list is not None:
|
||||||
rerank2_is_gen = True
|
rerank2_is_gen = True
|
||||||
@ -70,17 +97,25 @@ def gen_and_reprocess_nbest(args):
|
|||||||
if not os.path.exists(backwards_preprocessed_dir):
|
if not os.path.exists(backwards_preprocessed_dir):
|
||||||
os.makedirs(backwards_preprocessed_dir)
|
os.makedirs(backwards_preprocessed_dir)
|
||||||
|
|
||||||
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
|
score1_file = rerank_utils.rescore_file_name(
|
||||||
target_prefix_frac=args.target_prefix_frac,
|
pre_gen,
|
||||||
source_prefix_frac=args.source_prefix_frac,
|
args.prefix_len,
|
||||||
backwards=args.backwards1)
|
args.model1_name,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
backwards=args.backwards1,
|
||||||
|
)
|
||||||
if args.score_model2 is not None:
|
if args.score_model2 is not None:
|
||||||
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
|
score2_file = rerank_utils.rescore_file_name(
|
||||||
target_prefix_frac=args.target_prefix_frac,
|
pre_gen,
|
||||||
source_prefix_frac=args.source_prefix_frac,
|
args.prefix_len,
|
||||||
backwards=args.backwards2)
|
args.model2_name,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
backwards=args.backwards2,
|
||||||
|
)
|
||||||
|
|
||||||
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
|
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
||||||
|
|
||||||
using_nbest = args.nbest_list is not None
|
using_nbest = args.nbest_list is not None
|
||||||
|
|
||||||
@ -92,17 +127,29 @@ def gen_and_reprocess_nbest(args):
|
|||||||
if not os.path.isfile(predictions_bpe_file):
|
if not os.path.isfile(predictions_bpe_file):
|
||||||
print("STEP 1: generate predictions using the p(T|S) model with bpe")
|
print("STEP 1: generate predictions using the p(T|S) model with bpe")
|
||||||
print(args.data)
|
print(args.data)
|
||||||
param1 = [args.data,
|
param1 = [
|
||||||
"--path", args.gen_model,
|
args.data,
|
||||||
"--shard-id", str(args.shard_id),
|
"--path",
|
||||||
"--num-shards", str(args.num_shards),
|
args.gen_model,
|
||||||
"--nbest", str(args.num_rescore),
|
"--shard-id",
|
||||||
"--batch-size", str(args.batch_size),
|
str(args.shard_id),
|
||||||
"--beam", str(args.num_rescore),
|
"--num-shards",
|
||||||
"--batch-size", str(args.num_rescore),
|
str(args.num_shards),
|
||||||
"--gen-subset", args.gen_subset,
|
"--nbest",
|
||||||
"--source-lang", args.source_lang,
|
str(args.num_rescore),
|
||||||
"--target-lang", args.target_lang]
|
"--batch-size",
|
||||||
|
str(args.batch_size),
|
||||||
|
"--beam",
|
||||||
|
str(args.num_rescore),
|
||||||
|
"--batch-size",
|
||||||
|
str(args.num_rescore),
|
||||||
|
"--gen-subset",
|
||||||
|
args.gen_subset,
|
||||||
|
"--source-lang",
|
||||||
|
args.source_lang,
|
||||||
|
"--target-lang",
|
||||||
|
args.target_lang,
|
||||||
|
]
|
||||||
if args.sampling:
|
if args.sampling:
|
||||||
param1 += ["--sampling"]
|
param1 += ["--sampling"]
|
||||||
|
|
||||||
@ -110,124 +157,229 @@ def gen_and_reprocess_nbest(args):
|
|||||||
input_args = options.parse_args_and_arch(gen_parser, param1)
|
input_args = options.parse_args_and_arch(gen_parser, param1)
|
||||||
|
|
||||||
print(input_args)
|
print(input_args)
|
||||||
with open(predictions_bpe_file, 'w') as f:
|
with open(predictions_bpe_file, "w") as f:
|
||||||
with redirect_stdout(f):
|
with redirect_stdout(f):
|
||||||
generate.main(input_args)
|
generate.main(input_args)
|
||||||
|
|
||||||
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
|
gen_output = rerank_utils.BitextOutputFromGen(
|
||||||
nbest=using_nbest, prefix_len=args.prefix_len,
|
predictions_bpe_file,
|
||||||
target_prefix_frac=args.target_prefix_frac)
|
bpe_symbol=args.remove_bpe,
|
||||||
|
nbest=using_nbest,
|
||||||
|
prefix_len=args.prefix_len,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
if args.diff_bpe:
|
if args.diff_bpe:
|
||||||
rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
|
rerank_utils.write_reprocessed(
|
||||||
gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang,
|
gen_output.no_bpe_source,
|
||||||
pre_gen+"/target_gen_bpe."+args.target_lang,
|
gen_output.no_bpe_hypo,
|
||||||
pre_gen+"/reference_gen_bpe."+args.target_lang)
|
gen_output.no_bpe_target,
|
||||||
|
pre_gen + "/source_gen_bpe." + args.source_lang,
|
||||||
|
pre_gen + "/target_gen_bpe." + args.target_lang,
|
||||||
|
pre_gen + "/reference_gen_bpe." + args.target_lang,
|
||||||
|
)
|
||||||
bitext_bpe = args.rescore_bpe_code
|
bitext_bpe = args.rescore_bpe_code
|
||||||
bpe_src_param = ["-c", bitext_bpe,
|
bpe_src_param = [
|
||||||
"--input", pre_gen+"/source_gen_bpe."+args.source_lang,
|
"-c",
|
||||||
"--output", pre_gen+"/rescore_data."+args.source_lang]
|
bitext_bpe,
|
||||||
bpe_tgt_param = ["-c", bitext_bpe,
|
"--input",
|
||||||
"--input", pre_gen+"/target_gen_bpe."+args.target_lang,
|
pre_gen + "/source_gen_bpe." + args.source_lang,
|
||||||
"--output", pre_gen+"/rescore_data."+args.target_lang]
|
"--output",
|
||||||
|
pre_gen + "/rescore_data." + args.source_lang,
|
||||||
|
]
|
||||||
|
bpe_tgt_param = [
|
||||||
|
"-c",
|
||||||
|
bitext_bpe,
|
||||||
|
"--input",
|
||||||
|
pre_gen + "/target_gen_bpe." + args.target_lang,
|
||||||
|
"--output",
|
||||||
|
pre_gen + "/rescore_data." + args.target_lang,
|
||||||
|
]
|
||||||
|
|
||||||
subprocess.call(["python",
|
subprocess.call(
|
||||||
os.path.join(os.path.dirname(__file__),
|
[
|
||||||
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
|
"python",
|
||||||
shell=False)
|
os.path.join(
|
||||||
|
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ bpe_src_param,
|
||||||
|
shell=False,
|
||||||
|
)
|
||||||
|
|
||||||
subprocess.call(["python",
|
subprocess.call(
|
||||||
os.path.join(os.path.dirname(__file__),
|
[
|
||||||
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param,
|
"python",
|
||||||
shell=False)
|
os.path.join(
|
||||||
|
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ bpe_tgt_param,
|
||||||
|
shell=False,
|
||||||
|
)
|
||||||
|
|
||||||
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \
|
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
|
||||||
(args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen):
|
args.score_model2 is not None
|
||||||
print("STEP 2: process the output of generate.py so we have clean text files with the translations")
|
and not os.path.isfile(score2_file)
|
||||||
|
and not rerank2_is_gen
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"STEP 2: process the output of generate.py so we have clean text files with the translations"
|
||||||
|
)
|
||||||
|
|
||||||
rescore_file = "/rescore_data"
|
rescore_file = "/rescore_data"
|
||||||
if args.prefix_len is not None:
|
if args.prefix_len is not None:
|
||||||
prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len)
|
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
|
||||||
if args.target_prefix_frac is not None:
|
if args.target_prefix_frac is not None:
|
||||||
target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac)
|
target_prefix_frac_rescore_file = (
|
||||||
|
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
|
||||||
|
)
|
||||||
if args.source_prefix_frac is not None:
|
if args.source_prefix_frac is not None:
|
||||||
source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac)
|
source_prefix_frac_rescore_file = (
|
||||||
|
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
|
||||||
|
)
|
||||||
|
|
||||||
if not args.right_to_left1 or not args.right_to_left2:
|
if not args.right_to_left1 or not args.right_to_left2:
|
||||||
if not args.diff_bpe:
|
if not args.diff_bpe:
|
||||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
rerank_utils.write_reprocessed(
|
||||||
pre_gen+rescore_file+"."+args.source_lang,
|
gen_output.source,
|
||||||
pre_gen+rescore_file+"."+args.target_lang,
|
gen_output.hypo,
|
||||||
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe)
|
gen_output.target,
|
||||||
|
pre_gen + rescore_file + "." + args.source_lang,
|
||||||
|
pre_gen + rescore_file + "." + args.target_lang,
|
||||||
|
pre_gen + "/reference_file",
|
||||||
|
bpe_symbol=args.remove_bpe,
|
||||||
|
)
|
||||||
if args.prefix_len is not None:
|
if args.prefix_len is not None:
|
||||||
bw_rescore_file = prefix_len_rescore_file
|
bw_rescore_file = prefix_len_rescore_file
|
||||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
rerank_utils.write_reprocessed(
|
||||||
pre_gen+prefix_len_rescore_file+"."+args.source_lang,
|
gen_output.source,
|
||||||
pre_gen+prefix_len_rescore_file+"."+args.target_lang,
|
gen_output.hypo,
|
||||||
pre_gen+"/reference_file", prefix_len=args.prefix_len,
|
gen_output.target,
|
||||||
bpe_symbol=args.remove_bpe)
|
pre_gen + prefix_len_rescore_file + "." + args.source_lang,
|
||||||
|
pre_gen + prefix_len_rescore_file + "." + args.target_lang,
|
||||||
|
pre_gen + "/reference_file",
|
||||||
|
prefix_len=args.prefix_len,
|
||||||
|
bpe_symbol=args.remove_bpe,
|
||||||
|
)
|
||||||
elif args.target_prefix_frac is not None:
|
elif args.target_prefix_frac is not None:
|
||||||
bw_rescore_file = target_prefix_frac_rescore_file
|
bw_rescore_file = target_prefix_frac_rescore_file
|
||||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
rerank_utils.write_reprocessed(
|
||||||
pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang,
|
gen_output.source,
|
||||||
pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang,
|
gen_output.hypo,
|
||||||
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
|
gen_output.target,
|
||||||
target_prefix_frac=args.target_prefix_frac)
|
pre_gen
|
||||||
|
+ target_prefix_frac_rescore_file
|
||||||
|
+ "."
|
||||||
|
+ args.source_lang,
|
||||||
|
pre_gen
|
||||||
|
+ target_prefix_frac_rescore_file
|
||||||
|
+ "."
|
||||||
|
+ args.target_lang,
|
||||||
|
pre_gen + "/reference_file",
|
||||||
|
bpe_symbol=args.remove_bpe,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
bw_rescore_file = rescore_file
|
bw_rescore_file = rescore_file
|
||||||
|
|
||||||
if args.source_prefix_frac is not None:
|
if args.source_prefix_frac is not None:
|
||||||
fw_rescore_file = source_prefix_frac_rescore_file
|
fw_rescore_file = source_prefix_frac_rescore_file
|
||||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
rerank_utils.write_reprocessed(
|
||||||
pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang,
|
gen_output.source,
|
||||||
pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang,
|
gen_output.hypo,
|
||||||
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
|
gen_output.target,
|
||||||
source_prefix_frac=args.source_prefix_frac)
|
pre_gen
|
||||||
|
+ source_prefix_frac_rescore_file
|
||||||
|
+ "."
|
||||||
|
+ args.source_lang,
|
||||||
|
pre_gen
|
||||||
|
+ source_prefix_frac_rescore_file
|
||||||
|
+ "."
|
||||||
|
+ args.target_lang,
|
||||||
|
pre_gen + "/reference_file",
|
||||||
|
bpe_symbol=args.remove_bpe,
|
||||||
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
fw_rescore_file = rescore_file
|
fw_rescore_file = rescore_file
|
||||||
|
|
||||||
if args.right_to_left1 or args.right_to_left2:
|
if args.right_to_left1 or args.right_to_left2:
|
||||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
rerank_utils.write_reprocessed(
|
||||||
pre_gen+"/right_to_left_rescore_data."+args.source_lang,
|
gen_output.source,
|
||||||
pre_gen+"/right_to_left_rescore_data."+args.target_lang,
|
gen_output.hypo,
|
||||||
pre_gen+"/right_to_left_reference_file",
|
gen_output.target,
|
||||||
right_to_left=True, bpe_symbol=args.remove_bpe)
|
pre_gen + "/right_to_left_rescore_data." + args.source_lang,
|
||||||
|
pre_gen + "/right_to_left_rescore_data." + args.target_lang,
|
||||||
|
pre_gen + "/right_to_left_reference_file",
|
||||||
|
right_to_left=True,
|
||||||
|
bpe_symbol=args.remove_bpe,
|
||||||
|
)
|
||||||
|
|
||||||
print("STEP 3: binarize the translations")
|
print("STEP 3: binarize the translations")
|
||||||
if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen:
|
if (
|
||||||
|
not args.right_to_left1
|
||||||
|
or args.score_model2 is not None
|
||||||
|
and not args.right_to_left2
|
||||||
|
or not rerank1_is_gen
|
||||||
|
):
|
||||||
|
|
||||||
if args.backwards1 or args.backwards2:
|
if args.backwards1 or args.backwards2:
|
||||||
if args.backwards_score_dict_dir is not None:
|
if args.backwards_score_dict_dir is not None:
|
||||||
bw_dict = args.backwards_score_dict_dir
|
bw_dict = args.backwards_score_dict_dir
|
||||||
else:
|
else:
|
||||||
bw_dict = args.score_dict_dir
|
bw_dict = args.score_dict_dir
|
||||||
bw_preprocess_param = ["--source-lang", scorer1_src,
|
bw_preprocess_param = [
|
||||||
"--target-lang", scorer1_tgt,
|
"--source-lang",
|
||||||
"--trainpref", pre_gen+bw_rescore_file,
|
scorer1_src,
|
||||||
"--srcdict", bw_dict + "/dict." + scorer1_src + ".txt",
|
"--target-lang",
|
||||||
"--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt",
|
scorer1_tgt,
|
||||||
"--destdir", backwards_preprocessed_dir]
|
"--trainpref",
|
||||||
|
pre_gen + bw_rescore_file,
|
||||||
|
"--srcdict",
|
||||||
|
bw_dict + "/dict." + scorer1_src + ".txt",
|
||||||
|
"--tgtdict",
|
||||||
|
bw_dict + "/dict." + scorer1_tgt + ".txt",
|
||||||
|
"--destdir",
|
||||||
|
backwards_preprocessed_dir,
|
||||||
|
]
|
||||||
preprocess_parser = options.get_preprocessing_parser()
|
preprocess_parser = options.get_preprocessing_parser()
|
||||||
input_args = preprocess_parser.parse_args(bw_preprocess_param)
|
input_args = preprocess_parser.parse_args(bw_preprocess_param)
|
||||||
preprocess.main(input_args)
|
preprocess.main(input_args)
|
||||||
|
|
||||||
preprocess_param = ["--source-lang", scorer1_src,
|
preprocess_param = [
|
||||||
"--target-lang", scorer1_tgt,
|
"--source-lang",
|
||||||
"--trainpref", pre_gen+fw_rescore_file,
|
scorer1_src,
|
||||||
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
|
"--target-lang",
|
||||||
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
|
scorer1_tgt,
|
||||||
"--destdir", left_to_right_preprocessed_dir]
|
"--trainpref",
|
||||||
|
pre_gen + fw_rescore_file,
|
||||||
|
"--srcdict",
|
||||||
|
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
|
||||||
|
"--tgtdict",
|
||||||
|
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
|
||||||
|
"--destdir",
|
||||||
|
left_to_right_preprocessed_dir,
|
||||||
|
]
|
||||||
preprocess_parser = options.get_preprocessing_parser()
|
preprocess_parser = options.get_preprocessing_parser()
|
||||||
input_args = preprocess_parser.parse_args(preprocess_param)
|
input_args = preprocess_parser.parse_args(preprocess_param)
|
||||||
preprocess.main(input_args)
|
preprocess.main(input_args)
|
||||||
|
|
||||||
if args.right_to_left1 or args.right_to_left2:
|
if args.right_to_left1 or args.right_to_left2:
|
||||||
preprocess_param = ["--source-lang", scorer1_src,
|
preprocess_param = [
|
||||||
"--target-lang", scorer1_tgt,
|
"--source-lang",
|
||||||
"--trainpref", pre_gen+"/right_to_left_rescore_data",
|
scorer1_src,
|
||||||
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
|
"--target-lang",
|
||||||
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
|
scorer1_tgt,
|
||||||
"--destdir", right_to_left_preprocessed_dir]
|
"--trainpref",
|
||||||
|
pre_gen + "/right_to_left_rescore_data",
|
||||||
|
"--srcdict",
|
||||||
|
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
|
||||||
|
"--tgtdict",
|
||||||
|
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
|
||||||
|
"--destdir",
|
||||||
|
right_to_left_preprocessed_dir,
|
||||||
|
]
|
||||||
preprocess_parser = options.get_preprocessing_parser()
|
preprocess_parser = options.get_preprocessing_parser()
|
||||||
input_args = preprocess_parser.parse_args(preprocess_param)
|
input_args = preprocess_parser.parse_args(preprocess_param)
|
||||||
preprocess.main(input_args)
|
preprocess.main(input_args)
|
||||||
@ -241,5 +393,5 @@ def cli_main():
|
|||||||
gen_and_reprocess_nbest(args)
|
gen_and_reprocess_nbest(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
@ -6,14 +6,14 @@
|
|||||||
from fairseq import options
|
from fairseq import options
|
||||||
|
|
||||||
|
|
||||||
def get_reranking_parser(default_task='translation'):
|
def get_reranking_parser(default_task="translation"):
|
||||||
parser = options.get_parser('Generation and reranking', default_task)
|
parser = options.get_parser("Generation and reranking", default_task)
|
||||||
add_reranking_args(parser)
|
add_reranking_args(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_tuning_parser(default_task='translation'):
|
def get_tuning_parser(default_task="translation"):
|
||||||
parser = options.get_parser('Reranking tuning', default_task)
|
parser = options.get_parser("Reranking tuning", default_task)
|
||||||
add_reranking_args(parser)
|
add_reranking_args(parser)
|
||||||
add_tuning_args(parser)
|
add_tuning_args(parser)
|
||||||
return parser
|
return parser
|
||||||
@ -110,17 +110,40 @@ def add_reranking_args(parser):
|
|||||||
def add_tuning_args(parser):
|
def add_tuning_args(parser):
|
||||||
group = parser.add_argument_group("Tuning")
|
group = parser.add_argument_group("Tuning")
|
||||||
|
|
||||||
group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float,
|
group.add_argument(
|
||||||
help='lower bound of search space')
|
"--lower-bound",
|
||||||
group.add_argument('--upper-bound', default=[3], nargs='+', type=float,
|
default=[-0.7],
|
||||||
help='upper bound of search space')
|
nargs="+",
|
||||||
group.add_argument('--tune-param', default=['lenpen'], nargs='+',
|
type=float,
|
||||||
choices=['lenpen', 'weight1', 'weight2', 'weight3'],
|
help="lower bound of search space",
|
||||||
help='the parameter(s) to tune')
|
)
|
||||||
group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'],
|
group.add_argument(
|
||||||
help='the subset to tune on ')
|
"--upper-bound",
|
||||||
group.add_argument('--num-trials', default=1000, type=int,
|
default=[3],
|
||||||
help='number of trials to do for random search')
|
nargs="+",
|
||||||
group.add_argument('--share-weights', action='store_true',
|
type=float,
|
||||||
help='share weight2 and weight 3')
|
help="upper bound of search space",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--tune-param",
|
||||||
|
default=["lenpen"],
|
||||||
|
nargs="+",
|
||||||
|
choices=["lenpen", "weight1", "weight2", "weight3"],
|
||||||
|
help="the parameter(s) to tune",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--tune-subset",
|
||||||
|
default="valid",
|
||||||
|
choices=["valid", "test", "train"],
|
||||||
|
help="the subset to tune on ",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-trials",
|
||||||
|
default=1000,
|
||||||
|
type=int,
|
||||||
|
help="number of trials to do for random search",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--share-weights", action="store_true", help="share weight2 and weight 3"
|
||||||
|
)
|
||||||
return group
|
return group
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from contextlib import redirect_stdout
|
|
||||||
import os
|
import os
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
from fairseq import options
|
from fairseq import options
|
||||||
from fairseq_cli import generate
|
from fairseq_cli import generate
|
||||||
@ -13,82 +13,124 @@ from . import rerank_options, rerank_utils
|
|||||||
|
|
||||||
|
|
||||||
def score_bw(args):
|
def score_bw(args):
|
||||||
if args.backwards1:
|
if args.backwards1:
|
||||||
scorer1_src = args.target_lang
|
scorer1_src = args.target_lang
|
||||||
scorer1_tgt = args.source_lang
|
scorer1_tgt = args.source_lang
|
||||||
|
else:
|
||||||
|
scorer1_src = args.source_lang
|
||||||
|
scorer1_tgt = args.target_lang
|
||||||
|
|
||||||
|
if args.score_model2 is not None:
|
||||||
|
if args.backwards2:
|
||||||
|
scorer2_src = args.target_lang
|
||||||
|
scorer2_tgt = args.source_lang
|
||||||
else:
|
else:
|
||||||
scorer1_src = args.source_lang
|
scorer2_src = args.source_lang
|
||||||
scorer1_tgt = args.target_lang
|
scorer2_tgt = args.target_lang
|
||||||
|
|
||||||
if args.score_model2 is not None:
|
rerank1_is_gen = (
|
||||||
if args.backwards2:
|
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||||
scorer2_src = args.target_lang
|
)
|
||||||
scorer2_tgt = args.source_lang
|
rerank2_is_gen = (
|
||||||
else:
|
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||||
scorer2_src = args.source_lang
|
)
|
||||||
scorer2_tgt = args.target_lang
|
|
||||||
|
|
||||||
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
(
|
||||||
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
pre_gen,
|
||||||
|
left_to_right_preprocessed_dir,
|
||||||
|
right_to_left_preprocessed_dir,
|
||||||
|
backwards_preprocessed_dir,
|
||||||
|
lm_preprocessed_dir,
|
||||||
|
) = rerank_utils.get_directories(
|
||||||
|
args.data_dir_name,
|
||||||
|
args.num_rescore,
|
||||||
|
args.gen_subset,
|
||||||
|
args.gen_model_name,
|
||||||
|
args.shard_id,
|
||||||
|
args.num_shards,
|
||||||
|
args.sampling,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
score1_file = rerank_utils.rescore_file_name(
|
||||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
pre_gen,
|
||||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
args.prefix_len,
|
||||||
args.gen_model_name, args.shard_id, args.num_shards,
|
args.model1_name,
|
||||||
args.sampling, args.prefix_len, args.target_prefix_frac,
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
args.source_prefix_frac)
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
backwards=args.backwards1,
|
||||||
|
)
|
||||||
|
|
||||||
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
|
if args.score_model2 is not None:
|
||||||
target_prefix_frac=args.target_prefix_frac,
|
score2_file = rerank_utils.rescore_file_name(
|
||||||
source_prefix_frac=args.source_prefix_frac,
|
pre_gen,
|
||||||
backwards=args.backwards1)
|
args.prefix_len,
|
||||||
|
args.model2_name,
|
||||||
|
target_prefix_frac=args.target_prefix_frac,
|
||||||
|
source_prefix_frac=args.source_prefix_frac,
|
||||||
|
backwards=args.backwards2,
|
||||||
|
)
|
||||||
|
|
||||||
if args.score_model2 is not None:
|
if args.right_to_left1:
|
||||||
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
|
rerank_data1 = right_to_left_preprocessed_dir
|
||||||
target_prefix_frac=args.target_prefix_frac,
|
elif args.backwards1:
|
||||||
source_prefix_frac=args.source_prefix_frac,
|
rerank_data1 = backwards_preprocessed_dir
|
||||||
backwards=args.backwards2)
|
else:
|
||||||
|
rerank_data1 = left_to_right_preprocessed_dir
|
||||||
|
|
||||||
if args.right_to_left1:
|
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
|
||||||
rerank_data1 = right_to_left_preprocessed_dir
|
if not rerank1_is_gen and not os.path.isfile(score1_file):
|
||||||
elif args.backwards1:
|
print("STEP 4: score the translations for model 1")
|
||||||
rerank_data1 = backwards_preprocessed_dir
|
|
||||||
|
model_param1 = [
|
||||||
|
"--path",
|
||||||
|
args.score_model1,
|
||||||
|
"--source-lang",
|
||||||
|
scorer1_src,
|
||||||
|
"--target-lang",
|
||||||
|
scorer1_tgt,
|
||||||
|
]
|
||||||
|
gen_model1_param = [rerank_data1] + gen_param + model_param1
|
||||||
|
|
||||||
|
gen_parser = options.get_generation_parser()
|
||||||
|
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
|
||||||
|
|
||||||
|
with open(score1_file, "w") as f:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
generate.main(input_args)
|
||||||
|
|
||||||
|
if (
|
||||||
|
args.score_model2 is not None
|
||||||
|
and not os.path.isfile(score2_file)
|
||||||
|
and not rerank2_is_gen
|
||||||
|
):
|
||||||
|
print("STEP 4: score the translations for model 2")
|
||||||
|
|
||||||
|
if args.right_to_left2:
|
||||||
|
rerank_data2 = right_to_left_preprocessed_dir
|
||||||
|
elif args.backwards2:
|
||||||
|
rerank_data2 = backwards_preprocessed_dir
|
||||||
else:
|
else:
|
||||||
rerank_data1 = left_to_right_preprocessed_dir
|
rerank_data2 = left_to_right_preprocessed_dir
|
||||||
|
|
||||||
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
|
model_param2 = [
|
||||||
if not rerank1_is_gen and not os.path.isfile(score1_file):
|
"--path",
|
||||||
print("STEP 4: score the translations for model 1")
|
args.score_model2,
|
||||||
|
"--source-lang",
|
||||||
|
scorer2_src,
|
||||||
|
"--target-lang",
|
||||||
|
scorer2_tgt,
|
||||||
|
]
|
||||||
|
gen_model2_param = [rerank_data2] + gen_param + model_param2
|
||||||
|
|
||||||
model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt]
|
gen_parser = options.get_generation_parser()
|
||||||
gen_model1_param = [rerank_data1] + gen_param + model_param1
|
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
|
||||||
|
|
||||||
gen_parser = options.get_generation_parser()
|
with open(score2_file, "w") as f:
|
||||||
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
|
with redirect_stdout(f):
|
||||||
|
generate.main(input_args)
|
||||||
with open(score1_file, 'w') as f:
|
|
||||||
with redirect_stdout(f):
|
|
||||||
generate.main(input_args)
|
|
||||||
|
|
||||||
if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
|
|
||||||
print("STEP 4: score the translations for model 2")
|
|
||||||
|
|
||||||
if args.right_to_left2:
|
|
||||||
rerank_data2 = right_to_left_preprocessed_dir
|
|
||||||
elif args.backwards2:
|
|
||||||
rerank_data2 = backwards_preprocessed_dir
|
|
||||||
else:
|
|
||||||
rerank_data2 = left_to_right_preprocessed_dir
|
|
||||||
|
|
||||||
model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
|
|
||||||
gen_model2_param = [rerank_data2] + gen_param + model_param2
|
|
||||||
|
|
||||||
gen_parser = options.get_generation_parser()
|
|
||||||
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
|
|
||||||
|
|
||||||
with open(score2_file, 'w') as f:
|
|
||||||
with redirect_stdout(f):
|
|
||||||
generate.main(input_args)
|
|
||||||
|
|
||||||
|
|
||||||
def cli_main():
|
def cli_main():
|
||||||
@ -97,5 +139,5 @@ def cli_main():
|
|||||||
score_bw(args)
|
score_bw(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
@ -12,22 +12,38 @@ from . import rerank_options, rerank_utils
|
|||||||
|
|
||||||
def score_lm(args):
|
def score_lm(args):
|
||||||
using_nbest = args.nbest_list is not None
|
using_nbest = args.nbest_list is not None
|
||||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
(
|
||||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
pre_gen,
|
||||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
left_to_right_preprocessed_dir,
|
||||||
args.gen_model_name, args.shard_id, args.num_shards,
|
right_to_left_preprocessed_dir,
|
||||||
args.sampling, args.prefix_len, args.target_prefix_frac,
|
backwards_preprocessed_dir,
|
||||||
args.source_prefix_frac)
|
lm_preprocessed_dir,
|
||||||
|
) = rerank_utils.get_directories(
|
||||||
|
args.data_dir_name,
|
||||||
|
args.num_rescore,
|
||||||
|
args.gen_subset,
|
||||||
|
args.gen_model_name,
|
||||||
|
args.shard_id,
|
||||||
|
args.num_shards,
|
||||||
|
args.sampling,
|
||||||
|
args.prefix_len,
|
||||||
|
args.target_prefix_frac,
|
||||||
|
args.source_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
|
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
||||||
if using_nbest:
|
if using_nbest:
|
||||||
print("Using predefined n-best list from interactive.py")
|
print("Using predefined n-best list from interactive.py")
|
||||||
predictions_bpe_file = args.nbest_list
|
predictions_bpe_file = args.nbest_list
|
||||||
|
|
||||||
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest)
|
gen_output = rerank_utils.BitextOutputFromGen(
|
||||||
|
predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
|
||||||
|
)
|
||||||
|
|
||||||
if args.language_model is not None:
|
if args.language_model is not None:
|
||||||
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
|
lm_score_file = rerank_utils.rescore_file_name(
|
||||||
|
pre_gen, args.prefix_len, args.lm_name, lm_file=True
|
||||||
|
)
|
||||||
|
|
||||||
if args.language_model is not None and not os.path.isfile(lm_score_file):
|
if args.language_model is not None and not os.path.isfile(lm_score_file):
|
||||||
print("STEP 4.5: language modeling for P(T)")
|
print("STEP 4.5: language modeling for P(T)")
|
||||||
@ -38,10 +54,21 @@ def score_lm(args):
|
|||||||
else:
|
else:
|
||||||
bpe_status = "different"
|
bpe_status = "different"
|
||||||
|
|
||||||
rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen,
|
rerank_utils.lm_scoring(
|
||||||
args.lm_dict, args.lm_name, args.language_model,
|
lm_preprocessed_dir,
|
||||||
args.lm_bpe_code, 128, lm_score_file, args.target_lang,
|
bpe_status,
|
||||||
args.source_lang, prefix_len=args.prefix_len)
|
gen_output,
|
||||||
|
pre_gen,
|
||||||
|
args.lm_dict,
|
||||||
|
args.lm_name,
|
||||||
|
args.language_model,
|
||||||
|
args.lm_bpe_code,
|
||||||
|
128,
|
||||||
|
lm_score_file,
|
||||||
|
args.target_lang,
|
||||||
|
args.source_lang,
|
||||||
|
prefix_len=args.prefix_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cli_main():
|
def cli_main():
|
||||||
@ -50,5 +77,5 @@ def cli_main():
|
|||||||
score_lm(args)
|
score_lm(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from fairseq import options
|
from fairseq import options
|
||||||
|
|
||||||
from . import rerank, rerank_options
|
from . import rerank, rerank_options
|
||||||
@ -14,7 +14,7 @@ from . import rerank, rerank_options
|
|||||||
|
|
||||||
def random_search(args):
|
def random_search(args):
|
||||||
param_values = []
|
param_values = []
|
||||||
tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3']
|
tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
|
||||||
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
|
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
|
||||||
for i, elem in enumerate(initial_params):
|
for i, elem in enumerate(initial_params):
|
||||||
if type(elem) is not list:
|
if type(elem) is not list:
|
||||||
@ -33,51 +33,60 @@ def random_search(args):
|
|||||||
param_values += initial_params
|
param_values += initial_params
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
random_params = np.array([
|
random_params = np.array(
|
||||||
[random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))]
|
[
|
||||||
for k in range(args.num_trials)
|
[
|
||||||
])
|
random.uniform(args.lower_bound[i], args.upper_bound[i])
|
||||||
set_params = np.array([
|
for i in range(len(args.tune_param))
|
||||||
[initial_params[i][0] for i in range(len(tuneable_parameters))]
|
]
|
||||||
for k in range(args.num_trials)
|
for k in range(args.num_trials)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
set_params = np.array(
|
||||||
|
[
|
||||||
|
[initial_params[i][0] for i in range(len(tuneable_parameters))]
|
||||||
|
for k in range(args.num_trials)
|
||||||
|
]
|
||||||
|
)
|
||||||
random_params = np.concatenate((random_params, set_params), 1)
|
random_params = np.concatenate((random_params, set_params), 1)
|
||||||
|
|
||||||
rerank_args = vars(args).copy()
|
rerank_args = vars(args).copy()
|
||||||
if args.nbest_list:
|
if args.nbest_list:
|
||||||
rerank_args['gen_subset'] = 'test'
|
rerank_args["gen_subset"] = "test"
|
||||||
else:
|
else:
|
||||||
rerank_args['gen_subset'] = args.tune_subset
|
rerank_args["gen_subset"] = args.tune_subset
|
||||||
|
|
||||||
for k in range(len(tune_parameters)):
|
for k in range(len(tune_parameters)):
|
||||||
rerank_args[tune_parameters[k]] = list(random_params[:, k])
|
rerank_args[tune_parameters[k]] = list(random_params[:, k])
|
||||||
|
|
||||||
if args.share_weights:
|
if args.share_weights:
|
||||||
k = tune_parameters.index('weight2')
|
k = tune_parameters.index("weight2")
|
||||||
rerank_args['weight3'] = list(random_params[:, k])
|
rerank_args["weight3"] = list(random_params[:, k])
|
||||||
|
|
||||||
rerank_args = argparse.Namespace(**rerank_args)
|
rerank_args = argparse.Namespace(**rerank_args)
|
||||||
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args)
|
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
|
||||||
|
rerank_args
|
||||||
|
)
|
||||||
rerank_args = vars(args).copy()
|
rerank_args = vars(args).copy()
|
||||||
rerank_args['lenpen'] = [best_lenpen]
|
rerank_args["lenpen"] = [best_lenpen]
|
||||||
rerank_args['weight1'] = [best_weight1]
|
rerank_args["weight1"] = [best_weight1]
|
||||||
rerank_args['weight2'] = [best_weight2]
|
rerank_args["weight2"] = [best_weight2]
|
||||||
rerank_args['weight3'] = [best_weight3]
|
rerank_args["weight3"] = [best_weight3]
|
||||||
|
|
||||||
# write the hypothesis from the valid set from the best trial
|
# write the hypothesis from the valid set from the best trial
|
||||||
|
|
||||||
if args.gen_subset != "valid":
|
if args.gen_subset != "valid":
|
||||||
rerank_args['gen_subset'] = "valid"
|
rerank_args["gen_subset"] = "valid"
|
||||||
rerank_args = argparse.Namespace(**rerank_args)
|
rerank_args = argparse.Namespace(**rerank_args)
|
||||||
rerank.rerank(rerank_args)
|
rerank.rerank(rerank_args)
|
||||||
|
|
||||||
# test with the best hyperparameters on gen subset
|
# test with the best hyperparameters on gen subset
|
||||||
rerank_args = vars(args).copy()
|
rerank_args = vars(args).copy()
|
||||||
rerank_args['gen_subset'] = args.gen_subset
|
rerank_args["gen_subset"] = args.gen_subset
|
||||||
rerank_args['lenpen'] = [best_lenpen]
|
rerank_args["lenpen"] = [best_lenpen]
|
||||||
rerank_args['weight1'] = [best_weight1]
|
rerank_args["weight1"] = [best_weight1]
|
||||||
rerank_args['weight2'] = [best_weight2]
|
rerank_args["weight2"] = [best_weight2]
|
||||||
rerank_args['weight3'] = [best_weight3]
|
rerank_args["weight3"] = [best_weight3]
|
||||||
rerank_args = argparse.Namespace(**rerank_args)
|
rerank_args = argparse.Namespace(**rerank_args)
|
||||||
rerank.rerank(rerank_args)
|
rerank.rerank(rerank_args)
|
||||||
|
|
||||||
@ -89,5 +98,5 @@ def cli_main():
|
|||||||
random_search(args)
|
random_search(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
@ -3,11 +3,11 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from contextlib import redirect_stdout
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
from fairseq import options
|
from fairseq import options
|
||||||
from fairseq_cli import eval_lm, preprocess
|
from fairseq_cli import eval_lm, preprocess
|
||||||
@ -20,7 +20,7 @@ def reprocess(fle):
|
|||||||
# per source, so the values for hypothesis_dict are lists.
|
# per source, so the values for hypothesis_dict are lists.
|
||||||
# parses output of generate.py
|
# parses output of generate.py
|
||||||
|
|
||||||
with open(fle, 'r') as f:
|
with open(fle, "r") as f:
|
||||||
txt = f.read()
|
txt = f.read()
|
||||||
|
|
||||||
"""reprocess generate.py output"""
|
"""reprocess generate.py output"""
|
||||||
@ -45,7 +45,9 @@ def reprocess(fle):
|
|||||||
if line_type == "H":
|
if line_type == "H":
|
||||||
h_txt = line[j:]
|
h_txt = line[j:]
|
||||||
hypo = re.search(hp, h_txt)
|
hypo = re.search(hp, h_txt)
|
||||||
assert hypo is not None, ("regular expression failed to find the hypothesis scoring")
|
assert (
|
||||||
|
hypo is not None
|
||||||
|
), "regular expression failed to find the hypothesis scoring"
|
||||||
_, i = hypo.span()
|
_, i = hypo.span()
|
||||||
score = hypo.group()
|
score = hypo.group()
|
||||||
if id_num in hypothesis_dict:
|
if id_num in hypothesis_dict:
|
||||||
@ -56,9 +58,9 @@ def reprocess(fle):
|
|||||||
score_dict[id_num] = [float(score)]
|
score_dict[id_num] = [float(score)]
|
||||||
|
|
||||||
elif line_type == "S":
|
elif line_type == "S":
|
||||||
source_dict[id_num] = (line[j:])
|
source_dict[id_num] = line[j:]
|
||||||
elif line_type == "T":
|
elif line_type == "T":
|
||||||
target_dict[id_num] = (line[j:])
|
target_dict[id_num] = line[j:]
|
||||||
elif line_type == "P":
|
elif line_type == "P":
|
||||||
pos_scores = (line[j:]).split()
|
pos_scores = (line[j:]).split()
|
||||||
pos_scores = [float(x) for x in pos_scores]
|
pos_scores = [float(x) for x in pos_scores]
|
||||||
@ -72,7 +74,7 @@ def reprocess(fle):
|
|||||||
|
|
||||||
def reprocess_nbest(fle):
|
def reprocess_nbest(fle):
|
||||||
"""reprocess interactive.py output"""
|
"""reprocess interactive.py output"""
|
||||||
with open(fle, 'r') as f:
|
with open(fle, "r") as f:
|
||||||
txt = f.read()
|
txt = f.read()
|
||||||
|
|
||||||
source_dict = {}
|
source_dict = {}
|
||||||
@ -82,7 +84,7 @@ def reprocess_nbest(fle):
|
|||||||
pos_score_dict = {}
|
pos_score_dict = {}
|
||||||
lines = txt.split("\n")
|
lines = txt.split("\n")
|
||||||
|
|
||||||
hp = re.compile(r'[-]?\d+[.]?\d+')
|
hp = re.compile(r"[-]?\d+[.]?\d+")
|
||||||
j = -1
|
j = -1
|
||||||
|
|
||||||
for _i, line in enumerate(lines):
|
for _i, line in enumerate(lines):
|
||||||
@ -119,59 +121,76 @@ def reprocess_nbest(fle):
|
|||||||
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
|
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
|
||||||
|
|
||||||
|
|
||||||
def write_reprocessed(sources, hypos, targets, source_outfile,
|
def write_reprocessed(
|
||||||
hypo_outfile, target_outfile, right_to_left=False,
|
sources,
|
||||||
prefix_len=None, bpe_symbol=None,
|
hypos,
|
||||||
target_prefix_frac=None, source_prefix_frac=None):
|
targets,
|
||||||
|
source_outfile,
|
||||||
|
hypo_outfile,
|
||||||
|
target_outfile,
|
||||||
|
right_to_left=False,
|
||||||
|
prefix_len=None,
|
||||||
|
bpe_symbol=None,
|
||||||
|
target_prefix_frac=None,
|
||||||
|
source_prefix_frac=None,
|
||||||
|
):
|
||||||
|
|
||||||
"""writes nbest hypothesis for rescoring"""
|
"""writes nbest hypothesis for rescoring"""
|
||||||
assert not (prefix_len is not None and target_prefix_frac is not None), \
|
assert not (
|
||||||
"in writing reprocessed, only one type of prefix may be used"
|
prefix_len is not None and target_prefix_frac is not None
|
||||||
assert not (prefix_len is not None and source_prefix_frac is not None), \
|
), "in writing reprocessed, only one type of prefix may be used"
|
||||||
"in writing reprocessed, only one type of prefix may be used"
|
assert not (
|
||||||
assert not (target_prefix_frac is not None and source_prefix_frac is not None), \
|
prefix_len is not None and source_prefix_frac is not None
|
||||||
"in writing reprocessed, only one type of prefix may be used"
|
), "in writing reprocessed, only one type of prefix may be used"
|
||||||
|
assert not (
|
||||||
|
target_prefix_frac is not None and source_prefix_frac is not None
|
||||||
|
), "in writing reprocessed, only one type of prefix may be used"
|
||||||
|
|
||||||
with open(source_outfile, 'w') as source_file, \
|
with open(source_outfile, "w") as source_file, open(
|
||||||
open(hypo_outfile, 'w') as hypo_file, \
|
hypo_outfile, "w"
|
||||||
open(target_outfile, 'w') as target_file:
|
) as hypo_file, open(target_outfile, "w") as target_file:
|
||||||
|
|
||||||
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
|
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
|
||||||
if right_to_left:
|
if right_to_left:
|
||||||
for i in range(len(sources)):
|
for i in range(len(sources)):
|
||||||
for j in range(len(hypos[i])):
|
for j in range(len(hypos[i])):
|
||||||
if prefix_len is None:
|
if prefix_len is None:
|
||||||
hypo_file.write(make_right_to_left(hypos[i][j])+"\n")
|
hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
source_file.write(make_right_to_left(sources[i])+"\n")
|
source_file.write(make_right_to_left(sources[i]) + "\n")
|
||||||
target_file.write(make_right_to_left(targets[i])+"\n")
|
target_file.write(make_right_to_left(targets[i]) + "\n")
|
||||||
else:
|
else:
|
||||||
for i in sorted(sources.keys()):
|
for i in sorted(sources.keys()):
|
||||||
for j in range(len(hypos[i])):
|
for j in range(len(hypos[i])):
|
||||||
if prefix_len is not None:
|
if prefix_len is not None:
|
||||||
shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n"
|
shortened = (
|
||||||
hypo_file.write(shortened)
|
get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
|
||||||
source_file.write(sources[i])
|
+ "\n"
|
||||||
target_file.write(targets[i])
|
)
|
||||||
elif target_prefix_frac is not None:
|
hypo_file.write(shortened)
|
||||||
num_words, shortened, num_bpe_tokens = \
|
source_file.write(sources[i])
|
||||||
calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol)
|
target_file.write(targets[i])
|
||||||
shortened += "\n"
|
elif target_prefix_frac is not None:
|
||||||
hypo_file.write(shortened)
|
num_words, shortened, num_bpe_tokens = calc_length_from_frac(
|
||||||
source_file.write(sources[i])
|
hypos[i][j], target_prefix_frac, bpe_symbol
|
||||||
target_file.write(targets[i])
|
)
|
||||||
elif source_prefix_frac is not None:
|
shortened += "\n"
|
||||||
num_words, shortened, num_bpe_tokensn = \
|
hypo_file.write(shortened)
|
||||||
calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol)
|
source_file.write(sources[i])
|
||||||
shortened += "\n"
|
target_file.write(targets[i])
|
||||||
hypo_file.write(hypos[i][j])
|
elif source_prefix_frac is not None:
|
||||||
source_file.write(shortened)
|
num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
|
||||||
target_file.write(targets[i])
|
sources[i], source_prefix_frac, bpe_symbol
|
||||||
else:
|
)
|
||||||
hypo_file.write(hypos[i][j])
|
shortened += "\n"
|
||||||
source_file.write(sources[i])
|
hypo_file.write(hypos[i][j])
|
||||||
target_file.write(targets[i])
|
source_file.write(shortened)
|
||||||
|
target_file.write(targets[i])
|
||||||
|
else:
|
||||||
|
hypo_file.write(hypos[i][j])
|
||||||
|
source_file.write(sources[i])
|
||||||
|
target_file.write(targets[i])
|
||||||
|
|
||||||
|
|
||||||
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
|
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
|
||||||
@ -207,7 +226,9 @@ def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
|
|||||||
if bpe_count == 0:
|
if bpe_count == 0:
|
||||||
return sentence[:prefix_len]
|
return sentence[:prefix_len]
|
||||||
else:
|
else:
|
||||||
return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count)
|
return sentence[:prefix_len] + get_prefix_from_len(
|
||||||
|
sentence[prefix_len:], bpe_symbol, bpe_count
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
|
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
|
||||||
@ -225,9 +246,9 @@ def make_right_to_left(line):
|
|||||||
|
|
||||||
|
|
||||||
def remove_bpe(line, bpe_symbol):
|
def remove_bpe(line, bpe_symbol):
|
||||||
line = line.replace("\n", '')
|
line = line.replace("\n", "")
|
||||||
line = (line + ' ').replace(bpe_symbol, '').rstrip()
|
line = (line + " ").replace(bpe_symbol, "").rstrip()
|
||||||
return line+("\n")
|
return line + ("\n")
|
||||||
|
|
||||||
|
|
||||||
def remove_bpe_dict(pred_dict, bpe_symbol):
|
def remove_bpe_dict(pred_dict, bpe_symbol):
|
||||||
@ -242,7 +263,7 @@ def remove_bpe_dict(pred_dict, bpe_symbol):
|
|||||||
|
|
||||||
|
|
||||||
def parse_bleu_scoring(line):
|
def parse_bleu_scoring(line):
|
||||||
p = re.compile(r'(BLEU4 = )\d+[.]\d+')
|
p = re.compile(r"(BLEU4 = )\d+[.]\d+")
|
||||||
res = re.search(p, line)
|
res = re.search(p, line)
|
||||||
assert res is not None, line
|
assert res is not None, line
|
||||||
return float(res.group()[8:])
|
return float(res.group()[8:])
|
||||||
@ -259,9 +280,21 @@ def get_full_from_prefix(hypo_prefix, hypos):
|
|||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
|
|
||||||
def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None,
|
def get_score(
|
||||||
lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False,
|
a,
|
||||||
bitext2_backwards=False, normalize=False):
|
b,
|
||||||
|
c,
|
||||||
|
target_len,
|
||||||
|
bitext_score1,
|
||||||
|
bitext_score2=None,
|
||||||
|
lm_score=None,
|
||||||
|
lenpen=None,
|
||||||
|
src_len=None,
|
||||||
|
tgt_len=None,
|
||||||
|
bitext1_backwards=False,
|
||||||
|
bitext2_backwards=False,
|
||||||
|
normalize=False,
|
||||||
|
):
|
||||||
if bitext1_backwards:
|
if bitext1_backwards:
|
||||||
bitext1_norm = src_len
|
bitext1_norm = src_len
|
||||||
else:
|
else:
|
||||||
@ -275,9 +308,13 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
|
|||||||
bitext2_norm = 1
|
bitext2_norm = 1
|
||||||
bitext_score2 = 0
|
bitext_score2 = 0
|
||||||
if normalize:
|
if normalize:
|
||||||
score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len
|
score = (
|
||||||
|
a * bitext_score1 / bitext1_norm
|
||||||
|
+ b * bitext_score2 / bitext2_norm
|
||||||
|
+ c * lm_score / src_len
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
score = a*bitext_score1 + b*bitext_score2+c*lm_score
|
score = a * bitext_score1 + b * bitext_score2 + c * lm_score
|
||||||
|
|
||||||
if lenpen is not None:
|
if lenpen is not None:
|
||||||
score /= (target_len) ** float(lenpen)
|
score /= (target_len) ** float(lenpen)
|
||||||
@ -286,8 +323,16 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
|
|||||||
|
|
||||||
|
|
||||||
class BitextOutput(object):
|
class BitextOutput(object):
|
||||||
def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
|
def __init__(
|
||||||
prefix_len=None, target_prefix_frac=None, source_prefix_frac=None):
|
self,
|
||||||
|
output_file,
|
||||||
|
backwards,
|
||||||
|
right_to_left,
|
||||||
|
bpe_symbol,
|
||||||
|
prefix_len=None,
|
||||||
|
target_prefix_frac=None,
|
||||||
|
source_prefix_frac=None,
|
||||||
|
):
|
||||||
"""process output from rescoring"""
|
"""process output from rescoring"""
|
||||||
source, hypo, score, target, pos_score = reprocess(output_file)
|
source, hypo, score, target, pos_score = reprocess(output_file)
|
||||||
if backwards:
|
if backwards:
|
||||||
@ -296,7 +341,9 @@ class BitextOutput(object):
|
|||||||
self.hypo_fracs = target_prefix_frac
|
self.hypo_fracs = target_prefix_frac
|
||||||
|
|
||||||
# remove length penalty so we can use raw scores
|
# remove length penalty so we can use raw scores
|
||||||
score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards)
|
score, num_bpe_tokens = get_score_from_pos(
|
||||||
|
pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
|
||||||
|
)
|
||||||
source_lengths = {}
|
source_lengths = {}
|
||||||
target_lengths = {}
|
target_lengths = {}
|
||||||
|
|
||||||
@ -341,7 +388,9 @@ class BitextOutput(object):
|
|||||||
score[i] = float(score[i][0])
|
score[i] = float(score[i][0])
|
||||||
pos_score[i] = pos_score[i][0]
|
pos_score[i] = pos_score[i][0]
|
||||||
else:
|
else:
|
||||||
assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence"
|
assert (
|
||||||
|
len(hypo[i]) == 1
|
||||||
|
), "expected only one hypothesis per source sentence"
|
||||||
source[i] = remove_bpe(source[i], bpe_symbol)
|
source[i] = remove_bpe(source[i], bpe_symbol)
|
||||||
target[i] = remove_bpe(target[i], bpe_symbol)
|
target[i] = remove_bpe(target[i], bpe_symbol)
|
||||||
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
|
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
|
||||||
@ -360,11 +409,26 @@ class BitextOutput(object):
|
|||||||
|
|
||||||
|
|
||||||
class BitextOutputFromGen(object):
|
class BitextOutputFromGen(object):
|
||||||
def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
predictions_bpe_file,
|
||||||
|
bpe_symbol=None,
|
||||||
|
nbest=False,
|
||||||
|
prefix_len=None,
|
||||||
|
target_prefix_frac=None,
|
||||||
|
):
|
||||||
if nbest:
|
if nbest:
|
||||||
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file)
|
(
|
||||||
|
pred_source,
|
||||||
|
pred_hypo,
|
||||||
|
pred_score,
|
||||||
|
pred_target,
|
||||||
|
pred_pos_score,
|
||||||
|
) = reprocess_nbest(predictions_bpe_file)
|
||||||
else:
|
else:
|
||||||
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file)
|
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
|
||||||
|
predictions_bpe_file
|
||||||
|
)
|
||||||
|
|
||||||
assert len(pred_source) == len(pred_hypo)
|
assert len(pred_source) == len(pred_hypo)
|
||||||
assert len(pred_source) == len(pred_score)
|
assert len(pred_source) == len(pred_score)
|
||||||
@ -372,8 +436,9 @@ class BitextOutputFromGen(object):
|
|||||||
assert len(pred_source) == len(pred_pos_score)
|
assert len(pred_source) == len(pred_pos_score)
|
||||||
|
|
||||||
# remove length penalty so we can use raw scores
|
# remove length penalty so we can use raw scores
|
||||||
pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo,
|
pred_score, num_bpe_tokens = get_score_from_pos(
|
||||||
bpe_symbol, target_prefix_frac, False)
|
pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
|
||||||
|
)
|
||||||
|
|
||||||
self.source = pred_source
|
self.source = pred_source
|
||||||
self.target = pred_target
|
self.target = pred_target
|
||||||
@ -414,7 +479,9 @@ class BitextOutputFromGen(object):
|
|||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards):
|
def get_score_from_pos(
|
||||||
|
pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
|
||||||
|
):
|
||||||
score_dict = {}
|
score_dict = {}
|
||||||
num_bpe_tokens_dict = {}
|
num_bpe_tokens_dict = {}
|
||||||
assert prefix_len is None or hypo_frac is None
|
assert prefix_len is None or hypo_frac is None
|
||||||
@ -423,11 +490,15 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
|
|||||||
num_bpe_tokens_dict[key] = []
|
num_bpe_tokens_dict[key] = []
|
||||||
for i in range(len(pos_score_dict[key])):
|
for i in range(len(pos_score_dict[key])):
|
||||||
if prefix_len is not None and not backwards:
|
if prefix_len is not None and not backwards:
|
||||||
num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len)
|
num_bpe_tokens = get_num_bpe_tokens_from_len(
|
||||||
|
hypo_dict[key][i], bpe_symbol, prefix_len
|
||||||
|
)
|
||||||
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
|
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
|
||||||
num_bpe_tokens_dict[key].append(num_bpe_tokens)
|
num_bpe_tokens_dict[key].append(num_bpe_tokens)
|
||||||
elif hypo_frac is not None:
|
elif hypo_frac is not None:
|
||||||
num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol)
|
num_words, shortened, hypo_prefix_len = calc_length_from_frac(
|
||||||
|
hypo_dict[key][i], hypo_frac, bpe_symbol
|
||||||
|
)
|
||||||
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
|
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
|
||||||
num_bpe_tokens_dict[key].append(hypo_prefix_len)
|
num_bpe_tokens_dict[key].append(hypo_prefix_len)
|
||||||
else:
|
else:
|
||||||
@ -437,10 +508,26 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
|
|||||||
|
|
||||||
|
|
||||||
class LMOutput(object):
|
class LMOutput(object):
|
||||||
def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
|
def __init__(
|
||||||
lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \
|
self,
|
||||||
parse_lm(lm_score_file, prefix_len=prefix_len,
|
lm_score_file,
|
||||||
bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac)
|
lm_dict=None,
|
||||||
|
prefix_len=None,
|
||||||
|
bpe_symbol=None,
|
||||||
|
target_prefix_frac=None,
|
||||||
|
):
|
||||||
|
(
|
||||||
|
lm_sentences,
|
||||||
|
lm_sen_scores,
|
||||||
|
lm_sen_pos_scores,
|
||||||
|
lm_no_bpe_sentences,
|
||||||
|
lm_bpe_tokens,
|
||||||
|
) = parse_lm(
|
||||||
|
lm_score_file,
|
||||||
|
prefix_len=prefix_len,
|
||||||
|
bpe_symbol=bpe_symbol,
|
||||||
|
target_prefix_frac=target_prefix_frac,
|
||||||
|
)
|
||||||
|
|
||||||
self.sentences = lm_sentences
|
self.sentences = lm_sentences
|
||||||
self.score = lm_sen_scores
|
self.score = lm_sen_scores
|
||||||
@ -452,7 +539,7 @@ class LMOutput(object):
|
|||||||
|
|
||||||
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
|
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
|
||||||
"""parse output of eval_lm"""
|
"""parse output of eval_lm"""
|
||||||
with open(input_file, 'r') as f:
|
with open(input_file, "r") as f:
|
||||||
text = f.readlines()
|
text = f.readlines()
|
||||||
text = text[7:]
|
text = text[7:]
|
||||||
cleaned_text = text[:-2]
|
cleaned_text = text[:-2]
|
||||||
@ -467,20 +554,23 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
|
|||||||
if tokens[0].isdigit():
|
if tokens[0].isdigit():
|
||||||
line_id = int(tokens[0])
|
line_id = int(tokens[0])
|
||||||
scores = [float(x[1:-1]) for x in tokens[2::2]]
|
scores = [float(x[1:-1]) for x in tokens[2::2]]
|
||||||
sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n"
|
sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
|
||||||
if bpe_symbol is not None:
|
if bpe_symbol is not None:
|
||||||
# exclude <eos> symbol to match output from generate.py
|
# exclude <eos> symbol to match output from generate.py
|
||||||
bpe_sen = " ".join(tokens[1::2][:-1])+"\n"
|
bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
|
||||||
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
|
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
|
||||||
no_bpe_sentences[line_id] = no_bpe_sen
|
no_bpe_sentences[line_id] = no_bpe_sen
|
||||||
|
|
||||||
if prefix_len is not None:
|
if prefix_len is not None:
|
||||||
num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len)
|
num_bpe_tokens = get_num_bpe_tokens_from_len(
|
||||||
|
bpe_sen, bpe_symbol, prefix_len
|
||||||
|
)
|
||||||
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
|
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
|
||||||
num_bpe_tokens_dict[line_id] = num_bpe_tokens
|
num_bpe_tokens_dict[line_id] = num_bpe_tokens
|
||||||
elif target_prefix_frac is not None:
|
elif target_prefix_frac is not None:
|
||||||
num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac,
|
num_words, shortened, target_prefix_len = calc_length_from_frac(
|
||||||
bpe_symbol)
|
bpe_sen, target_prefix_frac, bpe_symbol
|
||||||
|
)
|
||||||
sen_scores[line_id] = sum(scores[:target_prefix_len])
|
sen_scores[line_id] = sum(scores[:target_prefix_len])
|
||||||
num_bpe_tokens_dict[line_id] = target_prefix_len
|
num_bpe_tokens_dict[line_id] = target_prefix_len
|
||||||
else:
|
else:
|
||||||
@ -492,160 +582,269 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
|
|||||||
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
|
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
|
||||||
|
|
||||||
|
|
||||||
def get_directories(data_dir_name, num_rescore, gen_subset,
|
def get_directories(
|
||||||
fw_name, shard_id, num_shards,
|
data_dir_name,
|
||||||
sampling=False, prefix_len=None,
|
num_rescore,
|
||||||
target_prefix_frac=None, source_prefix_frac=None):
|
gen_subset,
|
||||||
nbest_file_id = "nbest_" + str(num_rescore) + \
|
fw_name,
|
||||||
"_subset_" + gen_subset + \
|
shard_id,
|
||||||
"_fw_name_" + fw_name + \
|
num_shards,
|
||||||
"_shard_" + str(shard_id) + \
|
sampling=False,
|
||||||
"_of_" + str(num_shards)
|
prefix_len=None,
|
||||||
|
target_prefix_frac=None,
|
||||||
|
source_prefix_frac=None,
|
||||||
|
):
|
||||||
|
nbest_file_id = (
|
||||||
|
"nbest_"
|
||||||
|
+ str(num_rescore)
|
||||||
|
+ "_subset_"
|
||||||
|
+ gen_subset
|
||||||
|
+ "_fw_name_"
|
||||||
|
+ fw_name
|
||||||
|
+ "_shard_"
|
||||||
|
+ str(shard_id)
|
||||||
|
+ "_of_"
|
||||||
|
+ str(num_shards)
|
||||||
|
)
|
||||||
|
|
||||||
if sampling:
|
if sampling:
|
||||||
nbest_file_id += "_sampling"
|
nbest_file_id += "_sampling"
|
||||||
|
|
||||||
# the directory containing all information for this nbest list
|
# the directory containing all information for this nbest list
|
||||||
pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id
|
pre_gen = (
|
||||||
|
os.path.join(os.path.dirname(__file__))
|
||||||
|
+ "/rerank_data/"
|
||||||
|
+ data_dir_name
|
||||||
|
+ "/"
|
||||||
|
+ nbest_file_id
|
||||||
|
)
|
||||||
# the directory to store the preprocessed nbest list, for left to right rescoring
|
# the directory to store the preprocessed nbest list, for left to right rescoring
|
||||||
left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed"
|
left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
|
||||||
if source_prefix_frac is not None:
|
if source_prefix_frac is not None:
|
||||||
left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
|
left_to_right_preprocessed_dir = (
|
||||||
|
left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
|
||||||
|
)
|
||||||
# the directory to store the preprocessed nbest list, for right to left rescoring
|
# the directory to store the preprocessed nbest list, for right to left rescoring
|
||||||
right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed"
|
right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
|
||||||
# the directory to store the preprocessed nbest list, for backwards rescoring
|
# the directory to store the preprocessed nbest list, for backwards rescoring
|
||||||
backwards_preprocessed_dir = pre_gen+"/backwards"
|
backwards_preprocessed_dir = pre_gen + "/backwards"
|
||||||
if target_prefix_frac is not None:
|
if target_prefix_frac is not None:
|
||||||
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac)
|
backwards_preprocessed_dir = (
|
||||||
|
backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
|
||||||
|
)
|
||||||
elif prefix_len is not None:
|
elif prefix_len is not None:
|
||||||
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len)
|
backwards_preprocessed_dir = (
|
||||||
|
backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
|
||||||
|
)
|
||||||
|
|
||||||
# the directory to store the preprocessed nbest list, for rescoring with P(T)
|
# the directory to store the preprocessed nbest list, for rescoring with P(T)
|
||||||
lm_preprocessed_dir = pre_gen+"/lm_preprocessed"
|
lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
|
||||||
|
|
||||||
return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
return (
|
||||||
backwards_preprocessed_dir, lm_preprocessed_dir
|
pre_gen,
|
||||||
|
left_to_right_preprocessed_dir,
|
||||||
|
right_to_left_preprocessed_dir,
|
||||||
|
backwards_preprocessed_dir,
|
||||||
|
lm_preprocessed_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen,
|
def lm_scoring(
|
||||||
cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code,
|
preprocess_directory,
|
||||||
batch_size, lm_score_file, target_lang, source_lang, prefix_len=None):
|
bpe_status,
|
||||||
|
gen_output,
|
||||||
|
pre_gen,
|
||||||
|
cur_lm_dict,
|
||||||
|
cur_lm_name,
|
||||||
|
cur_language_model,
|
||||||
|
cur_lm_bpe_code,
|
||||||
|
batch_size,
|
||||||
|
lm_score_file,
|
||||||
|
target_lang,
|
||||||
|
source_lang,
|
||||||
|
prefix_len=None,
|
||||||
|
):
|
||||||
if prefix_len is not None:
|
if prefix_len is not None:
|
||||||
assert bpe_status == "different", "bpe status must be different to use prefix len"
|
assert (
|
||||||
|
bpe_status == "different"
|
||||||
|
), "bpe status must be different to use prefix len"
|
||||||
if bpe_status == "no bpe":
|
if bpe_status == "no bpe":
|
||||||
# run lm on output without bpe
|
# run lm on output without bpe
|
||||||
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
|
write_reprocessed(
|
||||||
gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de",
|
gen_output.no_bpe_source,
|
||||||
pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe")
|
gen_output.no_bpe_hypo,
|
||||||
|
gen_output.no_bpe_target,
|
||||||
|
pre_gen + "/rescore_data_no_bpe.de",
|
||||||
|
pre_gen + "/rescore_data_no_bpe.en",
|
||||||
|
pre_gen + "/reference_file_no_bpe",
|
||||||
|
)
|
||||||
|
|
||||||
preprocess_lm_param = ["--only-source",
|
preprocess_lm_param = [
|
||||||
"--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang,
|
"--only-source",
|
||||||
"--srcdict", cur_lm_dict,
|
"--trainpref",
|
||||||
"--destdir", preprocess_directory]
|
pre_gen + "/rescore_data_no_bpe." + target_lang,
|
||||||
|
"--srcdict",
|
||||||
|
cur_lm_dict,
|
||||||
|
"--destdir",
|
||||||
|
preprocess_directory,
|
||||||
|
]
|
||||||
preprocess_parser = options.get_preprocessing_parser()
|
preprocess_parser = options.get_preprocessing_parser()
|
||||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||||
preprocess.main(input_args)
|
preprocess.main(input_args)
|
||||||
|
|
||||||
eval_lm_param = [preprocess_directory,
|
eval_lm_param = [
|
||||||
"--path", cur_language_model,
|
preprocess_directory,
|
||||||
"--output-word-probs",
|
"--path",
|
||||||
"--batch-size", str(batch_size),
|
cur_language_model,
|
||||||
"--max-tokens", "1024",
|
"--output-word-probs",
|
||||||
"--sample-break-mode", "eos",
|
"--batch-size",
|
||||||
"--gen-subset", "train"]
|
str(batch_size),
|
||||||
|
"--max-tokens",
|
||||||
|
"1024",
|
||||||
|
"--sample-break-mode",
|
||||||
|
"eos",
|
||||||
|
"--gen-subset",
|
||||||
|
"train",
|
||||||
|
]
|
||||||
|
|
||||||
eval_lm_parser = options.get_eval_lm_parser()
|
eval_lm_parser = options.get_eval_lm_parser()
|
||||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||||
|
|
||||||
with open(lm_score_file, 'w') as f:
|
with open(lm_score_file, "w") as f:
|
||||||
with redirect_stdout(f):
|
with redirect_stdout(f):
|
||||||
eval_lm.main(input_args)
|
eval_lm.main(input_args)
|
||||||
|
|
||||||
elif bpe_status == "shared":
|
elif bpe_status == "shared":
|
||||||
preprocess_lm_param = ["--only-source",
|
preprocess_lm_param = [
|
||||||
"--trainpref", pre_gen+"/rescore_data."+target_lang,
|
"--only-source",
|
||||||
"--srcdict", cur_lm_dict,
|
"--trainpref",
|
||||||
"--destdir", preprocess_directory]
|
pre_gen + "/rescore_data." + target_lang,
|
||||||
preprocess_parser = options.get_preprocessing_parser()
|
"--srcdict",
|
||||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
cur_lm_dict,
|
||||||
preprocess.main(input_args)
|
"--destdir",
|
||||||
|
preprocess_directory,
|
||||||
|
]
|
||||||
|
preprocess_parser = options.get_preprocessing_parser()
|
||||||
|
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||||
|
preprocess.main(input_args)
|
||||||
|
|
||||||
eval_lm_param = [preprocess_directory,
|
eval_lm_param = [
|
||||||
"--path", cur_language_model,
|
preprocess_directory,
|
||||||
"--output-word-probs",
|
"--path",
|
||||||
"--batch-size", str(batch_size),
|
cur_language_model,
|
||||||
"--sample-break-mode", "eos",
|
"--output-word-probs",
|
||||||
"--gen-subset", "train"]
|
"--batch-size",
|
||||||
|
str(batch_size),
|
||||||
|
"--sample-break-mode",
|
||||||
|
"eos",
|
||||||
|
"--gen-subset",
|
||||||
|
"train",
|
||||||
|
]
|
||||||
|
|
||||||
eval_lm_parser = options.get_eval_lm_parser()
|
eval_lm_parser = options.get_eval_lm_parser()
|
||||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||||
|
|
||||||
with open(lm_score_file, 'w') as f:
|
with open(lm_score_file, "w") as f:
|
||||||
with redirect_stdout(f):
|
with redirect_stdout(f):
|
||||||
eval_lm.main(input_args)
|
eval_lm.main(input_args)
|
||||||
|
|
||||||
elif bpe_status == "different":
|
elif bpe_status == "different":
|
||||||
rescore_file = pre_gen+"/rescore_data_no_bpe"
|
rescore_file = pre_gen + "/rescore_data_no_bpe"
|
||||||
rescore_bpe = pre_gen+"/rescore_data_new_bpe"
|
rescore_bpe = pre_gen + "/rescore_data_new_bpe"
|
||||||
|
|
||||||
rescore_file += "."
|
rescore_file += "."
|
||||||
rescore_bpe += "."
|
rescore_bpe += "."
|
||||||
|
|
||||||
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
|
write_reprocessed(
|
||||||
gen_output.no_bpe_target, rescore_file+source_lang,
|
gen_output.no_bpe_source,
|
||||||
rescore_file+target_lang, pre_gen+"/reference_file_no_bpe",
|
gen_output.no_bpe_hypo,
|
||||||
bpe_symbol=None)
|
gen_output.no_bpe_target,
|
||||||
|
rescore_file + source_lang,
|
||||||
|
rescore_file + target_lang,
|
||||||
|
pre_gen + "/reference_file_no_bpe",
|
||||||
|
bpe_symbol=None,
|
||||||
|
)
|
||||||
|
|
||||||
# apply LM bpe to nbest list
|
# apply LM bpe to nbest list
|
||||||
bpe_src_param = ["-c", cur_lm_bpe_code,
|
bpe_src_param = [
|
||||||
"--input", rescore_file+target_lang,
|
"-c",
|
||||||
"--output", rescore_bpe+target_lang]
|
cur_lm_bpe_code,
|
||||||
subprocess.call(["python",
|
"--input",
|
||||||
os.path.join(os.path.dirname(__file__),
|
rescore_file + target_lang,
|
||||||
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
|
"--output",
|
||||||
shell=False)
|
rescore_bpe + target_lang,
|
||||||
|
]
|
||||||
|
subprocess.call(
|
||||||
|
[
|
||||||
|
"python",
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ bpe_src_param,
|
||||||
|
shell=False,
|
||||||
|
)
|
||||||
# uncomment to use fastbpe instead of subword-nmt bpe
|
# uncomment to use fastbpe instead of subword-nmt bpe
|
||||||
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
|
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
|
||||||
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
|
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
|
||||||
|
|
||||||
preprocess_dir = preprocess_directory
|
preprocess_dir = preprocess_directory
|
||||||
|
|
||||||
preprocess_lm_param = ["--only-source",
|
preprocess_lm_param = [
|
||||||
"--trainpref", rescore_bpe+target_lang,
|
"--only-source",
|
||||||
"--srcdict", cur_lm_dict,
|
"--trainpref",
|
||||||
"--destdir", preprocess_dir]
|
rescore_bpe + target_lang,
|
||||||
|
"--srcdict",
|
||||||
|
cur_lm_dict,
|
||||||
|
"--destdir",
|
||||||
|
preprocess_dir,
|
||||||
|
]
|
||||||
preprocess_parser = options.get_preprocessing_parser()
|
preprocess_parser = options.get_preprocessing_parser()
|
||||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||||
preprocess.main(input_args)
|
preprocess.main(input_args)
|
||||||
|
|
||||||
eval_lm_param = [preprocess_dir,
|
eval_lm_param = [
|
||||||
"--path", cur_language_model,
|
preprocess_dir,
|
||||||
"--output-word-probs",
|
"--path",
|
||||||
"--batch-size", str(batch_size),
|
cur_language_model,
|
||||||
"--max-tokens", "1024",
|
"--output-word-probs",
|
||||||
"--sample-break-mode", "eos",
|
"--batch-size",
|
||||||
"--gen-subset", "train"]
|
str(batch_size),
|
||||||
|
"--max-tokens",
|
||||||
|
"1024",
|
||||||
|
"--sample-break-mode",
|
||||||
|
"eos",
|
||||||
|
"--gen-subset",
|
||||||
|
"train",
|
||||||
|
]
|
||||||
|
|
||||||
eval_lm_parser = options.get_eval_lm_parser()
|
eval_lm_parser = options.get_eval_lm_parser()
|
||||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||||
|
|
||||||
with open(lm_score_file, 'w') as f:
|
with open(lm_score_file, "w") as f:
|
||||||
with redirect_stdout(f):
|
with redirect_stdout(f):
|
||||||
eval_lm.main(input_args)
|
eval_lm.main(input_args)
|
||||||
|
|
||||||
|
|
||||||
def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False,
|
def rescore_file_name(
|
||||||
target_prefix_frac=None, source_prefix_frac=None, backwards=None):
|
nbest_dir,
|
||||||
|
prefix_len,
|
||||||
|
scorer_name,
|
||||||
|
lm_file=False,
|
||||||
|
target_prefix_frac=None,
|
||||||
|
source_prefix_frac=None,
|
||||||
|
backwards=None,
|
||||||
|
):
|
||||||
if lm_file:
|
if lm_file:
|
||||||
score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt"
|
score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
|
||||||
else:
|
else:
|
||||||
score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt"
|
score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
|
||||||
if backwards:
|
if backwards:
|
||||||
if prefix_len is not None:
|
if prefix_len is not None:
|
||||||
score_file += "prefix_len"+str(prefix_len)
|
score_file += "prefix_len" + str(prefix_len)
|
||||||
elif target_prefix_frac is not None:
|
elif target_prefix_frac is not None:
|
||||||
score_file += "target_prefix_frac"+str(target_prefix_frac)
|
score_file += "target_prefix_frac" + str(target_prefix_frac)
|
||||||
else:
|
else:
|
||||||
if source_prefix_frac is not None:
|
if source_prefix_frac is not None:
|
||||||
score_file += "source_prefix_frac"+str(source_prefix_frac)
|
score_file += "source_prefix_frac" + str(source_prefix_frac)
|
||||||
return score_file
|
return score_file
|
||||||
|
@ -13,57 +13,66 @@ logging.getLogger().setLevel(logging.INFO)
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='')
|
parser = argparse.ArgumentParser(description="")
|
||||||
parser.add_argument('--en2fr', required=True,
|
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
|
||||||
help='path to en2fr model')
|
parser.add_argument(
|
||||||
parser.add_argument('--fr2en', required=True,
|
"--fr2en", required=True, help="path to fr2en mixture of experts model"
|
||||||
help='path to fr2en mixture of experts model')
|
)
|
||||||
parser.add_argument('--user-dir',
|
parser.add_argument(
|
||||||
help='path to fairseq examples/translation_moe/src directory')
|
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
|
||||||
parser.add_argument('--num-experts', type=int, default=10,
|
)
|
||||||
help='(keep at 10 unless using a different model)')
|
parser.add_argument(
|
||||||
parser.add_argument('files', nargs='*', default=['-'],
|
"--num-experts",
|
||||||
help='input files to paraphrase; "-" for stdin')
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="(keep at 10 unless using a different model)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"files",
|
||||||
|
nargs="*",
|
||||||
|
default=["-"],
|
||||||
|
help='input files to paraphrase; "-" for stdin',
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.user_dir is None:
|
if args.user_dir is None:
|
||||||
args.user_dir = os.path.join(
|
args.user_dir = os.path.join(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
|
||||||
'translation_moe',
|
"translation_moe",
|
||||||
'src',
|
"src",
|
||||||
)
|
)
|
||||||
if os.path.exists(args.user_dir):
|
if os.path.exists(args.user_dir):
|
||||||
logging.info('found user_dir:' + args.user_dir)
|
logging.info("found user_dir:" + args.user_dir)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'cannot find fairseq examples/translation_moe/src '
|
"cannot find fairseq examples/translation_moe/src "
|
||||||
'(tried looking here: {})'.format(args.user_dir)
|
"(tried looking here: {})".format(args.user_dir)
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info('loading en2fr model from:' + args.en2fr)
|
logging.info("loading en2fr model from:" + args.en2fr)
|
||||||
en2fr = TransformerModel.from_pretrained(
|
en2fr = TransformerModel.from_pretrained(
|
||||||
model_name_or_path=args.en2fr,
|
model_name_or_path=args.en2fr,
|
||||||
tokenizer='moses',
|
tokenizer="moses",
|
||||||
bpe='sentencepiece',
|
bpe="sentencepiece",
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
logging.info('loading fr2en model from:' + args.fr2en)
|
logging.info("loading fr2en model from:" + args.fr2en)
|
||||||
fr2en = TransformerModel.from_pretrained(
|
fr2en = TransformerModel.from_pretrained(
|
||||||
model_name_or_path=args.fr2en,
|
model_name_or_path=args.fr2en,
|
||||||
tokenizer='moses',
|
tokenizer="moses",
|
||||||
bpe='sentencepiece',
|
bpe="sentencepiece",
|
||||||
user_dir=args.user_dir,
|
user_dir=args.user_dir,
|
||||||
task='translation_moe',
|
task="translation_moe",
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
def gen_paraphrases(en):
|
def gen_paraphrases(en):
|
||||||
fr = en2fr.translate(en)
|
fr = en2fr.translate(en)
|
||||||
return [
|
return [
|
||||||
fr2en.translate(fr, inference_step_args={'expert': i})
|
fr2en.translate(fr, inference_step_args={"expert": i})
|
||||||
for i in range(args.num_experts)
|
for i in range(args.num_experts)
|
||||||
]
|
]
|
||||||
|
|
||||||
logging.info('Type the input sentence and press return:')
|
logging.info("Type the input sentence and press return:")
|
||||||
for line in fileinput.input(args.files):
|
for line in fileinput.input(args.files):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if len(line) == 0:
|
if len(line) == 0:
|
||||||
@ -72,5 +81,5 @@ def main():
|
|||||||
print(paraphrase)
|
print(paraphrase)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class OOVIndexError(IndexError):
|
class OOVIndexError(IndexError):
|
||||||
@ -25,8 +25,8 @@ class OOVIndexError(IndexError):
|
|||||||
|
|
||||||
def replace_oovs(source_in, target_in, target_out):
|
def replace_oovs(source_in, target_in, target_out):
|
||||||
"""Replaces <unk-N> tokens in the target text with the corresponding word in
|
"""Replaces <unk-N> tokens in the target text with the corresponding word in
|
||||||
the source text.
|
the source text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
oov_re = re.compile("^<unk-([0-9]+)>$")
|
oov_re = re.compile("^<unk-([0-9]+)>$")
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ from itertools import zip_longest
|
|||||||
|
|
||||||
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
|
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
|
||||||
"""Replaces out-of-vocabulary words in source and target text with <unk-N>,
|
"""Replaces out-of-vocabulary words in source and target text with <unk-N>,
|
||||||
where N in is the position of the word in the source sequence.
|
where N in is the position of the word in the source sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def format_unk(pos):
|
def format_unk(pos):
|
||||||
return "<unk-{}>".format(pos)
|
return "<unk-{}>".format(pos)
|
||||||
|
@ -8,19 +8,17 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from fairseq import metrics, utils
|
||||||
from fairseq import utils, metrics
|
|
||||||
from fairseq.models import register_model, register_model_architecture
|
from fairseq.models import register_model, register_model_architecture
|
||||||
from fairseq.models.fairseq_encoder import EncoderOut
|
from fairseq.models.fairseq_encoder import EncoderOut
|
||||||
from fairseq.models.transformer import (
|
from fairseq.models.transformer import (
|
||||||
TransformerModel,
|
|
||||||
TransformerDecoder,
|
|
||||||
TransformerEncoder,
|
|
||||||
base_architecture,
|
|
||||||
DEFAULT_MAX_SOURCE_POSITIONS,
|
DEFAULT_MAX_SOURCE_POSITIONS,
|
||||||
DEFAULT_MAX_TARGET_POSITIONS,
|
DEFAULT_MAX_TARGET_POSITIONS,
|
||||||
|
TransformerDecoder,
|
||||||
|
TransformerEncoder,
|
||||||
|
TransformerModel,
|
||||||
|
base_architecture,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,40 +8,44 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq.data import (
|
from fairseq.data import (
|
||||||
data_utils,
|
|
||||||
Dictionary,
|
Dictionary,
|
||||||
encoders,
|
|
||||||
IdDataset,
|
IdDataset,
|
||||||
ListDataset,
|
ListDataset,
|
||||||
NestedDictionaryDataset,
|
NestedDictionaryDataset,
|
||||||
NumSamplesDataset,
|
|
||||||
NumelDataset,
|
NumelDataset,
|
||||||
|
NumSamplesDataset,
|
||||||
RawLabelDataset,
|
RawLabelDataset,
|
||||||
RightPadDataset,
|
RightPadDataset,
|
||||||
SortDataset,
|
SortDataset,
|
||||||
|
data_utils,
|
||||||
|
encoders,
|
||||||
)
|
)
|
||||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||||
|
|
||||||
|
|
||||||
@register_task('commonsense_qa')
|
@register_task("commonsense_qa")
|
||||||
class CommonsenseQATask(LegacyFairseqTask):
|
class CommonsenseQATask(LegacyFairseqTask):
|
||||||
"""Task to finetune RoBERTa for Commonsense QA."""
|
"""Task to finetune RoBERTa for Commonsense QA."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add task-specific arguments to the parser."""
|
"""Add task-specific arguments to the parser."""
|
||||||
parser.add_argument('data', metavar='DIR',
|
parser.add_argument(
|
||||||
help='path to data directory; we load <split>.jsonl')
|
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
||||||
parser.add_argument('--init-token', type=int, default=None,
|
)
|
||||||
help='add token at the beginning of each batch item')
|
parser.add_argument(
|
||||||
parser.add_argument('--num-classes', type=int, default=5)
|
"--init-token",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="add token at the beginning of each batch item",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-classes", type=int, default=5)
|
||||||
|
|
||||||
def __init__(self, args, vocab):
|
def __init__(self, args, vocab):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.mask = vocab.add_symbol('<mask>')
|
self.mask = vocab.add_symbol("<mask>")
|
||||||
|
|
||||||
self.bpe = encoders.build_bpe(args)
|
self.bpe = encoders.build_bpe(args)
|
||||||
|
|
||||||
@ -53,20 +57,24 @@ class CommonsenseQATask(LegacyFairseqTask):
|
|||||||
filename (str): the filename
|
filename (str): the filename
|
||||||
"""
|
"""
|
||||||
dictionary = Dictionary.load(filename)
|
dictionary = Dictionary.load(filename)
|
||||||
dictionary.add_symbol('<mask>')
|
dictionary.add_symbol("<mask>")
|
||||||
return dictionary
|
return dictionary
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_task(cls, args, **kwargs):
|
def setup_task(cls, args, **kwargs):
|
||||||
assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking'
|
assert (
|
||||||
|
args.criterion == "sentence_ranking"
|
||||||
|
), "Must set --criterion=sentence_ranking"
|
||||||
|
|
||||||
# load data and label dictionaries
|
# load data and label dictionaries
|
||||||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
|
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
||||||
print('| dictionary: {} types'.format(len(vocab)))
|
print("| dictionary: {} types".format(len(vocab)))
|
||||||
|
|
||||||
return cls(args, vocab)
|
return cls(args, vocab)
|
||||||
|
|
||||||
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
|
def load_dataset(
|
||||||
|
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
||||||
|
):
|
||||||
"""Load a given dataset split.
|
"""Load a given dataset split.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -77,16 +85,18 @@ class CommonsenseQATask(LegacyFairseqTask):
|
|||||||
if self.bpe is not None:
|
if self.bpe is not None:
|
||||||
s = self.bpe.encode(s)
|
s = self.bpe.encode(s)
|
||||||
tokens = self.vocab.encode_line(
|
tokens = self.vocab.encode_line(
|
||||||
s, append_eos=True, add_if_not_exist=False,
|
s,
|
||||||
|
append_eos=True,
|
||||||
|
add_if_not_exist=False,
|
||||||
).long()
|
).long()
|
||||||
if append_bos and self.args.init_token is not None:
|
if append_bos and self.args.init_token is not None:
|
||||||
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = os.path.join(self.args.data, split + '.jsonl')
|
data_path = os.path.join(self.args.data, split + ".jsonl")
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
|
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
||||||
|
|
||||||
src_tokens = [[] for i in range(self.args.num_classes)]
|
src_tokens = [[] for i in range(self.args.num_classes)]
|
||||||
src_lengths = [[] for i in range(self.args.num_classes)]
|
src_lengths = [[] for i in range(self.args.num_classes)]
|
||||||
@ -95,20 +105,23 @@ class CommonsenseQATask(LegacyFairseqTask):
|
|||||||
with open(data_path) as h:
|
with open(data_path) as h:
|
||||||
for line in h:
|
for line in h:
|
||||||
example = json.loads(line.strip())
|
example = json.loads(line.strip())
|
||||||
if 'answerKey' in example:
|
if "answerKey" in example:
|
||||||
label = ord(example['answerKey']) - ord('A')
|
label = ord(example["answerKey"]) - ord("A")
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
question = example['question']['stem']
|
question = example["question"]["stem"]
|
||||||
assert len(example['question']['choices']) == self.args.num_classes
|
assert len(example["question"]["choices"]) == self.args.num_classes
|
||||||
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
|
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
|
||||||
question = 'Q: ' + question
|
question = "Q: " + question
|
||||||
question_toks = binarize(question, append_bos=True)
|
question_toks = binarize(question, append_bos=True)
|
||||||
for i, choice in enumerate(example['question']['choices']):
|
for i, choice in enumerate(example["question"]["choices"]):
|
||||||
src = 'A: ' + choice['text']
|
src = "A: " + choice["text"]
|
||||||
src_bin = torch.cat([question_toks, binarize(src)])
|
src_bin = torch.cat([question_toks, binarize(src)])
|
||||||
src_tokens[i].append(src_bin)
|
src_tokens[i].append(src_bin)
|
||||||
src_lengths[i].append(len(src_bin))
|
src_lengths[i].append(len(src_bin))
|
||||||
assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
|
assert all(
|
||||||
|
len(src_tokens[0]) == len(src_tokens[i])
|
||||||
|
for i in range(self.args.num_classes)
|
||||||
|
)
|
||||||
assert len(src_tokens[0]) == len(src_lengths[0])
|
assert len(src_tokens[0]) == len(src_lengths[0])
|
||||||
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
|
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
|
||||||
|
|
||||||
@ -118,24 +131,26 @@ class CommonsenseQATask(LegacyFairseqTask):
|
|||||||
src_lengths[i] = ListDataset(src_lengths[i])
|
src_lengths[i] = ListDataset(src_lengths[i])
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
'id': IdDataset(),
|
"id": IdDataset(),
|
||||||
'nsentences': NumSamplesDataset(),
|
"nsentences": NumSamplesDataset(),
|
||||||
'ntokens': NumelDataset(src_tokens[0], reduce=True),
|
"ntokens": NumelDataset(src_tokens[0], reduce=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in range(self.args.num_classes):
|
for i in range(self.args.num_classes):
|
||||||
dataset.update({
|
dataset.update(
|
||||||
'net_input{}'.format(i + 1): {
|
{
|
||||||
'src_tokens': RightPadDataset(
|
"net_input{}".format(i + 1): {
|
||||||
src_tokens[i],
|
"src_tokens": RightPadDataset(
|
||||||
pad_idx=self.source_dictionary.pad(),
|
src_tokens[i],
|
||||||
),
|
pad_idx=self.source_dictionary.pad(),
|
||||||
'src_lengths': src_lengths[i],
|
),
|
||||||
|
"src_lengths": src_lengths[i],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
if len(labels) > 0:
|
if len(labels) > 0:
|
||||||
dataset.update({'target': RawLabelDataset(labels)})
|
dataset.update({"target": RawLabelDataset(labels)})
|
||||||
|
|
||||||
dataset = NestedDictionaryDataset(
|
dataset = NestedDictionaryDataset(
|
||||||
dataset,
|
dataset,
|
||||||
@ -149,17 +164,18 @@ class CommonsenseQATask(LegacyFairseqTask):
|
|||||||
sort_order=[np.random.permutation(len(dataset))],
|
sort_order=[np.random.permutation(len(dataset))],
|
||||||
)
|
)
|
||||||
|
|
||||||
print('| Loaded {} with {} samples'.format(split, len(dataset)))
|
print("| Loaded {} with {} samples".format(split, len(dataset)))
|
||||||
|
|
||||||
self.datasets[split] = dataset
|
self.datasets[split] = dataset
|
||||||
return self.datasets[split]
|
return self.datasets[split]
|
||||||
|
|
||||||
def build_model(self, args):
|
def build_model(self, args):
|
||||||
from fairseq import models
|
from fairseq import models
|
||||||
|
|
||||||
model = models.build_model(args, self)
|
model = models.build_model(args, self)
|
||||||
|
|
||||||
model.register_classification_head(
|
model.register_classification_head(
|
||||||
'sentence_classification_head',
|
"sentence_classification_head",
|
||||||
num_classes=1,
|
num_classes=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import contextlib
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
@ -26,23 +25,23 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-json",
|
"--encoder-json",
|
||||||
help='path to encoder.json',
|
help="path to encoder.json",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-bpe",
|
"--vocab-bpe",
|
||||||
type=str,
|
type=str,
|
||||||
help='path to vocab.bpe',
|
help="path to vocab.bpe",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inputs",
|
"--inputs",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=['-'],
|
default=["-"],
|
||||||
help="input files to filter/encode",
|
help="input files to filter/encode",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outputs",
|
"--outputs",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=['-'],
|
default=["-"],
|
||||||
help="path to save encoded outputs",
|
help="path to save encoded outputs",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -53,18 +52,21 @@ def main():
|
|||||||
parser.add_argument("--workers", type=int, default=20)
|
parser.add_argument("--workers", type=int, default=20)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert len(args.inputs) == len(args.outputs), \
|
assert len(args.inputs) == len(
|
||||||
"number of input and output paths should match"
|
args.outputs
|
||||||
|
), "number of input and output paths should match"
|
||||||
|
|
||||||
with contextlib.ExitStack() as stack:
|
with contextlib.ExitStack() as stack:
|
||||||
inputs = [
|
inputs = [
|
||||||
stack.enter_context(open(input, "r", encoding="utf-8"))
|
stack.enter_context(open(input, "r", encoding="utf-8"))
|
||||||
if input != "-" else sys.stdin
|
if input != "-"
|
||||||
|
else sys.stdin
|
||||||
for input in args.inputs
|
for input in args.inputs
|
||||||
]
|
]
|
||||||
outputs = [
|
outputs = [
|
||||||
stack.enter_context(open(output, "w", encoding="utf-8"))
|
stack.enter_context(open(output, "w", encoding="utf-8"))
|
||||||
if output != "-" else sys.stdout
|
if output != "-"
|
||||||
|
else sys.stdout
|
||||||
for output in args.outputs
|
for output in args.outputs
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -87,7 +89,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
class MultiprocessingEncoder(object):
|
class MultiprocessingEncoder(object):
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def get_examples(data_dir, set_type):
|
|||||||
examples = []
|
examples = []
|
||||||
|
|
||||||
levels = ["middle", "high"]
|
levels = ["middle", "high"]
|
||||||
set_type_c = set_type.split('-')
|
set_type_c = set_type.split("-")
|
||||||
if len(set_type_c) == 2:
|
if len(set_type_c) == 2:
|
||||||
levels = [set_type_c[1]]
|
levels = [set_type_c[1]]
|
||||||
set_type = set_type_c[0]
|
set_type = set_type_c[0]
|
||||||
@ -33,13 +33,13 @@ def get_examples(data_dir, set_type):
|
|||||||
cur_dir = os.path.join(data_dir, set_type, level)
|
cur_dir = os.path.join(data_dir, set_type, level)
|
||||||
for filename in os.listdir(cur_dir):
|
for filename in os.listdir(cur_dir):
|
||||||
cur_path = os.path.join(cur_dir, filename)
|
cur_path = os.path.join(cur_dir, filename)
|
||||||
with open(cur_path, 'r') as f:
|
with open(cur_path, "r") as f:
|
||||||
cur_data = json.load(f)
|
cur_data = json.load(f)
|
||||||
answers = cur_data["answers"]
|
answers = cur_data["answers"]
|
||||||
options = cur_data["options"]
|
options = cur_data["options"]
|
||||||
questions = cur_data["questions"]
|
questions = cur_data["questions"]
|
||||||
context = cur_data["article"].replace("\n", " ")
|
context = cur_data["article"].replace("\n", " ")
|
||||||
context = re.sub(r'\s+', ' ', context)
|
context = re.sub(r"\s+", " ", context)
|
||||||
for i in range(len(answers)):
|
for i in range(len(answers)):
|
||||||
label = ord(answers[i]) - ord("A")
|
label = ord(answers[i]) - ord("A")
|
||||||
qa_list = []
|
qa_list = []
|
||||||
@ -50,7 +50,7 @@ def get_examples(data_dir, set_type):
|
|||||||
qa_cat = question.replace("_", option)
|
qa_cat = question.replace("_", option)
|
||||||
else:
|
else:
|
||||||
qa_cat = " ".join([question, option])
|
qa_cat = " ".join([question, option])
|
||||||
qa_cat = re.sub(r'\s+', ' ', qa_cat)
|
qa_cat = re.sub(r"\s+", " ", qa_cat)
|
||||||
qa_list.append(qa_cat)
|
qa_list.append(qa_cat)
|
||||||
examples.append(InputExample(context, qa_list, label))
|
examples.append(InputExample(context, qa_list, label))
|
||||||
|
|
||||||
@ -64,11 +64,11 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-dir",
|
"--input-dir",
|
||||||
help='input directory for downloaded RACE dataset',
|
help="input directory for downloaded RACE dataset",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
help='output directory for extracted data',
|
help="output directory for extracted data",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -77,17 +77,20 @@ def main():
|
|||||||
|
|
||||||
for set_type in ["train", "dev", "test-middle", "test-high"]:
|
for set_type in ["train", "dev", "test-middle", "test-high"]:
|
||||||
examples = get_examples(args.input_dir, set_type)
|
examples = get_examples(args.input_dir, set_type)
|
||||||
qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
|
qa_file_paths = [
|
||||||
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
|
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
|
||||||
|
for i in range(4)
|
||||||
|
]
|
||||||
|
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
|
||||||
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
|
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
|
||||||
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
|
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
|
||||||
outf_context = open(outf_context_path, 'w')
|
outf_context = open(outf_context_path, "w")
|
||||||
outf_label = open(outf_label_path, 'w')
|
outf_label = open(outf_label_path, "w")
|
||||||
for example in examples:
|
for example in examples:
|
||||||
outf_context.write(example.paragraph + '\n')
|
outf_context.write(example.paragraph + "\n")
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
qa_files[i].write(example.qa_list[i] + '\n')
|
qa_files[i].write(example.qa_list[i] + "\n")
|
||||||
outf_label.write(str(example.label) + '\n')
|
outf_label.write(str(example.label) + "\n")
|
||||||
|
|
||||||
for f in qa_files:
|
for f in qa_files:
|
||||||
f.close()
|
f.close()
|
||||||
@ -95,5 +98,5 @@ def main():
|
|||||||
outf_context.close()
|
outf_context.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -7,19 +7,17 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.data import encoders
|
|
||||||
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
||||||
|
from fairseq.data import encoders
|
||||||
|
|
||||||
|
|
||||||
@register_criterion('wsc')
|
@register_criterion("wsc")
|
||||||
class WSCCriterion(LegacyFairseqCriterion):
|
class WSCCriterion(LegacyFairseqCriterion):
|
||||||
|
|
||||||
def __init__(self, args, task):
|
def __init__(self, args, task):
|
||||||
super().__init__(args, task)
|
super().__init__(args, task)
|
||||||
if self.args.save_predictions is not None:
|
if self.args.save_predictions is not None:
|
||||||
self.prediction_h = open(self.args.save_predictions, 'w')
|
self.prediction_h = open(self.args.save_predictions, "w")
|
||||||
else:
|
else:
|
||||||
self.prediction_h = None
|
self.prediction_h = None
|
||||||
self.bpe = encoders.build_bpe(args)
|
self.bpe = encoders.build_bpe(args)
|
||||||
@ -32,12 +30,16 @@ class WSCCriterion(LegacyFairseqCriterion):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add criterion-specific arguments to the parser."""
|
"""Add criterion-specific arguments to the parser."""
|
||||||
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
|
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
|
||||||
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
|
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
|
||||||
parser.add_argument('--wsc-cross-entropy', action='store_true',
|
parser.add_argument(
|
||||||
help='use cross entropy formulation instead of margin loss')
|
"--wsc-cross-entropy",
|
||||||
parser.add_argument('--save-predictions', metavar='FILE',
|
action="store_true",
|
||||||
help='file to save predictions to')
|
help="use cross entropy formulation instead of margin loss",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-predictions", metavar="FILE", help="file to save predictions to"
|
||||||
|
)
|
||||||
|
|
||||||
def get_masked_input(self, tokens, mask):
|
def get_masked_input(self, tokens, mask):
|
||||||
masked_tokens = tokens.clone()
|
masked_tokens = tokens.clone()
|
||||||
@ -60,27 +62,26 @@ class WSCCriterion(LegacyFairseqCriterion):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
- query_lprobs
|
-query_lprobs
|
||||||
+ self.args.wsc_margin_alpha * (
|
+ self.args.wsc_margin_alpha
|
||||||
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
|
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
|
||||||
).clamp(min=0)
|
|
||||||
).sum()
|
).sum()
|
||||||
|
|
||||||
def forward(self, model, sample, reduce=True):
|
def forward(self, model, sample, reduce=True):
|
||||||
# compute loss and accuracy
|
# compute loss and accuracy
|
||||||
loss, nloss = 0., 0
|
loss, nloss = 0.0, 0
|
||||||
ncorrect, nqueries = 0, 0
|
ncorrect, nqueries = 0, 0
|
||||||
|
|
||||||
for i, label in enumerate(sample['labels']):
|
for i, label in enumerate(sample["labels"]):
|
||||||
query_lprobs = self.get_lprobs(
|
query_lprobs = self.get_lprobs(
|
||||||
model,
|
model,
|
||||||
sample['query_tokens'][i].unsqueeze(0),
|
sample["query_tokens"][i].unsqueeze(0),
|
||||||
sample['query_masks'][i].unsqueeze(0),
|
sample["query_masks"][i].unsqueeze(0),
|
||||||
)
|
)
|
||||||
cand_lprobs = self.get_lprobs(
|
cand_lprobs = self.get_lprobs(
|
||||||
model,
|
model,
|
||||||
sample['candidate_tokens'][i],
|
sample["candidate_tokens"][i],
|
||||||
sample['candidate_masks'][i],
|
sample["candidate_masks"][i],
|
||||||
)
|
)
|
||||||
|
|
||||||
pred = (query_lprobs >= cand_lprobs).all().item()
|
pred = (query_lprobs >= cand_lprobs).all().item()
|
||||||
@ -95,72 +96,72 @@ class WSCCriterion(LegacyFairseqCriterion):
|
|||||||
nloss += 1
|
nloss += 1
|
||||||
loss += self.get_loss(query_lprobs, cand_lprobs)
|
loss += self.get_loss(query_lprobs, cand_lprobs)
|
||||||
|
|
||||||
id = sample['id'][i].item()
|
id = sample["id"][i].item()
|
||||||
if self.prediction_h is not None:
|
if self.prediction_h is not None:
|
||||||
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
|
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
|
||||||
|
|
||||||
if nloss == 0:
|
if nloss == 0:
|
||||||
loss = torch.tensor(0.0, requires_grad=True)
|
loss = torch.tensor(0.0, requires_grad=True)
|
||||||
|
|
||||||
sample_size = nqueries if nqueries > 0 else 1
|
sample_size = nqueries if nqueries > 0 else 1
|
||||||
logging_output = {
|
logging_output = {
|
||||||
'loss': utils.item(loss.data) if reduce else loss.data,
|
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||||
'ntokens': sample['ntokens'],
|
"ntokens": sample["ntokens"],
|
||||||
'nsentences': sample['nsentences'],
|
"nsentences": sample["nsentences"],
|
||||||
'sample_size': sample_size,
|
"sample_size": sample_size,
|
||||||
'ncorrect': ncorrect,
|
"ncorrect": ncorrect,
|
||||||
'nqueries': nqueries,
|
"nqueries": nqueries,
|
||||||
}
|
}
|
||||||
return loss, sample_size, logging_output
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggregate_logging_outputs(logging_outputs):
|
def aggregate_logging_outputs(logging_outputs):
|
||||||
"""Aggregate logging outputs from data parallel training."""
|
"""Aggregate logging outputs from data parallel training."""
|
||||||
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
|
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||||
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
|
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||||
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
|
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
|
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||||
|
|
||||||
agg_output = {
|
agg_output = {
|
||||||
'loss': loss_sum / sample_size / math.log(2),
|
"loss": loss_sum / sample_size / math.log(2),
|
||||||
'ntokens': ntokens,
|
"ntokens": ntokens,
|
||||||
'nsentences': nsentences,
|
"nsentences": nsentences,
|
||||||
'sample_size': sample_size,
|
"sample_size": sample_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
|
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
||||||
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
|
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
|
||||||
if nqueries > 0:
|
if nqueries > 0:
|
||||||
agg_output['accuracy'] = ncorrect / float(nqueries)
|
agg_output["accuracy"] = ncorrect / float(nqueries)
|
||||||
|
|
||||||
return agg_output
|
return agg_output
|
||||||
|
|
||||||
|
|
||||||
@register_criterion('winogrande')
|
@register_criterion("winogrande")
|
||||||
class WinograndeCriterion(WSCCriterion):
|
class WinograndeCriterion(WSCCriterion):
|
||||||
def forward(self, model, sample, reduce=True):
|
def forward(self, model, sample, reduce=True):
|
||||||
# compute loss and accuracy
|
# compute loss and accuracy
|
||||||
query_lprobs = self.get_lprobs(
|
query_lprobs = self.get_lprobs(
|
||||||
model,
|
model,
|
||||||
sample['query_tokens'],
|
sample["query_tokens"],
|
||||||
sample['query_masks'],
|
sample["query_masks"],
|
||||||
)
|
)
|
||||||
cand_lprobs = self.get_lprobs(
|
cand_lprobs = self.get_lprobs(
|
||||||
model,
|
model,
|
||||||
sample['candidate_tokens'],
|
sample["candidate_tokens"],
|
||||||
sample['candidate_masks'],
|
sample["candidate_masks"],
|
||||||
)
|
)
|
||||||
pred = query_lprobs >= cand_lprobs
|
pred = query_lprobs >= cand_lprobs
|
||||||
loss = self.get_loss(query_lprobs, cand_lprobs)
|
loss = self.get_loss(query_lprobs, cand_lprobs)
|
||||||
|
|
||||||
sample_size = sample['query_tokens'].size(0)
|
sample_size = sample["query_tokens"].size(0)
|
||||||
ncorrect = pred.sum().item()
|
ncorrect = pred.sum().item()
|
||||||
logging_output = {
|
logging_output = {
|
||||||
'loss': utils.item(loss.data) if reduce else loss.data,
|
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||||
'ntokens': sample['ntokens'],
|
"ntokens": sample["ntokens"],
|
||||||
'nsentences': sample['nsentences'],
|
"nsentences": sample["nsentences"],
|
||||||
'sample_size': sample_size,
|
"sample_size": sample_size,
|
||||||
'ncorrect': ncorrect,
|
"ncorrect": ncorrect,
|
||||||
'nqueries': sample_size,
|
"nqueries": sample_size,
|
||||||
}
|
}
|
||||||
return loss, sample_size, logging_output
|
return loss, sample_size, logging_output
|
||||||
|
@ -10,47 +10,51 @@ import tempfile
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.data import (
|
from fairseq.data import (
|
||||||
data_utils,
|
|
||||||
Dictionary,
|
Dictionary,
|
||||||
encoders,
|
|
||||||
IdDataset,
|
IdDataset,
|
||||||
ListDataset,
|
ListDataset,
|
||||||
NestedDictionaryDataset,
|
NestedDictionaryDataset,
|
||||||
NumSamplesDataset,
|
|
||||||
NumelDataset,
|
NumelDataset,
|
||||||
|
NumSamplesDataset,
|
||||||
PadDataset,
|
PadDataset,
|
||||||
SortDataset,
|
SortDataset,
|
||||||
|
data_utils,
|
||||||
|
encoders,
|
||||||
)
|
)
|
||||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||||
|
|
||||||
from . import wsc_utils
|
from . import wsc_utils
|
||||||
|
|
||||||
|
|
||||||
@register_task('wsc')
|
@register_task("wsc")
|
||||||
class WSCTask(LegacyFairseqTask):
|
class WSCTask(LegacyFairseqTask):
|
||||||
"""Task to finetune RoBERTa for Winograd Schemas."""
|
"""Task to finetune RoBERTa for Winograd Schemas."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add task-specific arguments to the parser."""
|
"""Add task-specific arguments to the parser."""
|
||||||
parser.add_argument('data', metavar='DIR',
|
parser.add_argument(
|
||||||
help='path to data directory; we load <split>.jsonl')
|
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
||||||
parser.add_argument('--init-token', type=int, default=None,
|
)
|
||||||
help='add token at the beginning of each batch item')
|
parser.add_argument(
|
||||||
|
"--init-token",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="add token at the beginning of each batch item",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, args, vocab):
|
def __init__(self, args, vocab):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.mask = vocab.add_symbol('<mask>')
|
self.mask = vocab.add_symbol("<mask>")
|
||||||
|
|
||||||
self.bpe = encoders.build_bpe(args)
|
self.bpe = encoders.build_bpe(args)
|
||||||
self.tokenizer = encoders.build_tokenizer(args)
|
self.tokenizer = encoders.build_tokenizer(args)
|
||||||
|
|
||||||
# hack to handle GPT-2 BPE, which includes leading spaces
|
# hack to handle GPT-2 BPE, which includes leading spaces
|
||||||
if args.bpe == 'gpt2':
|
if args.bpe == "gpt2":
|
||||||
self.leading_space = True
|
self.leading_space = True
|
||||||
self.trailing_space = False
|
self.trailing_space = False
|
||||||
else:
|
else:
|
||||||
@ -65,16 +69,16 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
filename (str): the filename
|
filename (str): the filename
|
||||||
"""
|
"""
|
||||||
dictionary = Dictionary.load(filename)
|
dictionary = Dictionary.load(filename)
|
||||||
dictionary.add_symbol('<mask>')
|
dictionary.add_symbol("<mask>")
|
||||||
return dictionary
|
return dictionary
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_task(cls, args, **kwargs):
|
def setup_task(cls, args, **kwargs):
|
||||||
assert args.criterion == 'wsc', 'Must set --criterion=wsc'
|
assert args.criterion == "wsc", "Must set --criterion=wsc"
|
||||||
|
|
||||||
# load data and label dictionaries
|
# load data and label dictionaries
|
||||||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
|
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
||||||
print('| dictionary: {} types'.format(len(vocab)))
|
print("| dictionary: {} types".format(len(vocab)))
|
||||||
|
|
||||||
return cls(args, vocab)
|
return cls(args, vocab)
|
||||||
|
|
||||||
@ -84,7 +88,9 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
if self.bpe is not None:
|
if self.bpe is not None:
|
||||||
s = self.bpe.encode(s)
|
s = self.bpe.encode(s)
|
||||||
tokens = self.vocab.encode_line(
|
tokens = self.vocab.encode_line(
|
||||||
s, append_eos=append_eos, add_if_not_exist=False,
|
s,
|
||||||
|
append_eos=append_eos,
|
||||||
|
add_if_not_exist=False,
|
||||||
).long()
|
).long()
|
||||||
if self.args.init_token is not None:
|
if self.args.init_token is not None:
|
||||||
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
||||||
@ -98,19 +104,21 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
mask = torch.zeros_like(toks, dtype=torch.bool)
|
mask = torch.zeros_like(toks, dtype=torch.bool)
|
||||||
mask_start = len(self.binarize(prefix))
|
mask_start = len(self.binarize(prefix))
|
||||||
mask_size = len(self.binarize(leading_space + txt))
|
mask_size = len(self.binarize(leading_space + txt))
|
||||||
mask[mask_start:mask_start + mask_size] = 1
|
mask[mask_start : mask_start + mask_size] = 1
|
||||||
return toks, mask
|
return toks, mask
|
||||||
|
|
||||||
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
|
def load_dataset(
|
||||||
|
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
||||||
|
):
|
||||||
"""Load a given dataset split.
|
"""Load a given dataset split.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
split (str): name of the split (e.g., train, valid, test)
|
split (str): name of the split (e.g., train, valid, test)
|
||||||
"""
|
"""
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = os.path.join(self.args.data, split + '.jsonl')
|
data_path = os.path.join(self.args.data, split + ".jsonl")
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
|
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
||||||
|
|
||||||
query_tokens = []
|
query_tokens = []
|
||||||
query_masks = []
|
query_masks = []
|
||||||
@ -121,13 +129,15 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
|
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
|
||||||
prefix = sentence[:pronoun_span.start].text
|
prefix = sentence[: pronoun_span.start].text
|
||||||
suffix = sentence[pronoun_span.end:].text_with_ws
|
suffix = sentence[pronoun_span.end :].text_with_ws
|
||||||
|
|
||||||
# spaCy spans include trailing spaces, but we need to know about
|
# spaCy spans include trailing spaces, but we need to know about
|
||||||
# leading spaces for the GPT-2 BPE
|
# leading spaces for the GPT-2 BPE
|
||||||
leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
|
leading_space = (
|
||||||
trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
|
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
|
||||||
|
)
|
||||||
|
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
|
||||||
|
|
||||||
# get noun phrases, excluding pronouns and anything overlapping with the query
|
# get noun phrases, excluding pronouns and anything overlapping with the query
|
||||||
cand_spans = wsc_utils.filter_noun_chunks(
|
cand_spans = wsc_utils.filter_noun_chunks(
|
||||||
@ -152,7 +162,11 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
cand_toks, cand_masks = [], []
|
cand_toks, cand_masks = [], []
|
||||||
for cand_span in cand_spans:
|
for cand_span in cand_spans:
|
||||||
toks, mask = self.binarize_with_mask(
|
toks, mask = self.binarize_with_mask(
|
||||||
cand_span.text, prefix, suffix, leading_space, trailing_space,
|
cand_span.text,
|
||||||
|
prefix,
|
||||||
|
suffix,
|
||||||
|
leading_space,
|
||||||
|
trailing_space,
|
||||||
)
|
)
|
||||||
cand_toks.append(toks)
|
cand_toks.append(toks)
|
||||||
cand_masks.append(mask)
|
cand_masks.append(mask)
|
||||||
@ -176,17 +190,17 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
|
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
|
||||||
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
|
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
|
||||||
|
|
||||||
labels = ListDataset(labels, [1]*len(labels))
|
labels = ListDataset(labels, [1] * len(labels))
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
'id': IdDataset(),
|
"id": IdDataset(),
|
||||||
'query_tokens': query_tokens,
|
"query_tokens": query_tokens,
|
||||||
'query_masks': query_masks,
|
"query_masks": query_masks,
|
||||||
'candidate_tokens': candidate_tokens,
|
"candidate_tokens": candidate_tokens,
|
||||||
'candidate_masks': candidate_masks,
|
"candidate_masks": candidate_masks,
|
||||||
'labels': labels,
|
"labels": labels,
|
||||||
'nsentences': NumSamplesDataset(),
|
"nsentences": NumSamplesDataset(),
|
||||||
'ntokens': NumelDataset(query_tokens, reduce=True),
|
"ntokens": NumelDataset(query_tokens, reduce=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
nested_dataset = NestedDictionaryDataset(
|
nested_dataset = NestedDictionaryDataset(
|
||||||
@ -210,9 +224,9 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
|
|
||||||
def build_dataset_for_inference(self, sample_json):
|
def build_dataset_for_inference(self, sample_json):
|
||||||
with tempfile.NamedTemporaryFile(buffering=0) as h:
|
with tempfile.NamedTemporaryFile(buffering=0) as h:
|
||||||
h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
|
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
|
||||||
dataset = self.load_dataset(
|
dataset = self.load_dataset(
|
||||||
'disambiguate_pronoun',
|
"disambiguate_pronoun",
|
||||||
data_path=h.name,
|
data_path=h.name,
|
||||||
return_only=True,
|
return_only=True,
|
||||||
)
|
)
|
||||||
@ -239,19 +253,19 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
cand_lprobs = get_lprobs(
|
cand_lprobs = get_lprobs(
|
||||||
sample['candidate_tokens'][0],
|
sample["candidate_tokens"][0],
|
||||||
sample['candidate_masks'][0],
|
sample["candidate_masks"][0],
|
||||||
)
|
)
|
||||||
if sample['query_tokens'][0] is not None:
|
if sample["query_tokens"][0] is not None:
|
||||||
query_lprobs = get_lprobs(
|
query_lprobs = get_lprobs(
|
||||||
sample['query_tokens'][0].unsqueeze(0),
|
sample["query_tokens"][0].unsqueeze(0),
|
||||||
sample['query_masks'][0].unsqueeze(0),
|
sample["query_masks"][0].unsqueeze(0),
|
||||||
)
|
)
|
||||||
return (query_lprobs >= cand_lprobs).all().item() == 1
|
return (query_lprobs >= cand_lprobs).all().item() == 1
|
||||||
else:
|
else:
|
||||||
best_idx = cand_lprobs.argmax().item()
|
best_idx = cand_lprobs.argmax().item()
|
||||||
full_cand = sample['candidate_tokens'][0][best_idx]
|
full_cand = sample["candidate_tokens"][0][best_idx]
|
||||||
mask = sample['candidate_masks'][0][best_idx]
|
mask = sample["candidate_masks"][0][best_idx]
|
||||||
toks = full_cand[mask.bool()]
|
toks = full_cand[mask.bool()]
|
||||||
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
|
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
|
||||||
|
|
||||||
@ -264,7 +278,7 @@ class WSCTask(LegacyFairseqTask):
|
|||||||
return self.vocab
|
return self.vocab
|
||||||
|
|
||||||
|
|
||||||
@register_task('winogrande')
|
@register_task("winogrande")
|
||||||
class WinograndeTask(WSCTask):
|
class WinograndeTask(WSCTask):
|
||||||
"""
|
"""
|
||||||
Task for WinoGrande dataset. Efficient implementation for Winograd schema
|
Task for WinoGrande dataset. Efficient implementation for Winograd schema
|
||||||
@ -273,24 +287,26 @@ class WinograndeTask(WSCTask):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_task(cls, args, **kwargs):
|
def setup_task(cls, args, **kwargs):
|
||||||
assert args.criterion == 'winogrande', 'Must set --criterion=winogrande'
|
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
|
||||||
|
|
||||||
# load data and label dictionaries
|
# load data and label dictionaries
|
||||||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
|
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
||||||
print('| dictionary: {} types'.format(len(vocab)))
|
print("| dictionary: {} types".format(len(vocab)))
|
||||||
|
|
||||||
return cls(args, vocab)
|
return cls(args, vocab)
|
||||||
|
|
||||||
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
|
def load_dataset(
|
||||||
|
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
||||||
|
):
|
||||||
"""Load a given dataset split.
|
"""Load a given dataset split.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
split (str): name of the split (e.g., train, valid, test)
|
split (str): name of the split (e.g., train, valid, test)
|
||||||
"""
|
"""
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = os.path.join(self.args.data, split + '.jsonl')
|
data_path = os.path.join(self.args.data, split + ".jsonl")
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
|
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
||||||
|
|
||||||
query_tokens = []
|
query_tokens = []
|
||||||
query_masks = []
|
query_masks = []
|
||||||
@ -299,19 +315,23 @@ class WinograndeTask(WSCTask):
|
|||||||
candidate_masks = []
|
candidate_masks = []
|
||||||
candidate_lengths = []
|
candidate_lengths = []
|
||||||
|
|
||||||
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test'))
|
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
|
||||||
|
|
||||||
for sample in itr:
|
for sample in itr:
|
||||||
sentence, pronoun_span, query, cand_text = sample
|
sentence, pronoun_span, query, cand_text = sample
|
||||||
prefix = sentence[:pronoun_span[0]].rstrip()
|
prefix = sentence[: pronoun_span[0]].rstrip()
|
||||||
suffix = sentence[pronoun_span[1]:]
|
suffix = sentence[pronoun_span[1] :]
|
||||||
|
|
||||||
leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else ''
|
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
|
||||||
trailing_space = ''
|
trailing_space = ""
|
||||||
|
|
||||||
if query is not None:
|
if query is not None:
|
||||||
query_toks, query_mask = self.binarize_with_mask(
|
query_toks, query_mask = self.binarize_with_mask(
|
||||||
query, prefix, suffix, leading_space, trailing_space,
|
query,
|
||||||
|
prefix,
|
||||||
|
suffix,
|
||||||
|
leading_space,
|
||||||
|
trailing_space,
|
||||||
)
|
)
|
||||||
query_len = len(query_toks)
|
query_len = len(query_toks)
|
||||||
else:
|
else:
|
||||||
@ -322,7 +342,11 @@ class WinograndeTask(WSCTask):
|
|||||||
query_lengths.append(query_len)
|
query_lengths.append(query_len)
|
||||||
|
|
||||||
cand_toks, cand_mask = self.binarize_with_mask(
|
cand_toks, cand_mask = self.binarize_with_mask(
|
||||||
cand_text, prefix, suffix, leading_space, trailing_space,
|
cand_text,
|
||||||
|
prefix,
|
||||||
|
suffix,
|
||||||
|
leading_space,
|
||||||
|
trailing_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
candidate_tokens.append(cand_toks)
|
candidate_tokens.append(cand_toks)
|
||||||
@ -342,17 +366,19 @@ class WinograndeTask(WSCTask):
|
|||||||
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
|
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
|
||||||
|
|
||||||
candidate_lengths = np.array(candidate_lengths)
|
candidate_lengths = np.array(candidate_lengths)
|
||||||
candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad())
|
candidate_tokens = get_pad_dataset_fn(
|
||||||
|
candidate_tokens, candidate_lengths, self.vocab.pad()
|
||||||
|
)
|
||||||
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
|
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
'id': IdDataset(),
|
"id": IdDataset(),
|
||||||
'query_tokens': query_tokens,
|
"query_tokens": query_tokens,
|
||||||
'query_masks': query_masks,
|
"query_masks": query_masks,
|
||||||
'candidate_tokens': candidate_tokens,
|
"candidate_tokens": candidate_tokens,
|
||||||
'candidate_masks': candidate_masks,
|
"candidate_masks": candidate_masks,
|
||||||
'nsentences': NumSamplesDataset(),
|
"nsentences": NumSamplesDataset(),
|
||||||
'ntokens': NumelDataset(query_tokens, reduce=True),
|
"ntokens": NumelDataset(query_tokens, reduce=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
nested_dataset = NestedDictionaryDataset(
|
nested_dataset = NestedDictionaryDataset(
|
||||||
|
@ -3,48 +3,48 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
import json
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
def convert_sentence_to_json(sentence):
|
def convert_sentence_to_json(sentence):
|
||||||
if '_' in sentence:
|
if "_" in sentence:
|
||||||
prefix, rest = sentence.split('_', 1)
|
prefix, rest = sentence.split("_", 1)
|
||||||
query, rest = rest.split('_', 1)
|
query, rest = rest.split("_", 1)
|
||||||
query_index = len(prefix.rstrip().split(' '))
|
query_index = len(prefix.rstrip().split(" "))
|
||||||
else:
|
else:
|
||||||
query, query_index = None, None
|
query, query_index = None, None
|
||||||
|
|
||||||
prefix, rest = sentence.split('[', 1)
|
prefix, rest = sentence.split("[", 1)
|
||||||
pronoun, rest = rest.split(']', 1)
|
pronoun, rest = rest.split("]", 1)
|
||||||
pronoun_index = len(prefix.rstrip().split(' '))
|
pronoun_index = len(prefix.rstrip().split(" "))
|
||||||
|
|
||||||
sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
|
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'idx': 0,
|
"idx": 0,
|
||||||
'text': sentence,
|
"text": sentence,
|
||||||
'target': {
|
"target": {
|
||||||
'span1_index': query_index,
|
"span1_index": query_index,
|
||||||
'span1_text': query,
|
"span1_text": query,
|
||||||
'span2_index': pronoun_index,
|
"span2_index": pronoun_index,
|
||||||
'span2_text': pronoun,
|
"span2_text": pronoun,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def extended_noun_chunks(sentence):
|
def extended_noun_chunks(sentence):
|
||||||
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
|
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
|
||||||
np_start, cur_np = 0, 'NONE'
|
np_start, cur_np = 0, "NONE"
|
||||||
for i, token in enumerate(sentence):
|
for i, token in enumerate(sentence):
|
||||||
np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
|
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
|
||||||
if np_type != cur_np:
|
if np_type != cur_np:
|
||||||
if cur_np != 'NONE':
|
if cur_np != "NONE":
|
||||||
noun_chunks.add((np_start, i))
|
noun_chunks.add((np_start, i))
|
||||||
if np_type != 'NONE':
|
if np_type != "NONE":
|
||||||
np_start = i
|
np_start = i
|
||||||
cur_np = np_type
|
cur_np = np_type
|
||||||
if cur_np != 'NONE':
|
if cur_np != "NONE":
|
||||||
noun_chunks.add((np_start, len(sentence)))
|
noun_chunks.add((np_start, len(sentence)))
|
||||||
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
|
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
|
||||||
|
|
||||||
@ -61,14 +61,14 @@ def find_token(sentence, start_pos):
|
|||||||
def find_span(sentence, search_text, start=0):
|
def find_span(sentence, search_text, start=0):
|
||||||
search_text = search_text.lower()
|
search_text = search_text.lower()
|
||||||
for tok in sentence[start:]:
|
for tok in sentence[start:]:
|
||||||
remainder = sentence[tok.i:].text.lower()
|
remainder = sentence[tok.i :].text.lower()
|
||||||
if remainder.startswith(search_text):
|
if remainder.startswith(search_text):
|
||||||
len_to_consume = len(search_text)
|
len_to_consume = len(search_text)
|
||||||
start_idx = tok.idx
|
start_idx = tok.idx
|
||||||
for next_tok in sentence[tok.i:]:
|
for next_tok in sentence[tok.i :]:
|
||||||
end_idx = next_tok.idx + len(next_tok.text)
|
end_idx = next_tok.idx + len(next_tok.text)
|
||||||
if end_idx - start_idx == len_to_consume:
|
if end_idx - start_idx == len_to_consume:
|
||||||
span = sentence[tok.i:next_tok.i + 1]
|
span = sentence[tok.i : next_tok.i + 1]
|
||||||
return span
|
return span
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -76,13 +76,15 @@ def find_span(sentence, search_text, start=0):
|
|||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_detokenizer():
|
def get_detokenizer():
|
||||||
from sacremoses import MosesDetokenizer
|
from sacremoses import MosesDetokenizer
|
||||||
detok = MosesDetokenizer(lang='en')
|
|
||||||
|
detok = MosesDetokenizer(lang="en")
|
||||||
return detok
|
return detok
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_spacy_nlp():
|
def get_spacy_nlp():
|
||||||
import en_core_web_lg
|
import en_core_web_lg
|
||||||
|
|
||||||
nlp = en_core_web_lg.load()
|
nlp = en_core_web_lg.load()
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
@ -95,45 +97,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
|||||||
for line in fin:
|
for line in fin:
|
||||||
sample = json.loads(line.strip())
|
sample = json.loads(line.strip())
|
||||||
|
|
||||||
if positive_only and 'label' in sample and not sample['label']:
|
if positive_only and "label" in sample and not sample["label"]:
|
||||||
# only consider examples where the query is correct
|
# only consider examples where the query is correct
|
||||||
continue
|
continue
|
||||||
|
|
||||||
target = sample['target']
|
target = sample["target"]
|
||||||
|
|
||||||
# clean up the query
|
# clean up the query
|
||||||
query = target['span1_text']
|
query = target["span1_text"]
|
||||||
if query is not None:
|
if query is not None:
|
||||||
if '\n' in query:
|
if "\n" in query:
|
||||||
continue
|
continue
|
||||||
if query.endswith('.') or query.endswith(','):
|
if query.endswith(".") or query.endswith(","):
|
||||||
query = query[:-1]
|
query = query[:-1]
|
||||||
|
|
||||||
# split tokens
|
# split tokens
|
||||||
tokens = sample['text'].split(' ')
|
tokens = sample["text"].split(" ")
|
||||||
|
|
||||||
def strip_pronoun(x):
|
def strip_pronoun(x):
|
||||||
return x.rstrip('.,"')
|
return x.rstrip('.,"')
|
||||||
|
|
||||||
# find the pronoun
|
# find the pronoun
|
||||||
pronoun_idx = target['span2_index']
|
pronoun_idx = target["span2_index"]
|
||||||
pronoun = strip_pronoun(target['span2_text'])
|
pronoun = strip_pronoun(target["span2_text"])
|
||||||
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
|
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
|
||||||
# hack: sometimes the index is misaligned
|
# hack: sometimes the index is misaligned
|
||||||
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
|
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
|
||||||
pronoun_idx += 1
|
pronoun_idx += 1
|
||||||
else:
|
else:
|
||||||
raise Exception('Misaligned pronoun!')
|
raise Exception("Misaligned pronoun!")
|
||||||
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
|
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
|
||||||
|
|
||||||
# split tokens before and after the pronoun
|
# split tokens before and after the pronoun
|
||||||
before = tokens[:pronoun_idx]
|
before = tokens[:pronoun_idx]
|
||||||
after = tokens[pronoun_idx + 1:]
|
after = tokens[pronoun_idx + 1 :]
|
||||||
|
|
||||||
# the GPT BPE attaches leading spaces to tokens, so we keep track
|
# the GPT BPE attaches leading spaces to tokens, so we keep track
|
||||||
# of whether we need spaces before or after the pronoun
|
# of whether we need spaces before or after the pronoun
|
||||||
leading_space = ' ' if pronoun_idx > 0 else ''
|
leading_space = " " if pronoun_idx > 0 else ""
|
||||||
trailing_space = ' ' if len(after) > 0 else ''
|
trailing_space = " " if len(after) > 0 else ""
|
||||||
|
|
||||||
# detokenize
|
# detokenize
|
||||||
before = detok.detokenize(before, return_str=True)
|
before = detok.detokenize(before, return_str=True)
|
||||||
@ -142,14 +144,14 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
|||||||
|
|
||||||
# hack: when the pronoun ends in a period (or comma), move the
|
# hack: when the pronoun ends in a period (or comma), move the
|
||||||
# punctuation to the "after" part
|
# punctuation to the "after" part
|
||||||
if pronoun.endswith('.') or pronoun.endswith(','):
|
if pronoun.endswith(".") or pronoun.endswith(","):
|
||||||
after = pronoun[-1] + trailing_space + after
|
after = pronoun[-1] + trailing_space + after
|
||||||
pronoun = pronoun[:-1]
|
pronoun = pronoun[:-1]
|
||||||
|
|
||||||
# hack: when the "after" part begins with a comma or period, remove
|
# hack: when the "after" part begins with a comma or period, remove
|
||||||
# the trailing space
|
# the trailing space
|
||||||
if after.startswith('.') or after.startswith(','):
|
if after.startswith(".") or after.startswith(","):
|
||||||
trailing_space = ''
|
trailing_space = ""
|
||||||
|
|
||||||
# parse sentence with spacy
|
# parse sentence with spacy
|
||||||
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
|
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
|
||||||
@ -164,13 +166,13 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
|||||||
# convert to format where pronoun is surrounded by "[]" and
|
# convert to format where pronoun is surrounded by "[]" and
|
||||||
# query is surrounded by "_"
|
# query is surrounded by "_"
|
||||||
query_span = find_span(sentence, query)
|
query_span = find_span(sentence, query)
|
||||||
query_with_ws = '_{}_{}'.format(
|
query_with_ws = "_{}_{}".format(
|
||||||
query_span.text,
|
query_span.text,
|
||||||
(' ' if query_span.text_with_ws.endswith(' ') else '')
|
(" " if query_span.text_with_ws.endswith(" ") else ""),
|
||||||
)
|
)
|
||||||
pronoun_with_ws = '[{}]{}'.format(
|
pronoun_with_ws = "[{}]{}".format(
|
||||||
pronoun_span.text,
|
pronoun_span.text,
|
||||||
(' ' if pronoun_span.text_with_ws.endswith(' ') else '')
|
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
|
||||||
)
|
)
|
||||||
if query_span.start < pronoun_span.start:
|
if query_span.start < pronoun_span.start:
|
||||||
first = (query_span, query_with_ws)
|
first = (query_span, query_with_ws)
|
||||||
@ -179,41 +181,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
|||||||
first = (pronoun_span, pronoun_with_ws)
|
first = (pronoun_span, pronoun_with_ws)
|
||||||
second = (query_span, query_with_ws)
|
second = (query_span, query_with_ws)
|
||||||
sentence = (
|
sentence = (
|
||||||
sentence[:first[0].start].text_with_ws
|
sentence[: first[0].start].text_with_ws
|
||||||
+ first[1]
|
+ first[1]
|
||||||
+ sentence[first[0].end:second[0].start].text_with_ws
|
+ sentence[first[0].end : second[0].start].text_with_ws
|
||||||
+ second[1]
|
+ second[1]
|
||||||
+ sentence[second[0].end:].text
|
+ sentence[second[0].end :].text
|
||||||
)
|
)
|
||||||
yield sentence, sample.get('label', None)
|
yield sentence, sample.get("label", None)
|
||||||
else:
|
else:
|
||||||
yield sentence, pronoun_span, query, sample.get('label', None)
|
yield sentence, pronoun_span, query, sample.get("label", None)
|
||||||
|
|
||||||
|
|
||||||
def winogrande_jsonl_iterator(input_fname, eval=False):
|
def winogrande_jsonl_iterator(input_fname, eval=False):
|
||||||
with open(input_fname) as fin:
|
with open(input_fname) as fin:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
sample = json.loads(line.strip())
|
sample = json.loads(line.strip())
|
||||||
sentence, option1, option2 = sample['sentence'], sample['option1'],\
|
sentence, option1, option2 = (
|
||||||
sample['option2']
|
sample["sentence"],
|
||||||
|
sample["option1"],
|
||||||
|
sample["option2"],
|
||||||
|
)
|
||||||
|
|
||||||
pronoun_span = (sentence.index('_'), sentence.index('_') + 1)
|
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
query, cand = option1, option2
|
query, cand = option1, option2
|
||||||
else:
|
else:
|
||||||
query = option1 if sample['answer'] == '1' else option2
|
query = option1 if sample["answer"] == "1" else option2
|
||||||
cand = option2 if sample['answer'] == '1' else option1
|
cand = option2 if sample["answer"] == "1" else option1
|
||||||
yield sentence, pronoun_span, query, cand
|
yield sentence, pronoun_span, query, cand
|
||||||
|
|
||||||
|
|
||||||
def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
|
def filter_noun_chunks(
|
||||||
|
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
|
||||||
|
):
|
||||||
if exclude_pronouns:
|
if exclude_pronouns:
|
||||||
chunks = [
|
chunks = [
|
||||||
np for np in chunks if (
|
np
|
||||||
np.lemma_ != '-PRON-'
|
for np in chunks
|
||||||
and not all(tok.pos_ == 'PRON' for tok in np)
|
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if exclude_query is not None:
|
if exclude_query is not None:
|
||||||
@ -224,9 +230,8 @@ def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact
|
|||||||
found = False
|
found = False
|
||||||
for excl in excl_txt:
|
for excl in excl_txt:
|
||||||
if (
|
if (
|
||||||
(not exact_match and (lower_chunk in excl or excl in lower_chunk))
|
not exact_match and (lower_chunk in excl or excl in lower_chunk)
|
||||||
or lower_chunk == excl
|
) or lower_chunk == excl:
|
||||||
):
|
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
if not found:
|
if not found:
|
||||||
|
@ -3,4 +3,4 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from . import criterions, models, eval # noqa
|
from . import criterions, eval, models # noqa
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith(".py") and not file.startswith("_"):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
criterion_name = file[: file.find(".py")]
|
criterion_name = file[: file.find(".py")]
|
||||||
|
@ -3,21 +3,17 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from examples.simultaneous_translation.utils.latency import LatencyTraining
|
||||||
from fairseq.criterions import register_criterion
|
from fairseq.criterions import register_criterion
|
||||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||||
LabelSmoothedCrossEntropyCriterion
|
LabelSmoothedCrossEntropyCriterion,
|
||||||
)
|
|
||||||
|
|
||||||
from examples.simultaneous_translation.utils.latency import (
|
|
||||||
LatencyTraining
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_criterion('latency_augmented_label_smoothed_cross_entropy')
|
@register_criterion("latency_augmented_label_smoothed_cross_entropy")
|
||||||
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||||
LabelSmoothedCrossEntropyCriterion
|
LabelSmoothedCrossEntropyCriterion
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, args, task):
|
def __init__(self, args, task):
|
||||||
super().__init__(args, task)
|
super().__init__(args, task)
|
||||||
self.eps = args.label_smoothing
|
self.eps = args.label_smoothing
|
||||||
@ -40,7 +36,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
|||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
super(
|
super(
|
||||||
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
||||||
LatencyAugmentedLabelSmoothedCrossEntropyCriterion
|
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
||||||
).add_args(parser)
|
).add_args(parser)
|
||||||
"""Add criterion-specific arguments to the parser."""
|
"""Add criterion-specific arguments to the parser."""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -69,7 +65,8 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
|||||||
|
|
||||||
# Get latency loss
|
# Get latency loss
|
||||||
latency_loss = self.latency_train.loss(
|
latency_loss = self.latency_train.loss(
|
||||||
attn_list, source_padding_mask, target_padding_mask)
|
attn_list, source_padding_mask, target_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
loss += latency_loss
|
loss += latency_loss
|
||||||
|
|
||||||
|
@ -5,16 +5,20 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from fairseq import registry
|
from fairseq import registry
|
||||||
|
|
||||||
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type')
|
|
||||||
|
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
|
||||||
|
"--agent-type"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_EOS = '</s>'
|
DEFAULT_EOS = "</s>"
|
||||||
GET = 0
|
GET = 0
|
||||||
SEND = 1
|
SEND = 1
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
module = file[:file.find('.py')]
|
module = file[: file.find(".py")]
|
||||||
importlib.import_module('agents.' + module)
|
importlib.import_module("agents." + module)
|
||||||
|
@ -3,14 +3,16 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from . import GET, SEND, DEFAULT_EOS
|
|
||||||
import time
|
import time
|
||||||
from multiprocessing.pool import ThreadPool as Pool
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from multiprocessing.pool import ThreadPool as Pool
|
||||||
|
|
||||||
|
from . import DEFAULT_EOS, GET, SEND
|
||||||
|
|
||||||
|
|
||||||
class Agent(object):
|
class Agent(object):
|
||||||
"an agent needs to follow this pattern"
|
"an agent needs to follow this pattern"
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -40,26 +42,26 @@ class Agent(object):
|
|||||||
with Pool(10) as p:
|
with Pool(10) as p:
|
||||||
p.map(
|
p.map(
|
||||||
partial(self._decode_one, session),
|
partial(self._decode_one, session),
|
||||||
[sent_id for sent_id in range(low, high + 1)]
|
[sent_id for sent_id in range(low, high + 1)],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for sent_id in range(low, high + 1):
|
for sent_id in range(low, high + 1):
|
||||||
self._decode_one(session, sent_id)
|
self._decode_one(session, sent_id)
|
||||||
|
|
||||||
print(f'Finished {low} to {high} in {time.time() - t0}s')
|
print(f"Finished {low} to {high} in {time.time() - t0}s")
|
||||||
|
|
||||||
def _decode_one(self, session, sent_id):
|
def _decode_one(self, session, sent_id):
|
||||||
action = {}
|
action = {}
|
||||||
self.reset()
|
self.reset()
|
||||||
states = self.init_states()
|
states = self.init_states()
|
||||||
while action.get('value', None) != DEFAULT_EOS:
|
while action.get("value", None) != DEFAULT_EOS:
|
||||||
# take an action
|
# take an action
|
||||||
action = self.policy(states)
|
action = self.policy(states)
|
||||||
|
|
||||||
if action['key'] == GET:
|
if action["key"] == GET:
|
||||||
new_states = session.get_src(sent_id, action["value"])
|
new_states = session.get_src(sent_id, action["value"])
|
||||||
states = self.update_states(states, new_states)
|
states = self.update_states(states, new_states)
|
||||||
|
|
||||||
elif action['key'] == SEND:
|
elif action["key"] == SEND:
|
||||||
session.send_hypo(sent_id, action['value'])
|
session.send_hypo(sent_id, action["value"])
|
||||||
print(" ".join(states["tokens"]["tgt"]))
|
print(" ".join(states["tokens"]["tgt"]))
|
||||||
|
@ -3,11 +3,13 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from . agent import Agent
|
|
||||||
from . import DEFAULT_EOS, GET, SEND
|
|
||||||
from fairseq import checkpoint_utils, utils, tasks
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fairseq import checkpoint_utils, tasks, utils
|
||||||
|
|
||||||
|
from . import DEFAULT_EOS, GET, SEND
|
||||||
|
from .agent import Agent
|
||||||
|
|
||||||
|
|
||||||
class SimulTransAgent(Agent):
|
class SimulTransAgent(Agent):
|
||||||
@ -51,13 +53,15 @@ class SimulTransAgent(Agent):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load_model(self, args):
|
def load_model(self, args):
|
||||||
args.user_dir = os.path.join(os.path.dirname(__file__), '..', '..')
|
args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||||
utils.import_user_module(args)
|
utils.import_user_module(args)
|
||||||
filename = args.model_path
|
filename = args.model_path
|
||||||
if not os.path.exists(filename):
|
if not os.path.exists(filename):
|
||||||
raise IOError("Model file not found: {}".format(filename))
|
raise IOError("Model file not found: {}".format(filename))
|
||||||
|
|
||||||
state = checkpoint_utils.load_checkpoint_to_cpu(filename, json.loads(args.model_overrides))
|
state = checkpoint_utils.load_checkpoint_to_cpu(
|
||||||
|
filename, json.loads(args.model_overrides)
|
||||||
|
)
|
||||||
|
|
||||||
saved_args = state["args"]
|
saved_args = state["args"]
|
||||||
saved_args.data = args.data_bin
|
saved_args.data = args.data_bin
|
||||||
@ -79,7 +83,7 @@ class SimulTransAgent(Agent):
|
|||||||
"steps": {"src": 0, "tgt": 0},
|
"steps": {"src": 0, "tgt": 0},
|
||||||
"finished": False,
|
"finished": False,
|
||||||
"finish_read": False,
|
"finish_read": False,
|
||||||
"model_states": {}
|
"model_states": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def update_states(self, states, new_state):
|
def update_states(self, states, new_state):
|
||||||
@ -115,38 +119,38 @@ class SimulTransAgent(Agent):
|
|||||||
def write_action(self, states):
|
def write_action(self, states):
|
||||||
token, index = self.model.predict_from_states(states)
|
token, index = self.model.predict_from_states(states)
|
||||||
|
|
||||||
if index == self.dict["tgt"].eos() or len(states["tokens"]["tgt"]) > self.max_len:
|
if (
|
||||||
|
index == self.dict["tgt"].eos()
|
||||||
|
or len(states["tokens"]["tgt"]) > self.max_len
|
||||||
|
):
|
||||||
# Finish this sentence is predict EOS
|
# Finish this sentence is predict EOS
|
||||||
states["finished"] = True
|
states["finished"] = True
|
||||||
end_idx_last_full_word = self._target_length(states)
|
end_idx_last_full_word = self._target_length(states)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
states["tokens"]["tgt"] += [token]
|
states["tokens"]["tgt"] += [token]
|
||||||
end_idx_last_full_word = (
|
end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
|
||||||
self.word_splitter["tgt"]
|
states["tokens"]["tgt"]
|
||||||
.end_idx_last_full_word(states["tokens"]["tgt"])
|
|
||||||
)
|
)
|
||||||
self._append_indices(states, [index], "tgt")
|
self._append_indices(states, [index], "tgt")
|
||||||
|
|
||||||
if end_idx_last_full_word > states["steps"]["tgt"]:
|
if end_idx_last_full_word > states["steps"]["tgt"]:
|
||||||
# Only sent detokenized full words to the server
|
# Only sent detokenized full words to the server
|
||||||
word = self.word_splitter["tgt"].merge(
|
word = self.word_splitter["tgt"].merge(
|
||||||
states["tokens"]["tgt"][
|
states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
|
||||||
states["steps"]["tgt"]: end_idx_last_full_word
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
states["steps"]["tgt"] = end_idx_last_full_word
|
states["steps"]["tgt"] = end_idx_last_full_word
|
||||||
states["segments"]["tgt"] += [word]
|
states["segments"]["tgt"] += [word]
|
||||||
|
|
||||||
return {'key': SEND, 'value': word}
|
return {"key": SEND, "value": word}
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def read_action(self, states):
|
def read_action(self, states):
|
||||||
return {'key': GET, 'value': None}
|
return {"key": GET, "value": None}
|
||||||
|
|
||||||
def finish_action(self):
|
def finish_action(self):
|
||||||
return {'key': SEND, 'value': DEFAULT_EOS}
|
return {"key": SEND, "value": DEFAULT_EOS}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
@ -160,4 +164,4 @@ class SimulTransAgent(Agent):
|
|||||||
states["indices"][key] += new_indices
|
states["indices"][key] += new_indices
|
||||||
|
|
||||||
def _target_length(self, states):
|
def _target_length(self, states):
|
||||||
return len(states["tokens"]['tgt'])
|
return len(states["tokens"]["tgt"])
|
||||||
|
@ -3,10 +3,9 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from . simul_trans_agent import SimulTransAgent
|
from . import DEFAULT_EOS, GET, register_agent
|
||||||
from . import DEFAULT_EOS, GET
|
from .simul_trans_agent import SimulTransAgent
|
||||||
from . import register_agent
|
from .word_splitter import SPLITTER_DICT
|
||||||
from . word_splitter import SPLITTER_DICT
|
|
||||||
|
|
||||||
|
|
||||||
@register_agent("simul_trans_text")
|
@register_agent("simul_trans_text")
|
||||||
@ -15,11 +14,11 @@ class SimulTransTextAgent(SimulTransAgent):
|
|||||||
self.word_splitter = {}
|
self.word_splitter = {}
|
||||||
|
|
||||||
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
|
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
|
||||||
getattr(args, f"src_splitter_path")
|
getattr(args, f"src_splitter_path")
|
||||||
)
|
)
|
||||||
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
|
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
|
||||||
getattr(args, f"tgt_splitter_path")
|
getattr(args, f"tgt_splitter_path")
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_dictionary(self, task):
|
def load_dictionary(self, task):
|
||||||
self.dict = {}
|
self.dict = {}
|
||||||
@ -37,12 +36,16 @@ class SimulTransTextAgent(SimulTransAgent):
|
|||||||
tokens = self.word_splitter["src"].split(new_word)
|
tokens = self.word_splitter["src"].split(new_word)
|
||||||
# Get indices from dictionary
|
# Get indices from dictionary
|
||||||
# You can change to you own dictionary
|
# You can change to you own dictionary
|
||||||
indices = self.dict["src"].encode_line(
|
indices = (
|
||||||
tokens,
|
self.dict["src"]
|
||||||
line_tokenizer=lambda x: x,
|
.encode_line(
|
||||||
add_if_not_exist=False,
|
tokens,
|
||||||
append_eos=False
|
line_tokenizer=lambda x: x,
|
||||||
).tolist()
|
add_if_not_exist=False,
|
||||||
|
append_eos=False,
|
||||||
|
)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tokens = [new_word]
|
tokens = [new_word]
|
||||||
indices = [self.dict["src"].eos()]
|
indices = [self.dict["src"].eos()]
|
||||||
@ -61,11 +64,11 @@ class SimulTransTextAgent(SimulTransAgent):
|
|||||||
|
|
||||||
# At leat one word is read
|
# At leat one word is read
|
||||||
if len(states["tokens"]["src"]) == 0:
|
if len(states["tokens"]["src"]) == 0:
|
||||||
return {'key': GET, 'value': None}
|
return {"key": GET, "value": None}
|
||||||
|
|
||||||
# Only request new word if there is no buffered tokens
|
# Only request new word if there is no buffered tokens
|
||||||
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
|
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
|
||||||
return {'key': GET, 'value': None}
|
return {"key": GET, "value": None}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ class BPEWordSplitter(object):
|
|||||||
def __init__(self, model_path):
|
def __init__(self, model_path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from subword_nmt.apply_bpe import BPE
|
from subword_nmt.apply_bpe import BPE
|
||||||
|
|
||||||
with open(model_path) as f:
|
with open(model_path) as f:
|
||||||
self.model = BPE(f)
|
self.model = BPE(f)
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ class BPEWordSplitter(object):
|
|||||||
|
|
||||||
def end_idx_last_full_word(self, tokens):
|
def end_idx_last_full_word(self, tokens):
|
||||||
# Begin of word indices
|
# Begin of word indices
|
||||||
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@']
|
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
|
||||||
|
|
||||||
if len(bow_indices) < 2:
|
if len(bow_indices) < 2:
|
||||||
return 0
|
return 0
|
||||||
@ -63,6 +64,7 @@ class SentencePieceModelWordSplitter(object):
|
|||||||
def __init__(self, model_path):
|
def __init__(self, model_path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
self.model = spm.SentencePieceProcessor()
|
self.model = spm.SentencePieceProcessor()
|
||||||
self.model.Load(model_path)
|
self.model.Load(model_path)
|
||||||
|
|
||||||
@ -71,7 +73,7 @@ class SentencePieceModelWordSplitter(object):
|
|||||||
|
|
||||||
def end_idx_last_full_word(self, tokens):
|
def end_idx_last_full_word(self, tokens):
|
||||||
# Begin of word indices
|
# Begin of word indices
|
||||||
bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581']
|
bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
|
||||||
|
|
||||||
if len(bow_indices) < 2:
|
if len(bow_indices) < 2:
|
||||||
return 0
|
return 0
|
||||||
|
@ -3,19 +3,20 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import requests
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
from scorers import build_scorer
|
from scorers import build_scorer
|
||||||
|
|
||||||
|
|
||||||
class SimulSTEvaluationService(object):
|
class SimulSTEvaluationService(object):
|
||||||
DEFAULT_HOSTNAME = 'localhost'
|
DEFAULT_HOSTNAME = "localhost"
|
||||||
DEFAULT_PORT = 12321
|
DEFAULT_PORT = 12321
|
||||||
|
|
||||||
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
|
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
|
||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
self.port = port
|
self.port = port
|
||||||
self.base_url = f'http://{self.hostname}:{self.port}'
|
self.base_url = f"http://{self.hostname}:{self.port}"
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.new_session()
|
self.new_session()
|
||||||
@ -25,56 +26,53 @@ class SimulSTEvaluationService(object):
|
|||||||
|
|
||||||
def new_session(self):
|
def new_session(self):
|
||||||
# start eval session
|
# start eval session
|
||||||
url = f'{self.base_url}'
|
url = f"{self.base_url}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_ = requests.post(url)
|
_ = requests.post(url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to start an evaluation session: {e}')
|
print(f"Failed to start an evaluation session: {e}")
|
||||||
|
|
||||||
print('Evaluation session started.')
|
print("Evaluation session started.")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_scores(self):
|
def get_scores(self):
|
||||||
# end eval session
|
# end eval session
|
||||||
url = f'{self.base_url}/result'
|
url = f"{self.base_url}/result"
|
||||||
try:
|
try:
|
||||||
r = requests.get(url)
|
r = requests.get(url)
|
||||||
print('Scores: {}'.format(r.json()))
|
print("Scores: {}".format(r.json()))
|
||||||
print('Evaluation session finished.')
|
print("Evaluation session finished.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to end an evaluation session: {e}')
|
print(f"Failed to end an evaluation session: {e}")
|
||||||
|
|
||||||
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
|
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
|
||||||
url = f'{self.base_url}/src'
|
url = f"{self.base_url}/src"
|
||||||
params = {"sent_id": sent_id}
|
params = {"sent_id": sent_id}
|
||||||
if extra_params is not None:
|
if extra_params is not None:
|
||||||
for key in extra_params.keys():
|
for key in extra_params.keys():
|
||||||
params[key] = extra_params[key]
|
params[key] = extra_params[key]
|
||||||
try:
|
try:
|
||||||
r = requests.get(
|
r = requests.get(url, params=params)
|
||||||
url,
|
|
||||||
params=params
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to request a source segment: {e}')
|
print(f"Failed to request a source segment: {e}")
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
def send_hypo(self, sent_id: int, hypo: str) -> None:
|
def send_hypo(self, sent_id: int, hypo: str) -> None:
|
||||||
url = f'{self.base_url}/hypo'
|
url = f"{self.base_url}/hypo"
|
||||||
params = {"sent_id": sent_id}
|
params = {"sent_id": sent_id}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requests.put(url, params=params, data=hypo.encode("utf-8"))
|
requests.put(url, params=params, data=hypo.encode("utf-8"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to send a translated segment: {e}')
|
print(f"Failed to send a translated segment: {e}")
|
||||||
|
|
||||||
def corpus_info(self):
|
def corpus_info(self):
|
||||||
url = f'{self.base_url}'
|
url = f"{self.base_url}"
|
||||||
try:
|
try:
|
||||||
r = requests.get(url)
|
r = requests.get(url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to request corpus information: {e}')
|
print(f"Failed to request corpus information: {e}")
|
||||||
|
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
|
@ -3,20 +3,21 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from examples.simultaneous_translation.utils.latency import LatencyInference
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from examples.simultaneous_translation.utils.latency import LatencyInference
|
||||||
|
|
||||||
|
|
||||||
LATENCY_METRICS = [
|
LATENCY_METRICS = [
|
||||||
'differentiable_average_lagging',
|
"differentiable_average_lagging",
|
||||||
'average_lagging',
|
"average_lagging",
|
||||||
'average_proportion',
|
"average_proportion",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class LatencyScorer():
|
class LatencyScorer:
|
||||||
def __init__(self, start_from_zero=True):
|
def __init__(self, start_from_zero=True):
|
||||||
self.recorder = []
|
self.recorder = []
|
||||||
self.scores = {}
|
self.scores = {}
|
||||||
@ -26,10 +27,7 @@ class LatencyScorer():
|
|||||||
def update_reorder(self, list_of_dict):
|
def update_reorder(self, list_of_dict):
|
||||||
self.recorder = []
|
self.recorder = []
|
||||||
for info in list_of_dict:
|
for info in list_of_dict:
|
||||||
delays = [
|
delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
|
||||||
int(x) - int(not self.start_from_zero)
|
|
||||||
for x in info["delays"]
|
|
||||||
]
|
|
||||||
delays = torch.LongTensor(delays).unsqueeze(0)
|
delays = torch.LongTensor(delays).unsqueeze(0)
|
||||||
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
|
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
|
||||||
|
|
||||||
@ -59,7 +57,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
scorer = LatencyInference()
|
scorer = LatencyInference()
|
||||||
recorder = []
|
recorder = []
|
||||||
with open(args.input, 'r') as f:
|
with open(args.input, "r") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
info = json.loads(line)
|
info = json.loads(line)
|
||||||
|
|
||||||
@ -74,7 +72,7 @@ if __name__ == "__main__":
|
|||||||
average_results = {}
|
average_results = {}
|
||||||
|
|
||||||
for metric in LATENCY_METRICS:
|
for metric in LATENCY_METRICS:
|
||||||
average_results[metric] = sum(
|
average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
|
||||||
[x[metric][0, 0].item() for x in recorder]
|
recorder
|
||||||
) / len(recorder)
|
)
|
||||||
print(f"{metric}: {average_results[metric]}")
|
print(f"{metric}: {average_results[metric]}")
|
||||||
|
@ -5,37 +5,48 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from agents import build_agent
|
||||||
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
|
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
|
||||||
from fairseq.registry import REGISTRIES
|
from fairseq.registry import REGISTRIES
|
||||||
from agents import build_agent
|
|
||||||
|
|
||||||
DEFAULT_HOSTNAME = 'localhost'
|
|
||||||
|
DEFAULT_HOSTNAME = "localhost"
|
||||||
DEFAULT_PORT = 12321
|
DEFAULT_PORT = 12321
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
|
parser.add_argument(
|
||||||
help='server hostname')
|
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
|
||||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
)
|
||||||
help='server port number')
|
parser.add_argument(
|
||||||
parser.add_argument('--agent-type', default='simul_trans_text',
|
"--port", type=int, default=DEFAULT_PORT, help="server port number"
|
||||||
help='Agent type')
|
)
|
||||||
parser.add_argument('--scorer-type', default='text',
|
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
|
||||||
help='Scorer type')
|
parser.add_argument("--scorer-type", default="text", help="Scorer type")
|
||||||
parser.add_argument('--start-idx', type=int, default=0,
|
parser.add_argument(
|
||||||
help='Start index of the sentence to evaluate')
|
"--start-idx",
|
||||||
parser.add_argument('--end-idx', type=int, default=float('inf'),
|
type=int,
|
||||||
help='End index of the sentence to evaluate')
|
default=0,
|
||||||
parser.add_argument('--scores', action="store_true",
|
help="Start index of the sentence to evaluate",
|
||||||
help='Request scores from server')
|
)
|
||||||
parser.add_argument('--reset-server', action="store_true",
|
parser.add_argument(
|
||||||
help='Reset the server')
|
"--end-idx",
|
||||||
parser.add_argument('--num-threads', type=int, default=10,
|
type=int,
|
||||||
help='Number of threads used by agent')
|
default=float("inf"),
|
||||||
parser.add_argument('--local', action="store_true", default=False,
|
help="End index of the sentence to evaluate",
|
||||||
help='Local evaluation')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scores", action="store_true", help="Request scores from server"
|
||||||
|
)
|
||||||
|
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads", type=int, default=10, help="Number of threads used by agent"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local", action="store_true", default=False, help="Local evaluation"
|
||||||
|
)
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -5,15 +5,15 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from fairseq import registry
|
from fairseq import registry
|
||||||
(
|
|
||||||
build_scorer,
|
|
||||||
register_scorer,
|
(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
|
||||||
SCORER_REGISTRIES,
|
"--scorer-type"
|
||||||
_
|
)
|
||||||
) = registry.setup_registry('--scorer-type')
|
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
module = file[:file.find('.py')]
|
module = file[: file.find(".py")]
|
||||||
importlib.import_module('scorers.' + module)
|
importlib.import_module("scorers." + module)
|
||||||
|
@ -3,16 +3,17 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from vizseq.scorers.bleu import BLEUScorer
|
|
||||||
from vizseq.scorers.ter import TERScorer
|
|
||||||
from vizseq.scorers.meteor import METEORScorer
|
|
||||||
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
|
|
||||||
from collections import defaultdict
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
|
||||||
|
from vizseq.scorers.bleu import BLEUScorer
|
||||||
|
from vizseq.scorers.meteor import METEORScorer
|
||||||
|
from vizseq.scorers.ter import TERScorer
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_EOS = '</s>'
|
DEFAULT_EOS = "</s>"
|
||||||
|
|
||||||
|
|
||||||
class SimulScorer(object):
|
class SimulScorer(object):
|
||||||
@ -23,7 +24,7 @@ class SimulScorer(object):
|
|||||||
self.output_files = {
|
self.output_files = {
|
||||||
"text": os.path.join(args.output, "text"),
|
"text": os.path.join(args.output, "text"),
|
||||||
"delay": os.path.join(args.output, "delay"),
|
"delay": os.path.join(args.output, "delay"),
|
||||||
"scores": os.path.join(args.output, "scores")
|
"scores": os.path.join(args.output, "scores"),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.output_files = None
|
self.output_files = None
|
||||||
@ -52,14 +53,7 @@ class SimulScorer(object):
|
|||||||
|
|
||||||
def recv_hyp(self, sent_id, list_of_tokens):
|
def recv_hyp(self, sent_id, list_of_tokens):
|
||||||
for token in list_of_tokens:
|
for token in list_of_tokens:
|
||||||
self.translations[
|
self.translations[sent_id].append((token, self.steps[sent_id]))
|
||||||
sent_id
|
|
||||||
].append(
|
|
||||||
(
|
|
||||||
token,
|
|
||||||
self.steps[sent_id]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.steps = defaultdict(int)
|
self.steps = defaultdict(int)
|
||||||
@ -76,8 +70,9 @@ class SimulScorer(object):
|
|||||||
delays += [[t[1] for t in self.translations[i]]]
|
delays += [[t[1] for t in self.translations[i]]]
|
||||||
|
|
||||||
bleu_score = BLEUScorer(
|
bleu_score = BLEUScorer(
|
||||||
sent_level=False, corpus_level=True,
|
sent_level=False,
|
||||||
extra_args={'bleu_tokenizer': self.tokenizer}
|
corpus_level=True,
|
||||||
|
extra_args={"bleu_tokenizer": self.tokenizer},
|
||||||
).score(translations, [self.data["tgt"]])
|
).score(translations, [self.data["tgt"]])
|
||||||
|
|
||||||
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
|
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
|
||||||
@ -92,16 +87,16 @@ class SimulScorer(object):
|
|||||||
{"src_len": src_len, "delays": delay}
|
{"src_len": src_len, "delays": delay}
|
||||||
for src_len, delay in zip(self.src_lengths(), delays)
|
for src_len, delay in zip(self.src_lengths(), delays)
|
||||||
],
|
],
|
||||||
start_from_zero=False
|
start_from_zero=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
scores = {
|
scores = {
|
||||||
'BLEU': bleu_score[0],
|
"BLEU": bleu_score[0],
|
||||||
'TER': ter_score[0],
|
"TER": ter_score[0],
|
||||||
'METEOR': meteor_score[0],
|
"METEOR": meteor_score[0],
|
||||||
'DAL': latency_score['differentiable_average_lagging'],
|
"DAL": latency_score["differentiable_average_lagging"],
|
||||||
'AL': latency_score['average_lagging'],
|
"AL": latency_score["average_lagging"],
|
||||||
'AP': latency_score['average_proportion'],
|
"AP": latency_score["average_proportion"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.output_files is not None:
|
if self.output_files is not None:
|
||||||
@ -109,9 +104,9 @@ class SimulScorer(object):
|
|||||||
os.makedirs(self.output_dir, exist_ok=True)
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
self.write_results_to_file(translations, delays, scores)
|
self.write_results_to_file(translations, delays, scores)
|
||||||
except BaseException as be:
|
except BaseException as be:
|
||||||
print(f'Failed to write results to {self.output_dir}.')
|
print(f"Failed to write results to {self.output_dir}.")
|
||||||
print(be)
|
print(be)
|
||||||
print('Skip writing predictions')
|
print("Skip writing predictions")
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@ -125,12 +120,8 @@ class SimulScorer(object):
|
|||||||
with open(self.output_files["delay"], "w") as f:
|
with open(self.output_files["delay"], "w") as f:
|
||||||
for i, delay in enumerate(delays):
|
for i, delay in enumerate(delays):
|
||||||
f.write(
|
f.write(
|
||||||
json.dumps(
|
json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
|
||||||
{
|
+ "\n"
|
||||||
"src_len": self.src_lengths()[i],
|
|
||||||
"delays": delay
|
|
||||||
}
|
|
||||||
) + "\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(self.output_files["scores"], "w") as f:
|
with open(self.output_files["scores"], "w") as f:
|
||||||
@ -163,7 +154,7 @@ class SimulScorer(object):
|
|||||||
list_to_return.append(
|
list_to_return.append(
|
||||||
{
|
{
|
||||||
"path": item["input"]["path"].strip(),
|
"path": item["input"]["path"].strip(),
|
||||||
"length": item["input"]["length_ms"]
|
"length": item["input"]["length_ms"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return list_to_return
|
return list_to_return
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from . scorer import SimulScorer
|
|
||||||
from . import register_scorer
|
from . import register_scorer
|
||||||
|
from .scorer import SimulScorer
|
||||||
|
|
||||||
|
|
||||||
@register_scorer("text")
|
@register_scorer("text")
|
||||||
@ -13,7 +13,7 @@ class SimulTextScorer(SimulScorer):
|
|||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.data = {
|
self.data = {
|
||||||
"src": self._load_text_file(args.src_file, split=True),
|
"src": self._load_text_file(args.src_file, split=True),
|
||||||
"tgt": self._load_text_file(args.tgt_file, split=False)
|
"tgt": self._load_text_file(args.tgt_file, split=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
def send_src(self, sent_id, *args):
|
def send_src(self, sent_id, *args):
|
||||||
@ -21,7 +21,7 @@ class SimulTextScorer(SimulScorer):
|
|||||||
dict_to_return = {
|
dict_to_return = {
|
||||||
"sent_id": sent_id,
|
"sent_id": sent_id,
|
||||||
"segment_id": self.steps[sent_id],
|
"segment_id": self.steps[sent_id],
|
||||||
"segment": self.eos
|
"segment": self.eos,
|
||||||
}
|
}
|
||||||
# Consider EOS
|
# Consider EOS
|
||||||
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
|
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
|
||||||
@ -29,7 +29,7 @@ class SimulTextScorer(SimulScorer):
|
|||||||
dict_to_return = {
|
dict_to_return = {
|
||||||
"sent_id": sent_id,
|
"sent_id": sent_id,
|
||||||
"segment_id": self.steps[sent_id],
|
"segment_id": self.steps[sent_id],
|
||||||
"segment": self.data["src"][sent_id][self.steps[sent_id]]
|
"segment": self.data["src"][sent_id][self.steps[sent_id]],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.steps[sent_id] += 1
|
self.steps[sent_id] += 1
|
||||||
|
@ -3,12 +3,14 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
from tornado import web, ioloop
|
import sys
|
||||||
from scorers import build_scorer
|
|
||||||
|
|
||||||
DEFAULT_HOSTNAME = 'localhost'
|
from scorers import build_scorer
|
||||||
|
from tornado import ioloop, web
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_HOSTNAME = "localhost"
|
||||||
DEFAULT_PORT = 12321
|
DEFAULT_PORT = 12321
|
||||||
|
|
||||||
|
|
||||||
@ -34,10 +36,10 @@ class ResultHandler(ScorerHandler):
|
|||||||
|
|
||||||
class SourceHandler(ScorerHandler):
|
class SourceHandler(ScorerHandler):
|
||||||
def get(self):
|
def get(self):
|
||||||
sent_id = int(self.get_argument('sent_id'))
|
sent_id = int(self.get_argument("sent_id"))
|
||||||
segment_size = None
|
segment_size = None
|
||||||
if "segment_size" in self.request.arguments:
|
if "segment_size" in self.request.arguments:
|
||||||
string = self.get_argument('segment_size')
|
string = self.get_argument("segment_size")
|
||||||
if len(string) > 0:
|
if len(string) > 0:
|
||||||
segment_size = int(string)
|
segment_size = int(string)
|
||||||
|
|
||||||
@ -48,8 +50,8 @@ class SourceHandler(ScorerHandler):
|
|||||||
|
|
||||||
class HypothesisHandler(ScorerHandler):
|
class HypothesisHandler(ScorerHandler):
|
||||||
def put(self):
|
def put(self):
|
||||||
sent_id = int(self.get_argument('sent_id'))
|
sent_id = int(self.get_argument("sent_id"))
|
||||||
list_of_tokens = self.request.body.decode('utf-8').strip().split()
|
list_of_tokens = self.request.body.decode("utf-8").strip().split()
|
||||||
self.scorer.recv_hyp(sent_id, list_of_tokens)
|
self.scorer.recv_hyp(sent_id, list_of_tokens)
|
||||||
|
|
||||||
|
|
||||||
@ -67,18 +69,21 @@ def add_args():
|
|||||||
|
|
||||||
|
|
||||||
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
|
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
|
||||||
app = web.Application([
|
app = web.Application(
|
||||||
(r'/result', ResultHandler, dict(scorer=scorer)),
|
[
|
||||||
(r'/src', SourceHandler, dict(scorer=scorer)),
|
(r"/result", ResultHandler, dict(scorer=scorer)),
|
||||||
(r'/hypo', HypothesisHandler, dict(scorer=scorer)),
|
(r"/src", SourceHandler, dict(scorer=scorer)),
|
||||||
(r'/', EvalSessionHandler, dict(scorer=scorer)),
|
(r"/hypo", HypothesisHandler, dict(scorer=scorer)),
|
||||||
], debug=debug)
|
(r"/", EvalSessionHandler, dict(scorer=scorer)),
|
||||||
|
],
|
||||||
|
debug=debug,
|
||||||
|
)
|
||||||
app.listen(port, max_buffer_size=1024 ** 3)
|
app.listen(port, max_buffer_size=1024 ** 3)
|
||||||
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
|
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
|
||||||
ioloop.IOLoop.current().start()
|
ioloop.IOLoop.current().start()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
args = add_args()
|
args = add_args()
|
||||||
scorer = build_scorer(args)
|
scorer = build_scorer(args)
|
||||||
start_server(scorer, args.hostname, args.port, args.debug)
|
start_server(scorer, args.hostname, args.port, args.debug)
|
||||||
|
@ -6,7 +6,10 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
model_name = file[:file.find('.py')]
|
model_name = file[: file.find(".py")]
|
||||||
importlib.import_module('examples.simultaneous_translation.models.' + model_name)
|
importlib.import_module(
|
||||||
|
"examples.simultaneous_translation.models." + model_name
|
||||||
|
)
|
||||||
|
@ -6,42 +6,34 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
|
||||||
from fairseq.models import (
|
TransformerMonotonicDecoderLayer,
|
||||||
register_model,
|
TransformerMonotonicEncoderLayer,
|
||||||
register_model_architecture,
|
|
||||||
)
|
)
|
||||||
|
from fairseq.models import register_model, register_model_architecture
|
||||||
|
|
||||||
from fairseq.models.transformer import (
|
from fairseq.models.transformer import (
|
||||||
TransformerModel,
|
|
||||||
TransformerEncoder,
|
|
||||||
TransformerDecoder,
|
TransformerDecoder,
|
||||||
|
TransformerEncoder,
|
||||||
|
TransformerModel,
|
||||||
base_architecture,
|
base_architecture,
|
||||||
transformer_iwslt_de_en,
|
transformer_iwslt_de_en,
|
||||||
transformer_vaswani_wmt_en_de_big,
|
transformer_vaswani_wmt_en_de_big,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
|
|
||||||
TransformerMonotonicDecoderLayer,
|
|
||||||
TransformerMonotonicEncoderLayer
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||||
|
|
||||||
|
|
||||||
@register_model('transformer_unidirectional')
|
@register_model("transformer_unidirectional")
|
||||||
class TransformerUnidirectionalModel(TransformerModel):
|
class TransformerUnidirectionalModel(TransformerModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||||
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
||||||
|
|
||||||
|
|
||||||
@register_model('transformer_monotonic')
|
@register_model("transformer_monotonic")
|
||||||
class TransformerMonotonicModel(TransformerModel):
|
class TransformerMonotonicModel(TransformerModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||||
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
||||||
@ -62,26 +54,17 @@ class TransformerMonotonicModel(TransformerModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tgt_indices = tensor(
|
tgt_indices = tensor(
|
||||||
[
|
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
|
||||||
[self.decoder.dictionary.eos()]
|
|
||||||
+ states["indices"]["tgt"]
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
src_indices = states["indices"]["src"][: 1 +
|
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
|
||||||
states["steps"]["src"]]
|
|
||||||
tgt_indices = states["indices"]["tgt"]
|
tgt_indices = states["indices"]["tgt"]
|
||||||
|
|
||||||
return src_indices, None, tgt_indices
|
return src_indices, None, tgt_indices
|
||||||
|
|
||||||
def predict_from_states(self, states):
|
def predict_from_states(self, states):
|
||||||
decoder_states = self.decoder.output_layer(
|
decoder_states = self.decoder.output_layer(states["decoder_features"])
|
||||||
states["decoder_features"]
|
lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
|
||||||
)
|
|
||||||
lprobs = self.get_normalized_probs(
|
|
||||||
[decoder_states[:, -1:]],
|
|
||||||
log_probs=True
|
|
||||||
)
|
|
||||||
|
|
||||||
index = lprobs.argmax(dim=-1)
|
index = lprobs.argmax(dim=-1)
|
||||||
|
|
||||||
@ -90,25 +73,24 @@ class TransformerMonotonicModel(TransformerModel):
|
|||||||
return token, index[0, 0].item()
|
return token, index[0, 0].item()
|
||||||
|
|
||||||
def decision_from_states(self, states):
|
def decision_from_states(self, states):
|
||||||
'''
|
"""
|
||||||
This funcion take states dictionary as input, and gives the agent
|
This funcion take states dictionary as input, and gives the agent
|
||||||
a decision of whether read a token from server. Moreover, the decoder
|
a decision of whether read a token from server. Moreover, the decoder
|
||||||
states are also calculated here so we can directly generate a target
|
states are also calculated here so we can directly generate a target
|
||||||
token without recompute every thing
|
token without recompute every thing
|
||||||
'''
|
"""
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
if len(states["tokens"]["src"]) == 0:
|
if len(states["tokens"]["src"]) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
src_indices, src_lengths, tgt_indices = self._indices_from_states(
|
src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
|
||||||
states)
|
|
||||||
|
|
||||||
# Update encoder states if needed
|
# Update encoder states if needed
|
||||||
if (
|
if (
|
||||||
"encoder_states" not in states or
|
"encoder_states" not in states
|
||||||
states["encoder_states"][0].size(1) <= states["steps"]["src"]
|
or states["encoder_states"][0].size(1) <= states["steps"]["src"]
|
||||||
):
|
):
|
||||||
encoder_out_dict = self.encoder(src_indices, src_lengths)
|
encoder_out_dict = self.encoder(src_indices, src_lengths)
|
||||||
states["encoder_states"] = encoder_out_dict
|
states["encoder_states"] = encoder_out_dict
|
||||||
@ -136,16 +118,14 @@ class TransformerMonotonicModel(TransformerModel):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerMonotonicEncoder(TransformerEncoder):
|
class TransformerMonotonicEncoder(TransformerEncoder):
|
||||||
|
|
||||||
def __init__(self, args, dictionary, embed_tokens):
|
def __init__(self, args, dictionary, embed_tokens):
|
||||||
super().__init__(args, dictionary, embed_tokens)
|
super().__init__(args, dictionary, embed_tokens)
|
||||||
|
|
||||||
self.dictionary = dictionary
|
self.dictionary = dictionary
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
self.layers.extend([
|
self.layers.extend(
|
||||||
TransformerMonotonicEncoderLayer(args)
|
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
|
||||||
for i in range(args.encoder_layers)
|
)
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerMonotonicDecoder(TransformerDecoder):
|
class TransformerMonotonicDecoder(TransformerDecoder):
|
||||||
@ -166,19 +146,24 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
|||||||
|
|
||||||
self.dictionary = dictionary
|
self.dictionary = dictionary
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
self.layers.extend([
|
self.layers.extend(
|
||||||
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
|
[
|
||||||
for _ in range(args.decoder_layers)
|
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
|
||||||
])
|
for _ in range(args.decoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def pre_attention(
|
def pre_attention(
|
||||||
self, prev_output_tokens, encoder_out_dict,
|
self, prev_output_tokens, encoder_out_dict, incremental_state=None
|
||||||
incremental_state=None
|
|
||||||
):
|
):
|
||||||
positions = self.embed_positions(
|
positions = (
|
||||||
prev_output_tokens,
|
self.embed_positions(
|
||||||
incremental_state=incremental_state,
|
prev_output_tokens,
|
||||||
) if self.embed_positions is not None else None
|
incremental_state=incremental_state,
|
||||||
|
)
|
||||||
|
if self.embed_positions is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||||
@ -216,8 +201,7 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def extract_features(
|
def extract_features(
|
||||||
self, prev_output_tokens, encoder_out,
|
self, prev_output_tokens, encoder_out, incremental_state=None, **unused
|
||||||
incremental_state=None, **unused
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Similar to *forward* but only return features.
|
Similar to *forward* but only return features.
|
||||||
@ -228,14 +212,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
|||||||
- a dictionary with any model-specific outputs
|
- a dictionary with any model-specific outputs
|
||||||
"""
|
"""
|
||||||
# incremental_state = None
|
# incremental_state = None
|
||||||
(
|
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
|
||||||
x,
|
prev_output_tokens, encoder_out, incremental_state
|
||||||
encoder_outs,
|
|
||||||
encoder_padding_mask
|
|
||||||
) = self.pre_attention(
|
|
||||||
prev_output_tokens,
|
|
||||||
encoder_out,
|
|
||||||
incremental_state
|
|
||||||
)
|
)
|
||||||
attn = None
|
attn = None
|
||||||
inner_states = [x]
|
inner_states = [x]
|
||||||
@ -250,7 +228,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
|||||||
encoder_padding_mask=encoder_padding_mask,
|
encoder_padding_mask=encoder_padding_mask,
|
||||||
incremental_state=incremental_state,
|
incremental_state=incremental_state,
|
||||||
self_attn_mask=self.buffered_future_mask(x)
|
self_attn_mask=self.buffered_future_mask(x)
|
||||||
if incremental_state is None else None,
|
if incremental_state is None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
inner_states.append(x)
|
inner_states.append(x)
|
||||||
@ -261,38 +240,30 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
|||||||
step_list.append(curr_steps)
|
step_list.append(curr_steps)
|
||||||
|
|
||||||
if incremental_state.get("online", False):
|
if incremental_state.get("online", False):
|
||||||
p_choose = attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
|
p_choose = (
|
||||||
|
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
|
||||||
new_steps = (
|
|
||||||
curr_steps
|
|
||||||
+ (p_choose < 0.5).t().type_as(curr_steps)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
|
||||||
|
|
||||||
if (new_steps >= incremental_state["steps"]["src"]).any():
|
if (new_steps >= incremental_state["steps"]["src"]).any():
|
||||||
# We need to prune the last self_attn saved_state
|
# We need to prune the last self_attn saved_state
|
||||||
# if model decide not to read
|
# if model decide not to read
|
||||||
# otherwise there will be duplicated saved_state
|
# otherwise there will be duplicated saved_state
|
||||||
for j in range(i + 1):
|
for j in range(i + 1):
|
||||||
self.layers[j].prune_incremental_state(
|
self.layers[j].prune_incremental_state(incremental_state)
|
||||||
incremental_state)
|
|
||||||
|
|
||||||
return x, {"action": 0}
|
return x, {"action": 0}
|
||||||
|
|
||||||
if (
|
if incremental_state is not None and not incremental_state.get("online", False):
|
||||||
incremental_state is not None
|
|
||||||
and not incremental_state.get("online", False)
|
|
||||||
):
|
|
||||||
# Here is for fast evaluation
|
# Here is for fast evaluation
|
||||||
fastest_step = torch.max(
|
fastest_step = (
|
||||||
torch.cat(step_list, dim=1),
|
torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
|
||||||
dim=1,
|
)
|
||||||
keepdim=True
|
|
||||||
)[0] + 1
|
|
||||||
|
|
||||||
if "fastest_step" in incremental_state:
|
if "fastest_step" in incremental_state:
|
||||||
incremental_state["fastest_step"] = torch.cat(
|
incremental_state["fastest_step"] = torch.cat(
|
||||||
[incremental_state["fastest_step"], fastest_step],
|
[incremental_state["fastest_step"], fastest_step], dim=1
|
||||||
dim=1
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
incremental_state["fastest_step"] = fastest_step
|
incremental_state["fastest_step"] = fastest_step
|
||||||
@ -310,25 +281,19 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
|||||||
def reorder_incremental_state(self, incremental_state, new_order):
|
def reorder_incremental_state(self, incremental_state, new_order):
|
||||||
super().reorder_incremental_state(incremental_state, new_order)
|
super().reorder_incremental_state(incremental_state, new_order)
|
||||||
if "fastest_step" in incremental_state:
|
if "fastest_step" in incremental_state:
|
||||||
incremental_state["fastest_step"] = (
|
incremental_state["fastest_step"] = incremental_state[
|
||||||
incremental_state["fastest_step"]
|
"fastest_step"
|
||||||
.index_select(0, new_order)
|
].index_select(0, new_order)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture(
|
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
|
||||||
'transformer_monotonic',
|
|
||||||
'transformer_monotonic'
|
|
||||||
)
|
|
||||||
def base_monotonic_rchitecture(args):
|
def base_monotonic_rchitecture(args):
|
||||||
base_architecture(args)
|
base_architecture(args)
|
||||||
args.encoder_unidirectional = getattr(
|
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
|
||||||
args, 'encoder_unidirectional', False)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture(
|
@register_model_architecture(
|
||||||
'transformer_monotonic',
|
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
|
||||||
'transformer_monotonic_iwslt_de_en'
|
|
||||||
)
|
)
|
||||||
def transformer_monotonic_iwslt_de_en(args):
|
def transformer_monotonic_iwslt_de_en(args):
|
||||||
transformer_iwslt_de_en(args)
|
transformer_iwslt_de_en(args)
|
||||||
@ -337,24 +302,21 @@ def transformer_monotonic_iwslt_de_en(args):
|
|||||||
|
|
||||||
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
||||||
@register_model_architecture(
|
@register_model_architecture(
|
||||||
'transformer_monotonic',
|
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
|
||||||
'transformer_monotonic_vaswani_wmt_en_de_big'
|
|
||||||
)
|
)
|
||||||
def transformer_monotonic_vaswani_wmt_en_de_big(args):
|
def transformer_monotonic_vaswani_wmt_en_de_big(args):
|
||||||
transformer_vaswani_wmt_en_de_big(args)
|
transformer_vaswani_wmt_en_de_big(args)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture(
|
@register_model_architecture(
|
||||||
'transformer_monotonic',
|
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
|
||||||
'transformer_monotonic_vaswani_wmt_en_fr_big'
|
|
||||||
)
|
)
|
||||||
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
|
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
|
||||||
transformer_monotonic_vaswani_wmt_en_fr_big(args)
|
transformer_monotonic_vaswani_wmt_en_fr_big(args)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture(
|
@register_model_architecture(
|
||||||
'transformer_unidirectional',
|
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
|
||||||
'transformer_unidirectional_iwslt_de_en'
|
|
||||||
)
|
)
|
||||||
def transformer_unidirectional_iwslt_de_en(args):
|
def transformer_unidirectional_iwslt_de_en(args):
|
||||||
transformer_iwslt_de_en(args)
|
transformer_iwslt_de_en(args)
|
||||||
|
@ -7,14 +7,18 @@ import importlib
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from fairseq import registry
|
from fairseq import registry
|
||||||
|
|
||||||
|
|
||||||
(
|
(
|
||||||
build_monotonic_attention,
|
build_monotonic_attention,
|
||||||
register_monotonic_attention,
|
register_monotonic_attention,
|
||||||
MONOTONIC_ATTENTION_REGISTRY,
|
MONOTONIC_ATTENTION_REGISTRY,
|
||||||
_
|
_,
|
||||||
) = registry.setup_registry('--simul-type')
|
) = registry.setup_registry("--simul-type")
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
model_name = file[:file.find('.py')]
|
model_name = file[: file.find(".py")]
|
||||||
importlib.import_module('examples.simultaneous_translation.modules.' + model_name)
|
importlib.import_module(
|
||||||
|
"examples.simultaneous_translation.modules." + model_name
|
||||||
|
)
|
||||||
|
@ -4,22 +4,19 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from fairseq import utils
|
|
||||||
|
|
||||||
from fairseq.modules import MultiheadAttention
|
|
||||||
|
|
||||||
from examples.simultaneous_translation.utils.functions import (
|
from examples.simultaneous_translation.utils.functions import (
|
||||||
exclusive_cumprod,
|
exclusive_cumprod,
|
||||||
lengths_to_mask
|
lengths_to_mask,
|
||||||
)
|
)
|
||||||
|
from fairseq import utils
|
||||||
|
|
||||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||||
|
from fairseq.modules import MultiheadAttention
|
||||||
from fairseq.utils import convert_padding_direction
|
from fairseq.utils import convert_padding_direction
|
||||||
|
|
||||||
from . import register_monotonic_attention
|
from . import register_monotonic_attention
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +25,7 @@ class MonotonicAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Abstract class of monotonic attentions
|
Abstract class of monotonic attentions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.eps = args.attention_eps
|
self.eps = args.attention_eps
|
||||||
self.mass_preservation = args.mass_preservation
|
self.mass_preservation = args.mass_preservation
|
||||||
@ -38,7 +36,8 @@ class MonotonicAttention(nn.Module):
|
|||||||
self.energy_bias_init = args.energy_bias_init
|
self.energy_bias_init = args.energy_bias_init
|
||||||
self.energy_bias = (
|
self.energy_bias = (
|
||||||
nn.Parameter(self.energy_bias_init * torch.ones([1]))
|
nn.Parameter(self.energy_bias_init * torch.ones([1]))
|
||||||
if args.energy_bias is True else 0
|
if args.energy_bias is True
|
||||||
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -90,7 +89,7 @@ class MonotonicAttention(nn.Module):
|
|||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
attn_energy = attn_energy.masked_fill(
|
attn_energy = attn_energy.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
|
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
|
||||||
float('-inf'),
|
float("-inf"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_energy
|
return attn_energy
|
||||||
@ -131,10 +130,7 @@ class MonotonicAttention(nn.Module):
|
|||||||
alpha_i = (
|
alpha_i = (
|
||||||
p_choose[:, i]
|
p_choose[:, i]
|
||||||
* cumprod_1mp[:, i]
|
* cumprod_1mp[:, i]
|
||||||
* torch.cumsum(
|
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
|
||||||
previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
|
|
||||||
dim=1
|
|
||||||
)
|
|
||||||
).clamp(0, 1.0)
|
).clamp(0, 1.0)
|
||||||
previous_attn.append(alpha_i.unsqueeze(1))
|
previous_attn.append(alpha_i.unsqueeze(1))
|
||||||
|
|
||||||
@ -170,8 +166,7 @@ class MonotonicAttention(nn.Module):
|
|||||||
# prev_monotonic_step: bsz, num_heads
|
# prev_monotonic_step: bsz, num_heads
|
||||||
bsz = bsz_num_heads // self.num_heads
|
bsz = bsz_num_heads // self.num_heads
|
||||||
prev_monotonic_step = monotonic_cache.get(
|
prev_monotonic_step = monotonic_cache.get(
|
||||||
"step",
|
"step", p_choose.new_zeros([bsz, self.num_heads]).long()
|
||||||
p_choose.new_zeros([bsz, self.num_heads]).long()
|
|
||||||
)
|
)
|
||||||
bsz, num_heads = prev_monotonic_step.size()
|
bsz, num_heads = prev_monotonic_step.size()
|
||||||
assert num_heads == self.num_heads
|
assert num_heads == self.num_heads
|
||||||
@ -181,8 +176,7 @@ class MonotonicAttention(nn.Module):
|
|||||||
p_choose = p_choose.view(bsz, num_heads, src_len)
|
p_choose = p_choose.view(bsz, num_heads, src_len)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
src_lengths = src_len - \
|
src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
|
||||||
key_padding_mask.sum(dim=1, keepdim=True).long()
|
|
||||||
else:
|
else:
|
||||||
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
|
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
|
||||||
|
|
||||||
@ -197,10 +191,7 @@ class MonotonicAttention(nn.Module):
|
|||||||
# left_pad_source = True:
|
# left_pad_source = True:
|
||||||
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
max_steps = (
|
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
|
||||||
src_lengths - 1 if self.mass_preservation
|
|
||||||
else src_lengths
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish_read: bsz, num_heads
|
# finish_read: bsz, num_heads
|
||||||
finish_read = new_monotonic_step.eq(max_steps)
|
finish_read = new_monotonic_step.eq(max_steps)
|
||||||
@ -210,11 +201,11 @@ class MonotonicAttention(nn.Module):
|
|||||||
# only choose the p at monotonic steps
|
# only choose the p at monotonic steps
|
||||||
# p_choose_i: bsz , self.num_heads
|
# p_choose_i: bsz , self.num_heads
|
||||||
p_choose_i = (
|
p_choose_i = (
|
||||||
p_choose
|
p_choose.gather(
|
||||||
.gather(
|
|
||||||
2,
|
2,
|
||||||
(step_offset + new_monotonic_step).unsqueeze(2)
|
(step_offset + new_monotonic_step)
|
||||||
.clamp(0, src_len - 1)
|
.unsqueeze(2)
|
||||||
|
.clamp(0, src_len - 1),
|
||||||
)
|
)
|
||||||
).squeeze(2)
|
).squeeze(2)
|
||||||
|
|
||||||
@ -239,21 +230,17 @@ class MonotonicAttention(nn.Module):
|
|||||||
|
|
||||||
# alpha: bsz * num_heads, 1, src_len
|
# alpha: bsz * num_heads, 1, src_len
|
||||||
# new_monotonic_step: bsz, num_heads
|
# new_monotonic_step: bsz, num_heads
|
||||||
alpha = (
|
alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
|
||||||
p_choose
|
1,
|
||||||
.new_zeros([bsz * self.num_heads, src_len])
|
(step_offset + new_monotonic_step)
|
||||||
.scatter(
|
.view(bsz * self.num_heads, 1)
|
||||||
1,
|
.clamp(0, src_len - 1),
|
||||||
(step_offset + new_monotonic_step).view(bsz *
|
1,
|
||||||
self.num_heads, 1).clamp(0, src_len - 1),
|
|
||||||
1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.mass_preservation:
|
if not self.mass_preservation:
|
||||||
alpha = alpha.masked_fill(
|
alpha = alpha.masked_fill(
|
||||||
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1),
|
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
|
||||||
0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
alpha = alpha.unsqueeze(1)
|
alpha = alpha.unsqueeze(1)
|
||||||
@ -266,8 +253,14 @@ class MonotonicAttention(nn.Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, query, key, value,
|
self,
|
||||||
key_padding_mask=None, incremental_state=None, *args, **kwargs,
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_padding_mask=None,
|
||||||
|
incremental_state=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
@ -280,25 +273,24 @@ class MonotonicAttention(nn.Module):
|
|||||||
# expected alignment alpha
|
# expected alignment alpha
|
||||||
# bsz * self.num_heads, tgt_len, src_len
|
# bsz * self.num_heads, tgt_len, src_len
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
alpha = self.expected_alignment_infer(p_choose, key_padding_mask, incremental_state)
|
alpha = self.expected_alignment_infer(
|
||||||
|
p_choose, key_padding_mask, incremental_state
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
|
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
|
||||||
|
|
||||||
# expected attention beta
|
# expected attention beta
|
||||||
# bsz * self.num_heads, tgt_len, src_len
|
# bsz * self.num_heads, tgt_len, src_len
|
||||||
beta = self.expected_attention(alpha, query, key, value, key_padding_mask, incremental_state)
|
beta = self.expected_attention(
|
||||||
|
alpha, query, key, value, key_padding_mask, incremental_state
|
||||||
|
)
|
||||||
|
|
||||||
attn_weights = beta
|
attn_weights = beta
|
||||||
|
|
||||||
v_proj = self.v_proj_output(value)
|
v_proj = self.v_proj_output(value)
|
||||||
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
|
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
|
||||||
|
|
||||||
attn = (
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
attn
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
.view(tgt_len, bsz, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn = self.out_proj(attn)
|
attn = self.out_proj(attn)
|
||||||
|
|
||||||
@ -318,26 +310,32 @@ class MonotonicAttention(nn.Module):
|
|||||||
self._set_monotonic_buffer(incremental_state, input_buffer)
|
self._set_monotonic_buffer(incremental_state, input_buffer)
|
||||||
|
|
||||||
def _get_monotonic_buffer(self, incremental_state):
|
def _get_monotonic_buffer(self, incremental_state):
|
||||||
return utils.get_incremental_state(
|
return (
|
||||||
self,
|
utils.get_incremental_state(
|
||||||
incremental_state,
|
self,
|
||||||
'monotonic',
|
incremental_state,
|
||||||
) or {}
|
"monotonic",
|
||||||
|
)
|
||||||
|
or {}
|
||||||
|
)
|
||||||
|
|
||||||
def _set_monotonic_buffer(self, incremental_state, buffer):
|
def _set_monotonic_buffer(self, incremental_state, buffer):
|
||||||
utils.set_incremental_state(
|
utils.set_incremental_state(
|
||||||
self,
|
self,
|
||||||
incremental_state,
|
incremental_state,
|
||||||
'monotonic',
|
"monotonic",
|
||||||
buffer,
|
buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_pointer(self, incremental_state):
|
def get_pointer(self, incremental_state):
|
||||||
return utils.get_incremental_state(
|
return (
|
||||||
self,
|
utils.get_incremental_state(
|
||||||
incremental_state,
|
self,
|
||||||
'monotonic',
|
incremental_state,
|
||||||
) or {}
|
"monotonic",
|
||||||
|
)
|
||||||
|
or {}
|
||||||
|
)
|
||||||
|
|
||||||
def get_fastest_pointer(self, incremental_state):
|
def get_fastest_pointer(self, incremental_state):
|
||||||
return self.get_pointer(incremental_state)["step"].max(0)[0]
|
return self.get_pointer(incremental_state)["step"].max(0)[0]
|
||||||
@ -354,23 +352,22 @@ class MonotonicAttention(nn.Module):
|
|||||||
utils.set_incremental_state(
|
utils.set_incremental_state(
|
||||||
self,
|
self,
|
||||||
incremental_state,
|
incremental_state,
|
||||||
'monotonic',
|
"monotonic",
|
||||||
{"step": buffer},
|
{"step": buffer},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_monotonic_attention("hard_aligned")
|
@register_monotonic_attention("hard_aligned")
|
||||||
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
MultiheadAttention.__init__(
|
MultiheadAttention.__init__(
|
||||||
self,
|
self,
|
||||||
embed_dim=args.decoder_embed_dim,
|
embed_dim=args.decoder_embed_dim,
|
||||||
num_heads=args.decoder_attention_heads,
|
num_heads=args.decoder_attention_heads,
|
||||||
kdim=getattr(args, 'encoder_embed_dim', None),
|
kdim=getattr(args, "encoder_embed_dim", None),
|
||||||
vdim=getattr(args, 'encoder_embed_dim', None),
|
vdim=getattr(args, "encoder_embed_dim", None),
|
||||||
dropout=args.attention_dropout,
|
dropout=args.attention_dropout,
|
||||||
encoder_decoder_attention=True
|
encoder_decoder_attention=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
MonotonicAttention.__init__(self, args)
|
MonotonicAttention.__init__(self, args)
|
||||||
@ -395,21 +392,33 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
|||||||
bsz = query.size(1)
|
bsz = query.size(1)
|
||||||
q = self.q_in_proj[name](query)
|
q = self.q_in_proj[name](query)
|
||||||
q *= self.scaling
|
q *= self.scaling
|
||||||
q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
q = (
|
||||||
|
q.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
q = None
|
q = None
|
||||||
|
|
||||||
if key is not None:
|
if key is not None:
|
||||||
bsz = key.size(1)
|
bsz = key.size(1)
|
||||||
k = self.k_in_proj[name](key)
|
k = self.k_in_proj[name](key)
|
||||||
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
k = (
|
||||||
|
k.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
k = None
|
k = None
|
||||||
|
|
||||||
if value is not None:
|
if value is not None:
|
||||||
bsz = value.size(1)
|
bsz = value.size(1)
|
||||||
v = self.v_in_proj[name](value)
|
v = self.v_in_proj[name](value)
|
||||||
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
v = None
|
v = None
|
||||||
|
|
||||||
@ -441,8 +450,7 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
|||||||
if self.training:
|
if self.training:
|
||||||
# add noise here to encourage discretness
|
# add noise here to encourage discretness
|
||||||
noise = (
|
noise = (
|
||||||
torch
|
torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
|
||||||
.normal(self.noise_mean, self.noise_var, attn_energy.size())
|
|
||||||
.type_as(attn_energy)
|
.type_as(attn_energy)
|
||||||
.to(attn_energy.device)
|
.to(attn_energy.device)
|
||||||
)
|
)
|
||||||
@ -454,9 +462,9 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
|||||||
return p_choose.view(-1, tgt_len, src_len)
|
return p_choose.view(-1, tgt_len, src_len)
|
||||||
|
|
||||||
def expected_attention(self, alpha, *args):
|
def expected_attention(self, alpha, *args):
|
||||||
'''
|
"""
|
||||||
For MMA-H, beta = alpha
|
For MMA-H, beta = alpha
|
||||||
'''
|
"""
|
||||||
return alpha
|
return alpha
|
||||||
|
|
||||||
def v_proj_output(self, value):
|
def v_proj_output(self, value):
|
||||||
@ -479,13 +487,19 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
|
|||||||
if self.qkv_same_dim:
|
if self.qkv_same_dim:
|
||||||
# Empirically observed the convergence to be much better with
|
# Empirically observed the convergence to be much better with
|
||||||
# the scaled initialization
|
# the scaled initialization
|
||||||
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2))
|
nn.init.xavier_uniform_(
|
||||||
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2))
|
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
|
||||||
|
)
|
||||||
|
nn.init.xavier_uniform_(
|
||||||
|
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
|
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
|
||||||
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
|
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
|
||||||
|
|
||||||
def expected_attention(self, alpha, query, key, value, key_padding_mask, incremental_state):
|
def expected_attention(
|
||||||
|
self, alpha, query, key, value, key_padding_mask, incremental_state
|
||||||
|
):
|
||||||
# monotonic attention, we will calculate milk here
|
# monotonic attention, we will calculate milk here
|
||||||
bsz_x_num_heads, tgt_len, src_len = alpha.size()
|
bsz_x_num_heads, tgt_len, src_len = alpha.size()
|
||||||
bsz = int(bsz_x_num_heads / self.num_heads)
|
bsz = int(bsz_x_num_heads / self.num_heads)
|
||||||
@ -507,9 +521,10 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
|
|||||||
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
||||||
monotonic_step += step_offset
|
monotonic_step += step_offset
|
||||||
mask = lengths_to_mask(
|
mask = lengths_to_mask(
|
||||||
monotonic_step.view(-1), soft_energy.size(2), 1).unsqueeze(1)
|
monotonic_step.view(-1), soft_energy.size(2), 1
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
soft_energy = soft_energy.masked_fill(~ mask.bool(), float('-inf'))
|
soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
|
||||||
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
|
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
|
||||||
exp_soft_energy = torch.exp(soft_energy)
|
exp_soft_energy = torch.exp(soft_energy)
|
||||||
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
|
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
|
||||||
@ -524,14 +539,20 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
|
|||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
if key_padding_mask.any():
|
if key_padding_mask.any():
|
||||||
exp_soft_energy_cumsum = (
|
exp_soft_energy_cumsum = (
|
||||||
exp_soft_energy_cumsum.view(-1, self.num_heads, tgt_len, src_len)
|
exp_soft_energy_cumsum.view(
|
||||||
.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps)
|
-1, self.num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
|
||||||
|
)
|
||||||
.view(-1, tgt_len, src_len)
|
.view(-1, tgt_len, src_len)
|
||||||
)
|
)
|
||||||
|
|
||||||
inner_items = alpha / exp_soft_energy_cumsum
|
inner_items = alpha / exp_soft_energy_cumsum
|
||||||
|
|
||||||
beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2])
|
beta = exp_soft_energy * torch.cumsum(
|
||||||
|
inner_items.flip(dims=[2]), dim=2
|
||||||
|
).flip(dims=[2])
|
||||||
|
|
||||||
beta = self.dropout_module(beta)
|
beta = self.dropout_module(beta)
|
||||||
|
|
||||||
@ -547,7 +568,9 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
|
|||||||
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
|
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
|
||||||
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
|
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
|
||||||
self.waitk_lagging = args.waitk_lagging
|
self.waitk_lagging = args.waitk_lagging
|
||||||
assert self.waitk_lagging > 0, f"Lagging has to been larger than 0, get {self.waitk_lagging}."
|
assert (
|
||||||
|
self.waitk_lagging > 0
|
||||||
|
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
@ -556,10 +579,13 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
|
|||||||
MonotonicMultiheadAttentionWaitk,
|
MonotonicMultiheadAttentionWaitk,
|
||||||
).add_args(parser)
|
).add_args(parser)
|
||||||
|
|
||||||
parser.add_argument('--waitk-lagging', type=int, required=True,
|
parser.add_argument(
|
||||||
help='Wait k lagging')
|
"--waitk-lagging", type=int, required=True, help="Wait k lagging"
|
||||||
|
)
|
||||||
|
|
||||||
def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None):
|
def p_choose(
|
||||||
|
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
query: bsz, tgt_len
|
query: bsz, tgt_len
|
||||||
key: bsz, src_len
|
key: bsz, src_len
|
||||||
@ -574,16 +600,22 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
|
|||||||
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
|
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
|
||||||
# Left pad source
|
# Left pad source
|
||||||
# add -1 to the end
|
# add -1 to the end
|
||||||
p_choose = p_choose.masked_fill(key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1)
|
p_choose = p_choose.masked_fill(
|
||||||
p_choose = convert_padding_direction(p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True)
|
key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
|
||||||
|
)
|
||||||
|
p_choose = convert_padding_direction(
|
||||||
|
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
|
||||||
|
)
|
||||||
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
|
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
|
||||||
# remove -1
|
# remove -1
|
||||||
p_choose[p_choose.eq(-1)] = 0
|
p_choose[p_choose.eq(-1)] = 0
|
||||||
|
|
||||||
# Extend to each head
|
# Extend to each head
|
||||||
p_choose = (
|
p_choose = (
|
||||||
p_choose.contiguous().unsqueeze(1)
|
p_choose.contiguous()
|
||||||
.expand(-1, self.num_heads, -1, -1).contiguous()
|
.unsqueeze(1)
|
||||||
|
.expand(-1, self.num_heads, -1, -1)
|
||||||
|
.contiguous()
|
||||||
.view(-1, tgt_len, src_len)
|
.view(-1, tgt_len, src_len)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,37 +3,32 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from fairseq.modules import (
|
from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
|
||||||
LayerNorm,
|
|
||||||
TransformerEncoderLayer,
|
|
||||||
TransformerDecoderLayer
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import build_monotonic_attention
|
from . import build_monotonic_attention
|
||||||
|
|
||||||
|
|
||||||
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
|
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
|
||||||
|
|
||||||
def forward(self, x, encoder_padding_mask):
|
def forward(self, x, encoder_padding_mask):
|
||||||
seq_len, _, _ = x.size()
|
seq_len, _, _ = x.size()
|
||||||
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
|
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float('-inf'))
|
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
|
||||||
return super().forward(x, encoder_padding_mask, attn_mask)
|
return super().forward(x, encoder_padding_mask, attn_mask)
|
||||||
|
|
||||||
|
|
||||||
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
|
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
|
||||||
|
def __init__(
|
||||||
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
|
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
args,
|
args,
|
||||||
no_encoder_attn=True,
|
no_encoder_attn=True,
|
||||||
add_bias_kv=add_bias_kv,
|
add_bias_kv=add_bias_kv,
|
||||||
add_zero_attn=add_zero_attn
|
add_zero_attn=add_zero_attn,
|
||||||
)
|
)
|
||||||
self.encoder_attn = build_monotonic_attention(args)
|
self.encoder_attn = build_monotonic_attention(args)
|
||||||
self.encoder_attn_layer_norm = LayerNorm(
|
self.encoder_attn_layer_norm = LayerNorm(
|
||||||
self.embed_dim,
|
self.embed_dim, export=getattr(args, "char_inputs", False)
|
||||||
export=getattr(args, 'char_inputs', False)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prune_incremental_state(self, incremental_state):
|
def prune_incremental_state(self, incremental_state):
|
||||||
@ -46,12 +41,8 @@ class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
|
|||||||
input_buffer = {}
|
input_buffer = {}
|
||||||
break
|
break
|
||||||
module._set_input_buffer(incremental_state, input_buffer)
|
module._set_input_buffer(incremental_state, input_buffer)
|
||||||
|
|
||||||
prune(self.self_attn)
|
prune(self.self_attn)
|
||||||
|
|
||||||
def get_steps(self, incremental_state):
|
def get_steps(self, incremental_state):
|
||||||
return (
|
return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
|
||||||
self.encoder_attn
|
|
||||||
._get_monotonic_buffer(
|
|
||||||
incremental_state
|
|
||||||
).get("step", 0)
|
|
||||||
)
|
|
||||||
|
@ -9,6 +9,6 @@ import os
|
|||||||
|
|
||||||
# automatically import any Python files in the criterions/ directory
|
# automatically import any Python files in the criterions/ directory
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
module = file[:file.find('.py')]
|
module = file[: file.find(".py")]
|
||||||
importlib.import_module('examples.simultaneous_translation.utils.' + module)
|
importlib.import_module("examples.simultaneous_translation.utils." + module)
|
||||||
|
@ -16,7 +16,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
|
|||||||
tensor_size = list(tensor.size())
|
tensor_size = list(tensor.size())
|
||||||
tensor_size[dim] = 1
|
tensor_size[dim] = 1
|
||||||
return_tensor = safe_cumprod(
|
return_tensor = safe_cumprod(
|
||||||
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps
|
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
|
||||||
|
dim=dim,
|
||||||
|
eps=eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
@ -132,12 +134,14 @@ def moving_sum(x, start_idx: int, end_idx: int):
|
|||||||
# batch_size, 1, src_len
|
# batch_size, 1, src_len
|
||||||
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
|
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
|
||||||
|
|
||||||
moving_sum = torch.nn.functional.conv1d(
|
moving_sum = (
|
||||||
x,
|
torch.nn.functional.conv1d(
|
||||||
moving_sum_weight,
|
x, moving_sum_weight, padding=start_idx + end_idx - 1
|
||||||
padding=start_idx + end_idx - 1
|
)
|
||||||
).squeeze(1).t()
|
.squeeze(1)
|
||||||
moving_sum = moving_sum[end_idx: -start_idx]
|
.t()
|
||||||
|
)
|
||||||
|
moving_sum = moving_sum[end_idx:-start_idx]
|
||||||
|
|
||||||
assert src_len == moving_sum.size(0)
|
assert src_len == moving_sum.size(0)
|
||||||
assert batch_size == moving_sum.size(1)
|
assert batch_size == moving_sum.size(1)
|
||||||
|
@ -18,7 +18,7 @@ class LatencyMetric(object):
|
|||||||
src_lens,
|
src_lens,
|
||||||
target_padding_mask=None,
|
target_padding_mask=None,
|
||||||
batch_first: bool = False,
|
batch_first: bool = False,
|
||||||
start_from_zero: bool = True
|
start_from_zero: bool = True,
|
||||||
):
|
):
|
||||||
assert len(delays.size()) == 2
|
assert len(delays.size()) == 2
|
||||||
assert len(src_lens.size()) == 2
|
assert len(src_lens.size()) == 2
|
||||||
@ -59,11 +59,7 @@ class LatencyMetric(object):
|
|||||||
start_from_zero: bool = True,
|
start_from_zero: bool = True,
|
||||||
):
|
):
|
||||||
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
|
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
|
||||||
delays,
|
delays, src_lens, target_padding_mask, batch_first, start_from_zero
|
||||||
src_lens,
|
|
||||||
target_padding_mask,
|
|
||||||
batch_first,
|
|
||||||
start_from_zero
|
|
||||||
)
|
)
|
||||||
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
|
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
|
||||||
|
|
||||||
@ -89,10 +85,13 @@ class AverageProportion(LatencyMetric):
|
|||||||
|
|
||||||
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
|
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
||||||
if target_padding_mask is not None:
|
if target_padding_mask is not None:
|
||||||
AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True)
|
AP = torch.sum(
|
||||||
|
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
AP = torch.sum(delays, dim=0, keepdim=True)
|
AP = torch.sum(delays, dim=0, keepdim=True)
|
||||||
|
|
||||||
@ -116,14 +115,24 @@ class AverageLagging(LatencyMetric):
|
|||||||
gamma = |y| / |x|
|
gamma = |y| / |x|
|
||||||
tau = argmin_i(delays_i = |x|)
|
tau = argmin_i(delays_i = |x|)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
||||||
# tau = argmin_i(delays_i = |x|)
|
# tau = argmin_i(delays_i = |x|)
|
||||||
tgt_len, bsz = delays.size()
|
tgt_len, bsz = delays.size()
|
||||||
lagging_padding_mask = delays >= src_lens
|
lagging_padding_mask = delays >= src_lens
|
||||||
lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :]
|
lagging_padding_mask = torch.nn.functional.pad(
|
||||||
|
lagging_padding_mask.t(), (1, 0)
|
||||||
|
).t()[:-1, :]
|
||||||
gamma = tgt_lens / src_lens
|
gamma = tgt_lens / src_lens
|
||||||
lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
|
lagging = (
|
||||||
|
delays
|
||||||
|
- torch.arange(delays.size(0))
|
||||||
|
.unsqueeze(1)
|
||||||
|
.type_as(delays)
|
||||||
|
.expand_as(delays)
|
||||||
|
/ gamma
|
||||||
|
)
|
||||||
lagging.masked_fill_(lagging_padding_mask, 0)
|
lagging.masked_fill_(lagging_padding_mask, 0)
|
||||||
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
|
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
|
||||||
AL = lagging.sum(dim=0, keepdim=True) / tau
|
AL = lagging.sum(dim=0, keepdim=True) / tau
|
||||||
@ -149,6 +158,7 @@ class DifferentiableAverageLagging(LatencyMetric):
|
|||||||
2. max(delays_i, delays'_{i-1} + 1 / gamma)
|
2. max(delays_i, delays'_{i-1} + 1 / gamma)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
||||||
tgt_len, bsz = delays.size()
|
tgt_len, bsz = delays.size()
|
||||||
@ -163,13 +173,18 @@ class DifferentiableAverageLagging(LatencyMetric):
|
|||||||
new_delays[i] = torch.cat(
|
new_delays[i] = torch.cat(
|
||||||
[
|
[
|
||||||
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
|
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
|
||||||
delays[i].unsqueeze(0)
|
delays[i].unsqueeze(0),
|
||||||
],
|
],
|
||||||
dim=0
|
dim=0,
|
||||||
).max(dim=0)[0]
|
).max(dim=0)[0]
|
||||||
|
|
||||||
DAL = (
|
DAL = (
|
||||||
new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
|
new_delays
|
||||||
|
- torch.arange(delays.size(0))
|
||||||
|
.unsqueeze(1)
|
||||||
|
.type_as(delays)
|
||||||
|
.expand_as(delays)
|
||||||
|
/ gamma
|
||||||
)
|
)
|
||||||
if target_padding_mask is not None:
|
if target_padding_mask is not None:
|
||||||
DAL = DAL.masked_fill(target_padding_mask, 0)
|
DAL = DAL.masked_fill(target_padding_mask, 0)
|
||||||
@ -186,7 +201,7 @@ class LatencyMetricVariance(LatencyMetric):
|
|||||||
src_lens,
|
src_lens,
|
||||||
target_padding_mask=None,
|
target_padding_mask=None,
|
||||||
batch_first: bool = True,
|
batch_first: bool = True,
|
||||||
start_from_zero: bool = True
|
start_from_zero: bool = True,
|
||||||
):
|
):
|
||||||
assert batch_first
|
assert batch_first
|
||||||
assert len(delays.size()) == 3
|
assert len(delays.size()) == 3
|
||||||
@ -256,25 +271,21 @@ class LatencyInference(object):
|
|||||||
|
|
||||||
src_lens = src_lens
|
src_lens = src_lens
|
||||||
|
|
||||||
delays = (
|
delays = monotonic_step.view(
|
||||||
monotonic_step
|
monotonic_step.size(0), -1, monotonic_step.size(-1)
|
||||||
.view(monotonic_step.size(0), -1, monotonic_step.size(-1))
|
).max(dim=1)[0]
|
||||||
.max(dim=1)[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
delays = (
|
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
|
||||||
delays.masked_fill(delays >= src_lens, 0)
|
delays
|
||||||
+ (src_lens - 1)
|
).masked_fill(delays < src_lens, 0)
|
||||||
.expand_as(delays)
|
|
||||||
.masked_fill(delays < src_lens, 0)
|
|
||||||
)
|
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
for key, func in self.metric_calculator.items():
|
for key, func in self.metric_calculator.items():
|
||||||
return_dict[key] = func(
|
return_dict[key] = func(
|
||||||
delays.float(), src_lens.float(),
|
delays.float(),
|
||||||
|
src_lens.float(),
|
||||||
target_padding_mask=None,
|
target_padding_mask=None,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
start_from_zero=True
|
start_from_zero=True,
|
||||||
).t()
|
).t()
|
||||||
|
|
||||||
return return_dict
|
return return_dict
|
||||||
@ -282,8 +293,13 @@ class LatencyInference(object):
|
|||||||
|
|
||||||
class LatencyTraining(object):
|
class LatencyTraining(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, avg_weight, var_weight, avg_type, var_type,
|
self,
|
||||||
stay_on_last_token, average_method,
|
avg_weight,
|
||||||
|
var_weight,
|
||||||
|
avg_type,
|
||||||
|
var_type,
|
||||||
|
stay_on_last_token,
|
||||||
|
average_method,
|
||||||
):
|
):
|
||||||
self.avg_weight = avg_weight
|
self.avg_weight = avg_weight
|
||||||
self.var_weight = var_weight
|
self.var_weight = var_weight
|
||||||
@ -319,17 +335,12 @@ class LatencyTraining(object):
|
|||||||
attention = attention.view(-1, tgt_len, src_len)
|
attention = attention.view(-1, tgt_len, src_len)
|
||||||
|
|
||||||
if not self.stay_on_last_token:
|
if not self.stay_on_last_token:
|
||||||
residual_attention = \
|
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
|
||||||
1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
|
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
|
||||||
attention = torch.cat(
|
|
||||||
[attention[:, :, :-1], residual_attention],
|
|
||||||
dim=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
|
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
|
||||||
steps = (
|
steps = (
|
||||||
torch
|
torch.arange(1, 1 + src_len)
|
||||||
.arange(1, 1 + src_len)
|
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.expand_as(attention)
|
.expand_as(attention)
|
||||||
@ -355,15 +366,12 @@ class LatencyTraining(object):
|
|||||||
src_lens = src_lens.view(-1, 1)
|
src_lens = src_lens.view(-1, 1)
|
||||||
|
|
||||||
# bsz * num_heads_num_layers, tgt_len, src_len
|
# bsz * num_heads_num_layers, tgt_len, src_len
|
||||||
expected_delays = (steps * attention).sum(dim=2).view(
|
expected_delays = (
|
||||||
bsz, num_heads_x_layers, tgt_len
|
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
|
||||||
)
|
)
|
||||||
|
|
||||||
if target_padding_mask is not None:
|
if target_padding_mask is not None:
|
||||||
expected_delays.masked_fill_(
|
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
|
||||||
target_padding_mask.unsqueeze(1),
|
|
||||||
0
|
|
||||||
)
|
|
||||||
|
|
||||||
return expected_delays, src_lens
|
return expected_delays, src_lens
|
||||||
|
|
||||||
@ -371,8 +379,7 @@ class LatencyTraining(object):
|
|||||||
|
|
||||||
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
|
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
|
||||||
target_padding_mask = (
|
target_padding_mask = (
|
||||||
target_padding_mask
|
target_padding_mask.unsqueeze(1)
|
||||||
.unsqueeze(1)
|
|
||||||
.expand_as(expected_delays)
|
.expand_as(expected_delays)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
.view(-1, tgt_len)
|
.view(-1, tgt_len)
|
||||||
@ -396,8 +403,11 @@ class LatencyTraining(object):
|
|||||||
if self.avg_weight > 0.0:
|
if self.avg_weight > 0.0:
|
||||||
if self.avg_type in self.metric_calculator:
|
if self.avg_type in self.metric_calculator:
|
||||||
average_delays = self.metric_calculator[self.avg_type](
|
average_delays = self.metric_calculator[self.avg_type](
|
||||||
expected_delays, src_lens, target_padding_mask,
|
expected_delays,
|
||||||
batch_first=True, start_from_zero=False
|
src_lens,
|
||||||
|
target_padding_mask,
|
||||||
|
batch_first=True,
|
||||||
|
start_from_zero=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"{self.avg_type} is not supported.")
|
raise RuntimeError(f"{self.avg_type} is not supported.")
|
||||||
@ -408,12 +418,17 @@ class LatencyTraining(object):
|
|||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def var_loss(self, expected_delays, src_lens, target_padding_mask):
|
def var_loss(self, expected_delays, src_lens, target_padding_mask):
|
||||||
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1]
|
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
|
||||||
|
:, :1
|
||||||
|
]
|
||||||
if self.var_weight > 0.0:
|
if self.var_weight > 0.0:
|
||||||
if self.var_type in self.variance_calculator:
|
if self.var_type in self.variance_calculator:
|
||||||
variance_delays = self.variance_calculator[self.var_type](
|
variance_delays = self.variance_calculator[self.var_type](
|
||||||
expected_delays, src_lens, target_padding_mask,
|
expected_delays,
|
||||||
batch_first=True, start_from_zero=False
|
src_lens,
|
||||||
|
target_padding_mask,
|
||||||
|
batch_first=True,
|
||||||
|
start_from_zero=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"{self.var_type} is not supported.")
|
raise RuntimeError(f"{self.var_type} is not supported.")
|
||||||
|
@ -1 +1 @@
|
|||||||
from . import tasks, criterions, models # noqa
|
from . import criterions, models, tasks # noqa
|
||||||
|
@ -6,9 +6,9 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from examples.speech_recognition.data.replabels import pack_replabels
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||||
from examples.speech_recognition.data.replabels import pack_replabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_criterion("asg_loss")
|
@register_criterion("asg_loss")
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
from .asr_dataset import AsrDataset
|
from .asr_dataset import AsrDataset
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AsrDataset',
|
"AsrDataset",
|
||||||
]
|
]
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fairseq.data import FairseqDataset
|
from fairseq.data import FairseqDataset
|
||||||
|
|
||||||
@ -30,16 +31,22 @@ class AsrDataset(FairseqDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, aud_paths, aud_durations_ms, tgt,
|
self,
|
||||||
tgt_dict, ids, speakers,
|
aud_paths,
|
||||||
num_mel_bins=80, frame_length=25.0, frame_shift=10.0
|
aud_durations_ms,
|
||||||
|
tgt,
|
||||||
|
tgt_dict,
|
||||||
|
ids,
|
||||||
|
speakers,
|
||||||
|
num_mel_bins=80,
|
||||||
|
frame_length=25.0,
|
||||||
|
frame_shift=10.0,
|
||||||
):
|
):
|
||||||
assert frame_length > 0
|
assert frame_length > 0
|
||||||
assert frame_shift > 0
|
assert frame_shift > 0
|
||||||
assert all(x > frame_length for x in aud_durations_ms)
|
assert all(x > frame_length for x in aud_durations_ms)
|
||||||
self.frame_sizes = [
|
self.frame_sizes = [
|
||||||
int(1 + (d - frame_length) / frame_shift)
|
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
|
||||||
for d in aud_durations_ms
|
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(aud_paths) > 0
|
assert len(aud_paths) > 0
|
||||||
@ -57,13 +64,17 @@ class AsrDataset(FairseqDataset):
|
|||||||
self.frame_shift = frame_shift
|
self.frame_shift = frame_shift
|
||||||
|
|
||||||
self.s2s_collater = Seq2SeqCollater(
|
self.s2s_collater = Seq2SeqCollater(
|
||||||
0, 1, pad_index=self.tgt_dict.pad(),
|
0,
|
||||||
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
|
1,
|
||||||
|
pad_index=self.tgt_dict.pad(),
|
||||||
|
eos_index=self.tgt_dict.eos(),
|
||||||
|
move_eos_to_beginning=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
tgt_item = self.tgt[index] if self.tgt is not None else None
|
tgt_item = self.tgt[index] if self.tgt is not None else None
|
||||||
|
|
||||||
path = self.aud_paths[index]
|
path = self.aud_paths[index]
|
||||||
@ -74,7 +85,7 @@ class AsrDataset(FairseqDataset):
|
|||||||
sound,
|
sound,
|
||||||
num_mel_bins=self.num_mel_bins,
|
num_mel_bins=self.num_mel_bins,
|
||||||
frame_length=self.frame_length,
|
frame_length=self.frame_length,
|
||||||
frame_shift=self.frame_shift
|
frame_shift=self.frame_shift,
|
||||||
)
|
)
|
||||||
output_cmvn = data_utils.apply_mv_norm(output)
|
output_cmvn = data_utils.apply_mv_norm(output)
|
||||||
|
|
||||||
|
@ -12,18 +12,18 @@
|
|||||||
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fairseq.data import data_utils as fairseq_data_utils
|
from fairseq.data import data_utils as fairseq_data_utils
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqCollater(object):
|
class Seq2SeqCollater(object):
|
||||||
"""
|
"""
|
||||||
Implements collate function mainly for seq2seq tasks
|
Implements collate function mainly for seq2seq tasks
|
||||||
This expects each sample to contain feature (src_tokens) and
|
This expects each sample to contain feature (src_tokens) and
|
||||||
targets.
|
targets.
|
||||||
This collator is also used for aligned training task.
|
This collator is also used for aligned training task.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -6,52 +6,74 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
import concurrent.futures
|
|
||||||
from itertools import chain
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import sentencepiece as spm
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
from fairseq.data import Dictionary
|
from fairseq.data import Dictionary
|
||||||
|
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS = 0.001
|
MILLISECONDS_TO_SECONDS = 0.001
|
||||||
|
|
||||||
|
|
||||||
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
input = {}
|
input = {}
|
||||||
output = {}
|
output = {}
|
||||||
si, ei = torchaudio.info(aud_path)
|
si, ei = torchaudio.info(aud_path)
|
||||||
input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS)
|
input["length_ms"] = int(
|
||||||
|
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
|
||||||
|
)
|
||||||
input["path"] = aud_path
|
input["path"] = aud_path
|
||||||
|
|
||||||
token = " ".join(sp.EncodeAsPieces(lable))
|
token = " ".join(sp.EncodeAsPieces(lable))
|
||||||
ids = tgt_dict.encode_line(token, append_eos=False)
|
ids = tgt_dict.encode_line(token, append_eos=False)
|
||||||
output["text"] = lable
|
output["text"] = lable
|
||||||
output["token"] = token
|
output["token"] = token
|
||||||
output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids]))
|
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
|
||||||
return {utt_id: {"input": input, "output": output}}
|
return {utt_id: {"input": input, "output": output}}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True,
|
parser.add_argument(
|
||||||
help="input directories with audio files")
|
"--audio-dirs",
|
||||||
parser.add_argument("--labels", required=True,
|
nargs="+",
|
||||||
help="aggregated input labels with format <ID LABEL> per line",
|
default=["-"],
|
||||||
type=argparse.FileType('r', encoding='UTF-8'))
|
required=True,
|
||||||
parser.add_argument("--spm-model", required=True,
|
help="input directories with audio files",
|
||||||
help="sentencepiece model to use for encoding",
|
)
|
||||||
type=argparse.FileType('r', encoding='UTF-8'))
|
parser.add_argument(
|
||||||
parser.add_argument("--dictionary", required=True,
|
"--labels",
|
||||||
help="file to load fairseq dictionary from",
|
required=True,
|
||||||
type=argparse.FileType('r', encoding='UTF-8'))
|
help="aggregated input labels with format <ID LABEL> per line",
|
||||||
|
type=argparse.FileType("r", encoding="UTF-8"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--spm-model",
|
||||||
|
required=True,
|
||||||
|
help="sentencepiece model to use for encoding",
|
||||||
|
type=argparse.FileType("r", encoding="UTF-8"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dictionary",
|
||||||
|
required=True,
|
||||||
|
help="file to load fairseq dictionary from",
|
||||||
|
type=argparse.FileType("r", encoding="UTF-8"),
|
||||||
|
)
|
||||||
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
||||||
parser.add_argument("--output", required=True, type=argparse.FileType('w'),
|
parser.add_argument(
|
||||||
help="path to save json output")
|
"--output",
|
||||||
|
required=True,
|
||||||
|
type=argparse.FileType("w"),
|
||||||
|
help="path to save json output",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
@ -64,15 +86,17 @@ def main():
|
|||||||
(utt_id, label) = line.split(" ", 1)
|
(utt_id, label) = line.split(" ", 1)
|
||||||
labels[utt_id] = label
|
labels[utt_id] = label
|
||||||
if len(labels) == 0:
|
if len(labels) == 0:
|
||||||
raise Exception('No labels found in ', args.labels_path)
|
raise Exception("No labels found in ", args.labels_path)
|
||||||
|
|
||||||
Sample = namedtuple('Sample', 'aud_path utt_id')
|
Sample = namedtuple("Sample", "aud_path utt_id")
|
||||||
samples = []
|
samples = []
|
||||||
for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs):
|
for path, _, files in chain.from_iterable(
|
||||||
|
os.walk(path) for path in args.audio_dirs
|
||||||
|
):
|
||||||
for f in files:
|
for f in files:
|
||||||
if f.endswith(args.audio_format):
|
if f.endswith(args.audio_format):
|
||||||
if len(os.path.splitext(f)) != 2:
|
if len(os.path.splitext(f)) != 2:
|
||||||
raise Exception('Expect <utt_id.extension> file name. Got: ', f)
|
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
|
||||||
utt_id = os.path.splitext(f)[0]
|
utt_id = os.path.splitext(f)[0]
|
||||||
if utt_id not in labels:
|
if utt_id not in labels:
|
||||||
continue
|
continue
|
||||||
@ -81,12 +105,17 @@ def main():
|
|||||||
utts = {}
|
utts = {}
|
||||||
num_cpu = multiprocessing.cpu_count()
|
num_cpu = multiprocessing.cpu_count()
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
||||||
future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples}
|
future_to_sample = {
|
||||||
|
executor.submit(
|
||||||
|
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
|
||||||
|
): s
|
||||||
|
for s in samples
|
||||||
|
}
|
||||||
for future in concurrent.futures.as_completed(future_to_sample):
|
for future in concurrent.futures.as_completed(future_to_sample):
|
||||||
try:
|
try:
|
||||||
data = future.result()
|
data = future.result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('generated an exception: ', exc)
|
print("generated an exception: ", exc)
|
||||||
else:
|
else:
|
||||||
utts.update(data)
|
utts.update(data)
|
||||||
json.dump({"utts": utts}, args.output, indent=4)
|
json.dump({"utts": utts}, args.output, indent=4)
|
||||||
|
@ -8,17 +8,17 @@
|
|||||||
Run inference for pre-processed data with a trained model.
|
Run inference for pre-processed data with a trained model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import editdistance
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import editdistance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
|
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
||||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
|
||||||
from fairseq.data.data_utils import post_process
|
from fairseq.data.data_utils import post_process
|
||||||
|
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
@ -52,10 +52,12 @@ output units",
|
|||||||
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder"
|
"--w2l-decoder",
|
||||||
|
choices=["viterbi", "kenlm", "fairseqlm"],
|
||||||
|
help="use a w2l decoder",
|
||||||
)
|
)
|
||||||
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
|
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
|
||||||
parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm")
|
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
|
||||||
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
|
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
|
||||||
parser.add_argument("--beam-threshold", type=float, default=25.0)
|
parser.add_argument("--beam-threshold", type=float, default=25.0)
|
||||||
parser.add_argument("--beam-size-token", type=float, default=100)
|
parser.add_argument("--beam-size-token", type=float, default=100)
|
||||||
@ -87,10 +89,10 @@ def check_args(args):
|
|||||||
# assert args.path is not None, "--path required for generation!"
|
# assert args.path is not None, "--path required for generation!"
|
||||||
# assert args.results_path is not None, "--results_path required for generation!"
|
# assert args.results_path is not None, "--results_path required for generation!"
|
||||||
assert (
|
assert (
|
||||||
not args.sampling or args.nbest == args.beam
|
not args.sampling or args.nbest == args.beam
|
||||||
), "--sampling requires --nbest to be equal to --beam"
|
), "--sampling requires --nbest to be equal to --beam"
|
||||||
assert (
|
assert (
|
||||||
args.replace_unk is None or args.raw_text
|
args.replace_unk is None or args.raw_text
|
||||||
), "--replace-unk requires a raw text dataset (--raw-text)"
|
), "--replace-unk requires a raw text dataset (--raw-text)"
|
||||||
|
|
||||||
|
|
||||||
@ -110,7 +112,7 @@ def get_dataset_itr(args, task, models):
|
|||||||
|
|
||||||
|
|
||||||
def process_predictions(
|
def process_predictions(
|
||||||
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
||||||
):
|
):
|
||||||
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
||||||
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
||||||
@ -122,16 +124,25 @@ def process_predictions(
|
|||||||
|
|
||||||
if res_files is not None:
|
if res_files is not None:
|
||||||
print(
|
print(
|
||||||
"{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
|
"{} ({}-{})".format(hyp_pieces, speaker, id),
|
||||||
|
file=res_files["hypo.units"],
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{} ({}-{})".format(hyp_words, speaker, id),
|
||||||
|
file=res_files["hypo.words"],
|
||||||
)
|
)
|
||||||
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
|
|
||||||
|
|
||||||
tgt_pieces = tgt_dict.string(target_tokens)
|
tgt_pieces = tgt_dict.string(target_tokens)
|
||||||
tgt_words = post_process(tgt_pieces, args.remove_bpe)
|
tgt_words = post_process(tgt_pieces, args.remove_bpe)
|
||||||
|
|
||||||
if res_files is not None:
|
if res_files is not None:
|
||||||
print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
|
print(
|
||||||
print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
|
"{} ({}-{})".format(tgt_pieces, speaker, id),
|
||||||
|
file=res_files["ref.units"],
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
|
||||||
|
)
|
||||||
# only score top hypothesis
|
# only score top hypothesis
|
||||||
if not args.quiet:
|
if not args.quiet:
|
||||||
logger.debug("HYPO:" + hyp_words)
|
logger.debug("HYPO:" + hyp_words)
|
||||||
@ -146,7 +157,7 @@ def process_predictions(
|
|||||||
def prepare_result_files(args):
|
def prepare_result_files(args):
|
||||||
def get_res_file(file_prefix):
|
def get_res_file(file_prefix):
|
||||||
if args.num_shards > 1:
|
if args.num_shards > 1:
|
||||||
file_prefix = f'{args.shard_id}_{file_prefix}'
|
file_prefix = f"{args.shard_id}_{file_prefix}"
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
args.results_path,
|
args.results_path,
|
||||||
"{}-{}-{}.txt".format(
|
"{}-{}-{}.txt".format(
|
||||||
@ -166,15 +177,17 @@ def prepare_result_files(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None):
|
def load_models_and_criterions(
|
||||||
|
filenames, data_path, arg_overrides=None, task=None, model_state=None
|
||||||
|
):
|
||||||
models = []
|
models = []
|
||||||
criterions = []
|
criterions = []
|
||||||
|
|
||||||
if arg_overrides is None:
|
if arg_overrides is None:
|
||||||
arg_overrides = {}
|
arg_overrides = {}
|
||||||
|
|
||||||
arg_overrides['wer_args'] = None
|
arg_overrides["wer_args"] = None
|
||||||
arg_overrides['data'] = data_path
|
arg_overrides["data"] = data_path
|
||||||
|
|
||||||
if filenames is None:
|
if filenames is None:
|
||||||
assert model_state is not None
|
assert model_state is not None
|
||||||
@ -205,8 +218,7 @@ def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=No
|
|||||||
|
|
||||||
|
|
||||||
def optimize_models(args, use_cuda, models):
|
def optimize_models(args, use_cuda, models):
|
||||||
"""Optimize ensemble for generation
|
"""Optimize ensemble for generation"""
|
||||||
"""
|
|
||||||
for model in models:
|
for model in models:
|
||||||
model.make_generation_fast_(
|
model.make_generation_fast_(
|
||||||
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
||||||
@ -229,7 +241,7 @@ class ExistingEmissionsDecoder(object):
|
|||||||
emissions = np.stack(self.emissions[ids])
|
emissions = np.stack(self.emissions[ids])
|
||||||
except:
|
except:
|
||||||
print([x.shape for x in self.emissions[ids]])
|
print([x.shape for x in self.emissions[ids]])
|
||||||
raise Exception('invalid sizes')
|
raise Exception("invalid sizes")
|
||||||
emissions = torch.from_numpy(emissions)
|
emissions = torch.from_numpy(emissions)
|
||||||
return self.decoder.decode(emissions)
|
return self.decoder.decode(emissions)
|
||||||
|
|
||||||
@ -300,7 +312,9 @@ def main(args, task=None, model_state=None):
|
|||||||
|
|
||||||
return W2lFairseqLMDecoder(args, task.target_dictionary)
|
return W2lFairseqLMDecoder(args, task.target_dictionary)
|
||||||
else:
|
else:
|
||||||
print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment')
|
print(
|
||||||
|
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
|
||||||
|
)
|
||||||
|
|
||||||
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
|
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
|
||||||
generator = build_generator(args)
|
generator = build_generator(args)
|
||||||
@ -361,7 +375,11 @@ def main(args, task=None, model_state=None):
|
|||||||
encoder_out = models[0](**sample["net_input"])
|
encoder_out = models[0](**sample["net_input"])
|
||||||
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
|
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
|
||||||
for i, id in enumerate(sample["id"]):
|
for i, id in enumerate(sample["id"]):
|
||||||
padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None
|
padding = (
|
||||||
|
encoder_out["encoder_padding_mask"][i].cpu().numpy()
|
||||||
|
if encoder_out["encoder_padding_mask"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
features[id.item()] = (feat[i], padding)
|
features[id.item()] = (feat[i], padding)
|
||||||
continue
|
continue
|
||||||
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
||||||
@ -372,20 +390,31 @@ def main(args, task=None, model_state=None):
|
|||||||
speaker = None
|
speaker = None
|
||||||
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
||||||
id = sample_id
|
id = sample_id
|
||||||
toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :]
|
toks = (
|
||||||
target_tokens = (
|
sample["target"][i, :]
|
||||||
utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
if "target_label" not in sample
|
||||||
|
else sample["target_label"][i, :]
|
||||||
)
|
)
|
||||||
|
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
||||||
# Process top predictions
|
# Process top predictions
|
||||||
errs, length = process_predictions(
|
errs, length = process_predictions(
|
||||||
args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id
|
args,
|
||||||
|
hypos[i],
|
||||||
|
None,
|
||||||
|
tgt_dict,
|
||||||
|
target_tokens,
|
||||||
|
res_files,
|
||||||
|
speaker,
|
||||||
|
id,
|
||||||
)
|
)
|
||||||
errs_t += errs
|
errs_t += errs
|
||||||
lengths_t += length
|
lengths_t += length
|
||||||
|
|
||||||
wps_meter.update(num_generated_tokens)
|
wps_meter.update(num_generated_tokens)
|
||||||
t.log({"wps": round(wps_meter.avg)})
|
t.log({"wps": round(wps_meter.avg)})
|
||||||
num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
num_sentences += (
|
||||||
|
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
||||||
|
)
|
||||||
|
|
||||||
wer = None
|
wer = None
|
||||||
if args.dump_emissions:
|
if args.dump_emissions:
|
||||||
@ -413,7 +442,7 @@ def main(args, task=None, model_state=None):
|
|||||||
gen_timer.sum,
|
gen_timer.sum,
|
||||||
num_sentences / gen_timer.sum,
|
num_sentences / gen_timer.sum,
|
||||||
1.0 / gen_timer.avg,
|
1.0 / gen_timer.avg,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
||||||
return task, wer
|
return task, wer
|
||||||
@ -424,6 +453,7 @@ def make_parser():
|
|||||||
parser = add_asr_eval_argument(parser)
|
parser = add_asr_eval_argument(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def cli_main():
|
def cli_main():
|
||||||
parser = make_parser()
|
parser = make_parser()
|
||||||
args = options.parse_args_and_arch(parser)
|
args = options.parse_args_and_arch(parser)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
model_name = file[:file.find('.py')]
|
model_name = file[: file.find(".py")]
|
||||||
importlib.import_module('examples.speech_recognition.models.' + model_name)
|
importlib.import_module("examples.speech_recognition.models." + model_name)
|
||||||
|
@ -9,18 +9,22 @@ from collections.abc import Iterable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqEncoder,
|
FairseqEncoder,
|
||||||
|
FairseqEncoderDecoderModel,
|
||||||
FairseqEncoderModel,
|
FairseqEncoderModel,
|
||||||
FairseqIncrementalDecoder,
|
FairseqIncrementalDecoder,
|
||||||
FairseqEncoderDecoderModel,
|
|
||||||
register_model,
|
register_model,
|
||||||
register_model_architecture,
|
register_model_architecture,
|
||||||
)
|
)
|
||||||
from fairseq.modules import LinearizedConvolution
|
from fairseq.modules import (
|
||||||
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
LinearizedConvolution,
|
||||||
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
|
TransformerDecoderLayer,
|
||||||
|
TransformerEncoderLayer,
|
||||||
|
VGGBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model("asr_vggtransformer")
|
@register_model("asr_vggtransformer")
|
||||||
@ -29,6 +33,7 @@ class VGGTransformerModel(FairseqEncoderDecoderModel):
|
|||||||
Transformers with convolutional context for ASR
|
Transformers with convolutional context for ASR
|
||||||
https://arxiv.org/abs/1904.11660
|
https://arxiv.org/abs/1904.11660
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder, decoder):
|
def __init__(self, encoder, decoder):
|
||||||
super().__init__(encoder, decoder)
|
super().__init__(encoder, decoder)
|
||||||
|
|
||||||
@ -602,18 +607,22 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
|||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
if conv_config[-1][0] != transformer_config[0][0]:
|
if conv_config[-1][0] != transformer_config[0][0]:
|
||||||
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
|
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
|
||||||
self.layers.append(TransformerDecoderLayer(
|
self.layers.append(
|
||||||
prepare_transformer_decoder_params(*transformer_config[0])
|
TransformerDecoderLayer(
|
||||||
))
|
prepare_transformer_decoder_params(*transformer_config[0])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(1, len(transformer_config)):
|
for i in range(1, len(transformer_config)):
|
||||||
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
||||||
self.layers.append(
|
self.layers.append(
|
||||||
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
||||||
)
|
)
|
||||||
self.layers.append(TransformerDecoderLayer(
|
self.layers.append(
|
||||||
prepare_transformer_decoder_params(*transformer_config[i])
|
TransformerDecoderLayer(
|
||||||
))
|
prepare_transformer_decoder_params(*transformer_config[i])
|
||||||
|
)
|
||||||
|
)
|
||||||
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
|
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
|
||||||
|
|
||||||
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
||||||
@ -713,6 +722,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
|||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_model("asr_vggtransformer_encoder")
|
@register_model("asr_vggtransformer_encoder")
|
||||||
class VGGTransformerEncoderModel(FairseqEncoderModel):
|
class VGGTransformerEncoderModel(FairseqEncoderModel):
|
||||||
def __init__(self, encoder):
|
def __init__(self, encoder):
|
||||||
|
@ -10,7 +10,6 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqEncoder,
|
FairseqEncoder,
|
||||||
FairseqEncoderModel,
|
FairseqEncoderModel,
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
for file in os.listdir(os.path.dirname(__file__)):
|
for file in os.listdir(os.path.dirname(__file__)):
|
||||||
if file.endswith('.py') and not file.startswith('_'):
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
task_name = file[:file.find('.py')]
|
task_name = file[: file.find(".py")]
|
||||||
importlib.import_module('examples.speech_recognition.tasks.' + task_name)
|
importlib.import_module("examples.speech_recognition.tasks." + task_name)
|
||||||
|
@ -9,10 +9,10 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairseq.data import Dictionary
|
|
||||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
|
||||||
from examples.speech_recognition.data import AsrDataset
|
from examples.speech_recognition.data import AsrDataset
|
||||||
from examples.speech_recognition.data.replabels import replabel_symbol
|
from examples.speech_recognition.data.replabels import replabel_symbol
|
||||||
|
from fairseq.data import Dictionary
|
||||||
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||||
|
|
||||||
|
|
||||||
def get_asr_dataset_from_json(data_json_path, tgt_dict):
|
def get_asr_dataset_from_json(data_json_path, tgt_dict):
|
||||||
@ -78,10 +78,20 @@ class SpeechRecognitionTask(LegacyFairseqTask):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
|
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
|
||||||
)
|
)
|
||||||
parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N',
|
parser.add_argument(
|
||||||
help='max number of frames in the source sequence')
|
"--max-source-positions",
|
||||||
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
|
default=sys.maxsize,
|
||||||
help='max number of tokens in the target sequence')
|
type=int,
|
||||||
|
metavar="N",
|
||||||
|
help="max number of frames in the source sequence",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-target-positions",
|
||||||
|
default=1024,
|
||||||
|
type=int,
|
||||||
|
metavar="N",
|
||||||
|
help="max number of tokens in the target sequence",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, args, tgt_dict):
|
def __init__(self, args, tgt_dict):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
@ -9,16 +9,18 @@
|
|||||||
Wav2letter decoders.
|
Wav2letter decoders.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import namedtuple, deque
|
|
||||||
import gc
|
import gc
|
||||||
import itertools as it
|
import itertools as it
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import deque, namedtuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from examples.speech_recognition.data.replabels import unpack_replabels
|
||||||
from fairseq import tasks
|
from fairseq import tasks
|
||||||
from fairseq.utils import apply_to_sample
|
from fairseq.utils import apply_to_sample
|
||||||
from examples.speech_recognition.data.replabels import unpack_replabels
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from wav2letter.common import create_word_dict, load_words
|
from wav2letter.common import create_word_dict, load_words
|
||||||
|
@ -4,66 +4,76 @@
|
|||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from multiprocessing import cpu_count
|
import csv
|
||||||
import os
|
import os
|
||||||
import os.path as op
|
import os.path as op
|
||||||
from glob import glob
|
|
||||||
import zipfile
|
import zipfile
|
||||||
import csv
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict, Any, List
|
from glob import glob
|
||||||
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
|
from multiprocessing import cpu_count
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import sentencepiece as sp
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sentencepiece as sp
|
||||||
|
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
|
||||||
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
|
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
UNK_TOKEN, UNK_TOKEN_ID = '<unk>', 3
|
|
||||||
BOS_TOKEN, BOS_TOKEN_ID = '<s>', 0
|
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
|
||||||
EOS_TOKEN, EOS_TOKEN_ID = '</s>', 2
|
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
|
||||||
PAD_TOKEN, PAD_TOKEN_ID = '<pad>', 1
|
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
|
||||||
|
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
|
||||||
|
|
||||||
|
|
||||||
def gen_vocab(
|
def gen_vocab(
|
||||||
input_path: str, output_path_prefix: str, model_type='bpe',
|
input_path: str,
|
||||||
vocab_size=1000,
|
output_path_prefix: str,
|
||||||
|
model_type="bpe",
|
||||||
|
vocab_size=1000,
|
||||||
):
|
):
|
||||||
# Train SentencePiece Model
|
# Train SentencePiece Model
|
||||||
arguments = [
|
arguments = [
|
||||||
f'--input={input_path}',
|
f"--input={input_path}",
|
||||||
f'--model_prefix={output_path_prefix}',
|
f"--model_prefix={output_path_prefix}",
|
||||||
f'--model_type={model_type}',
|
f"--model_type={model_type}",
|
||||||
f'--vocab_size={vocab_size}',
|
f"--vocab_size={vocab_size}",
|
||||||
'--character_coverage=1.0',
|
"--character_coverage=1.0",
|
||||||
f'--num_threads={cpu_count()}',
|
f"--num_threads={cpu_count()}",
|
||||||
f'--unk_id={UNK_TOKEN_ID}',
|
f"--unk_id={UNK_TOKEN_ID}",
|
||||||
f'--bos_id={BOS_TOKEN_ID}',
|
f"--bos_id={BOS_TOKEN_ID}",
|
||||||
f'--eos_id={EOS_TOKEN_ID}',
|
f"--eos_id={EOS_TOKEN_ID}",
|
||||||
f'--pad_id={PAD_TOKEN_ID}'
|
f"--pad_id={PAD_TOKEN_ID}",
|
||||||
]
|
]
|
||||||
sp.SentencePieceTrainer.Train(' '.join(arguments))
|
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
||||||
# Export fairseq dictionary
|
# Export fairseq dictionary
|
||||||
spm = sp.SentencePieceProcessor()
|
spm = sp.SentencePieceProcessor()
|
||||||
spm.Load(output_path_prefix + '.model')
|
spm.Load(output_path_prefix + ".model")
|
||||||
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
||||||
assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \
|
assert (
|
||||||
vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \
|
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
|
||||||
vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \
|
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
|
||||||
vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
|
||||||
|
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
||||||
|
)
|
||||||
vocab = {
|
vocab = {
|
||||||
i: s for i, s in vocab.items()
|
i: s
|
||||||
|
for i, s in vocab.items()
|
||||||
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
||||||
}
|
}
|
||||||
with open(output_path_prefix + '.txt', 'w') as f_out:
|
with open(output_path_prefix + ".txt", "w") as f_out:
|
||||||
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
||||||
f_out.write(f'{s} 1\n')
|
f_out.write(f"{s} 1\n")
|
||||||
|
|
||||||
|
|
||||||
def extract_fbank_features(waveform, sample_rate, output_path=None,
|
def extract_fbank_features(
|
||||||
n_mel_bins=80, apply_utterance_cmvn=True,
|
waveform,
|
||||||
overwrite=False):
|
sample_rate,
|
||||||
|
output_path=None,
|
||||||
|
n_mel_bins=80,
|
||||||
|
apply_utterance_cmvn=True,
|
||||||
|
overwrite=False,
|
||||||
|
):
|
||||||
if output_path is not None and op.exists(output_path) and not overwrite:
|
if output_path is not None and op.exists(output_path) and not overwrite:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -74,8 +84,10 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
|
|||||||
if features is None:
|
if features is None:
|
||||||
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
||||||
if features is None:
|
if features is None:
|
||||||
raise ImportError('Please install pyKaldi or torchaudio to enable '
|
raise ImportError(
|
||||||
'online filterbank feature extraction')
|
"Please install pyKaldi or torchaudio to enable "
|
||||||
|
"online filterbank feature extraction"
|
||||||
|
)
|
||||||
|
|
||||||
if apply_utterance_cmvn:
|
if apply_utterance_cmvn:
|
||||||
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
|
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
|
||||||
@ -89,8 +101,8 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
|
|||||||
def create_zip(data_root, zip_path):
|
def create_zip(data_root, zip_path):
|
||||||
cwd = os.path.abspath(os.curdir)
|
cwd = os.path.abspath(os.curdir)
|
||||||
os.chdir(data_root)
|
os.chdir(data_root)
|
||||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f:
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
|
||||||
for filename in tqdm(glob('*.npy')):
|
for filename in tqdm(glob("*.npy")):
|
||||||
f.write(filename)
|
f.write(filename)
|
||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
|
|
||||||
@ -101,69 +113,80 @@ def is_npy_data(data: bytes) -> bool:
|
|||||||
|
|
||||||
def get_zip_manifest(zip_root, zip_filename):
|
def get_zip_manifest(zip_root, zip_filename):
|
||||||
zip_path = op.join(zip_root, zip_filename)
|
zip_path = op.join(zip_root, zip_filename)
|
||||||
with zipfile.ZipFile(zip_path, mode='r') as f:
|
with zipfile.ZipFile(zip_path, mode="r") as f:
|
||||||
info = f.infolist()
|
info = f.infolist()
|
||||||
manifest = {}
|
manifest = {}
|
||||||
for i in tqdm(info):
|
for i in tqdm(info):
|
||||||
utt_id = op.splitext(i.filename)[0]
|
utt_id = op.splitext(i.filename)[0]
|
||||||
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
||||||
manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}'
|
manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
|
||||||
with open(zip_path, 'rb') as f:
|
with open(zip_path, "rb") as f:
|
||||||
f.seek(offset)
|
f.seek(offset)
|
||||||
data = f.read(file_size)
|
data = f.read(file_size)
|
||||||
assert len(data) > 1 and is_npy_data(data)
|
assert len(data) > 1 and is_npy_data(data)
|
||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml',
|
def gen_config_yaml(
|
||||||
specaugment_policy='lb'):
|
data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
|
||||||
assert specaugment_policy in {'lb', 'ld'}
|
):
|
||||||
|
assert specaugment_policy in {"lb", "ld"}
|
||||||
data_root = op.abspath(data_root)
|
data_root = op.abspath(data_root)
|
||||||
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
|
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
|
||||||
writer.set_audio_root(op.abspath(data_root))
|
writer.set_audio_root(op.abspath(data_root))
|
||||||
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
|
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
|
||||||
writer.set_input_channels(1)
|
writer.set_input_channels(1)
|
||||||
writer.set_input_feat_per_channel(80)
|
writer.set_input_feat_per_channel(80)
|
||||||
if specaugment_policy == 'lb':
|
if specaugment_policy == "lb":
|
||||||
writer.set_specaugment_lb_policy()
|
writer.set_specaugment_lb_policy()
|
||||||
else:
|
else:
|
||||||
writer.set_specaugment_ld_policy()
|
writer.set_specaugment_ld_policy()
|
||||||
writer.set_bpe_tokenizer(
|
writer.set_bpe_tokenizer(
|
||||||
{'bpe': 'sentencepiece',
|
{
|
||||||
'sentencepiece_model': op.join(data_root, spm_filename)}
|
"bpe": "sentencepiece",
|
||||||
|
"sentencepiece_model": op.join(data_root, spm_filename),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
writer.set_feature_transforms('_train', ['specaugment'])
|
writer.set_feature_transforms("_train", ["specaugment"])
|
||||||
writer.flush()
|
writer.flush()
|
||||||
|
|
||||||
|
|
||||||
def save_df_to_tsv(dataframe, path):
|
def save_df_to_tsv(dataframe, path):
|
||||||
dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8",
|
dataframe.to_csv(
|
||||||
escapechar='\\', quoting=csv.QUOTE_NONE)
|
path,
|
||||||
|
sep="\t",
|
||||||
|
header=True,
|
||||||
|
index=False,
|
||||||
|
encoding="utf-8",
|
||||||
|
escapechar="\\",
|
||||||
|
quoting=csv.QUOTE_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def filter_manifest_df(df, is_train_split=False, extra_filters=None,
|
def filter_manifest_df(
|
||||||
min_n_frames=5, max_n_frames=3000):
|
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
|
||||||
|
):
|
||||||
filters = {
|
filters = {
|
||||||
'no speech': df['audio'] == '',
|
"no speech": df["audio"] == "",
|
||||||
f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames,
|
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
|
||||||
'empty sentence': df['tgt_text'] == '',
|
"empty sentence": df["tgt_text"] == "",
|
||||||
}
|
}
|
||||||
if is_train_split:
|
if is_train_split:
|
||||||
filters[f'long speech (>{max_n_frames} frames)'] = \
|
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
|
||||||
df['n_frames'] > max_n_frames
|
|
||||||
if extra_filters is not None:
|
if extra_filters is not None:
|
||||||
filters.update(extra_filters)
|
filters.update(extra_filters)
|
||||||
invalid = reduce(lambda x, y: x | y, filters.values())
|
invalid = reduce(lambda x, y: x | y, filters.values())
|
||||||
valid = ~invalid
|
valid = ~invalid
|
||||||
print(
|
print(
|
||||||
'| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) +
|
"| "
|
||||||
f', total {invalid.sum()} filtered, {valid.sum()} remained.'
|
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
|
||||||
|
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
|
||||||
)
|
)
|
||||||
return df[valid]
|
return df[valid]
|
||||||
|
|
||||||
|
|
||||||
class S2TDataConfigWriter(object):
|
class S2TDataConfigWriter(object):
|
||||||
DEFAULT_VOCAB_FILENAME = 'dict.txt'
|
DEFAULT_VOCAB_FILENAME = "dict.txt"
|
||||||
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
||||||
DEFAULT_INPUT_CHANNELS = 1
|
DEFAULT_INPUT_CHANNELS = 1
|
||||||
|
|
||||||
@ -171,48 +194,69 @@ class S2TDataConfigWriter(object):
|
|||||||
try:
|
try:
|
||||||
import yaml
|
import yaml
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('Please install PyYAML to load YAML files for S2T data config')
|
print("Please install PyYAML to load YAML files for S2T data config")
|
||||||
self.yaml = yaml
|
self.yaml = yaml
|
||||||
self.yaml_path = yaml_path
|
self.yaml_path = yaml_path
|
||||||
self.config = {}
|
self.config = {}
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
with open(self.yaml_path, 'w') as f:
|
with open(self.yaml_path, "w") as f:
|
||||||
self.yaml.dump(self.config, f)
|
self.yaml.dump(self.config, f)
|
||||||
|
|
||||||
def set_audio_root(self, audio_root=''):
|
def set_audio_root(self, audio_root=""):
|
||||||
self.config['audio_root'] = audio_root
|
self.config["audio_root"] = audio_root
|
||||||
|
|
||||||
def set_vocab_filename(self, vocab_filename='dict.txt'):
|
def set_vocab_filename(self, vocab_filename="dict.txt"):
|
||||||
self.config['vocab_filename'] = vocab_filename
|
self.config["vocab_filename"] = vocab_filename
|
||||||
|
|
||||||
def set_specaugment(self, time_wrap_w: int, freq_mask_n: int,
|
def set_specaugment(
|
||||||
freq_mask_f: int, time_mask_n: int, time_mask_t: int,
|
self,
|
||||||
time_mask_p: float):
|
time_wrap_w: int,
|
||||||
self.config['specaugment'] = {
|
freq_mask_n: int,
|
||||||
'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n,
|
freq_mask_f: int,
|
||||||
'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n,
|
time_mask_n: int,
|
||||||
'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p,
|
time_mask_t: int,
|
||||||
|
time_mask_p: float,
|
||||||
|
):
|
||||||
|
self.config["specaugment"] = {
|
||||||
|
"time_wrap_W": time_wrap_w,
|
||||||
|
"freq_mask_N": freq_mask_n,
|
||||||
|
"freq_mask_F": freq_mask_f,
|
||||||
|
"time_mask_N": time_mask_n,
|
||||||
|
"time_mask_T": time_mask_t,
|
||||||
|
"time_mask_p": time_mask_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_specaugment_lb_policy(self):
|
def set_specaugment_lb_policy(self):
|
||||||
self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27,
|
self.set_specaugment(
|
||||||
time_mask_n=1, time_mask_t=100, time_mask_p=1.0)
|
time_wrap_w=0,
|
||||||
|
freq_mask_n=1,
|
||||||
|
freq_mask_f=27,
|
||||||
|
time_mask_n=1,
|
||||||
|
time_mask_t=100,
|
||||||
|
time_mask_p=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
def set_specaugment_ld_policy(self):
|
def set_specaugment_ld_policy(self):
|
||||||
self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27,
|
self.set_specaugment(
|
||||||
time_mask_n=2, time_mask_t=100, time_mask_p=1.0)
|
time_wrap_w=0,
|
||||||
|
freq_mask_n=2,
|
||||||
|
freq_mask_f=27,
|
||||||
|
time_mask_n=2,
|
||||||
|
time_mask_t=100,
|
||||||
|
time_mask_p=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
def set_input_channels(self, input_channels=1):
|
def set_input_channels(self, input_channels=1):
|
||||||
self.config['input_channels'] = input_channels
|
self.config["input_channels"] = input_channels
|
||||||
|
|
||||||
def set_input_feat_per_channel(self, input_feat_per_channel=80):
|
def set_input_feat_per_channel(self, input_feat_per_channel=80):
|
||||||
self.config['input_feat_per_channel'] = input_feat_per_channel
|
self.config["input_feat_per_channel"] = input_feat_per_channel
|
||||||
|
|
||||||
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
||||||
self.config['bpe_tokenizer'] = bpe_tokenizer
|
self.config["bpe_tokenizer"] = bpe_tokenizer
|
||||||
|
|
||||||
def set_feature_transforms(self, split, transforms: List[str]):
|
def set_feature_transforms(self, split, transforms: List[str]):
|
||||||
if 'transforms' not in self.config:
|
if "transforms" not in self.config:
|
||||||
self.config['transforms'] = {}
|
self.config["transforms"] = {}
|
||||||
self.config['transforms'][split] = transforms
|
self.config["transforms"][split] = transforms
|
||||||
|
@ -5,30 +5,35 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
import logging
|
import logging
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
import os
|
import os
|
||||||
import os.path as op
|
import os.path as op
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Tuple, Optional
|
from tempfile import NamedTemporaryFile
|
||||||
import csv
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torchaudio
|
||||||
|
from examples.speech_to_text.data_utils import (
|
||||||
|
create_zip,
|
||||||
|
extract_fbank_features,
|
||||||
|
filter_manifest_df,
|
||||||
|
gen_config_yaml,
|
||||||
|
gen_vocab,
|
||||||
|
get_zip_manifest,
|
||||||
|
save_df_to_tsv,
|
||||||
|
)
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset
|
||||||
from torchaudio.datasets.utils import download_url, extract_archive
|
from torchaudio.datasets.utils import download_url, extract_archive
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pandas as pd
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import torchaudio
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from examples.speech_to_text.data_utils import (
|
|
||||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
|
||||||
extract_fbank_features, gen_config_yaml, filter_manifest_df
|
|
||||||
)
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
||||||
|
|
||||||
|
|
||||||
class CoVoST(Dataset):
|
class CoVoST(Dataset):
|
||||||
@ -44,40 +49,82 @@ class CoVoST(Dataset):
|
|||||||
found at root path. (default: ``False``).
|
found at root path. (default: ``False``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \
|
CV_URL_TEMPLATE = (
|
||||||
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
|
"https://voice-prod-bundler-ee1969a6ce8178826482b88"
|
||||||
COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \
|
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
|
||||||
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
|
)
|
||||||
|
COVOST_URL_TEMPLATE = (
|
||||||
|
"https://dl.fbaipublicfiles.com/covost/"
|
||||||
|
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
|
||||||
|
)
|
||||||
|
|
||||||
VERSIONS = {2}
|
VERSIONS = {2}
|
||||||
SPLITS = ['train', 'dev', 'test']
|
SPLITS = ["train", "dev", "test"]
|
||||||
|
|
||||||
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
|
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
|
||||||
|
|
||||||
XX_EN_LANGUAGES = {
|
XX_EN_LANGUAGES = {
|
||||||
1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn',
|
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
|
||||||
'zh-CN'],
|
2: [
|
||||||
2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn',
|
"fr",
|
||||||
'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy']
|
"de",
|
||||||
|
"es",
|
||||||
|
"ca",
|
||||||
|
"it",
|
||||||
|
"ru",
|
||||||
|
"zh-CN",
|
||||||
|
"pt",
|
||||||
|
"fa",
|
||||||
|
"et",
|
||||||
|
"mn",
|
||||||
|
"nl",
|
||||||
|
"tr",
|
||||||
|
"ar",
|
||||||
|
"sv-SE",
|
||||||
|
"lv",
|
||||||
|
"sl",
|
||||||
|
"ta",
|
||||||
|
"ja",
|
||||||
|
"id",
|
||||||
|
"cy",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
EN_XX_LANGUAGES = {
|
EN_XX_LANGUAGES = {
|
||||||
1: [],
|
1: [],
|
||||||
2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et',
|
2: [
|
||||||
'id',
|
"de",
|
||||||
'ar', 'ta', 'lv', 'ja']
|
"tr",
|
||||||
|
"fa",
|
||||||
|
"sv-SE",
|
||||||
|
"mn",
|
||||||
|
"zh-CN",
|
||||||
|
"cy",
|
||||||
|
"ca",
|
||||||
|
"sl",
|
||||||
|
"et",
|
||||||
|
"id",
|
||||||
|
"ar",
|
||||||
|
"ta",
|
||||||
|
"lv",
|
||||||
|
"ja",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, root: str, split: str, source_language: str,
|
self,
|
||||||
target_language: Optional[str] = None, version: int = 2,
|
root: str,
|
||||||
download: bool = False
|
split: str,
|
||||||
|
source_language: str,
|
||||||
|
target_language: Optional[str] = None,
|
||||||
|
version: int = 2,
|
||||||
|
download: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert version in self.VERSIONS and split in self.SPLITS
|
assert version in self.VERSIONS and split in self.SPLITS
|
||||||
assert source_language is not None
|
assert source_language is not None
|
||||||
self.no_translation = (target_language is None)
|
self.no_translation = target_language is None
|
||||||
if not self.no_translation:
|
if not self.no_translation:
|
||||||
assert 'en' in {source_language, target_language}
|
assert "en" in {source_language, target_language}
|
||||||
if source_language == 'en':
|
if source_language == "en":
|
||||||
assert target_language in self.EN_XX_LANGUAGES[version]
|
assert target_language in self.EN_XX_LANGUAGES[version]
|
||||||
else:
|
else:
|
||||||
assert source_language in self.XX_EN_LANGUAGES[version]
|
assert source_language in self.XX_EN_LANGUAGES[version]
|
||||||
@ -85,51 +132,60 @@ class CoVoST(Dataset):
|
|||||||
# Hack here so that we can get "split" column from CoVoST TSV.
|
# Hack here so that we can get "split" column from CoVoST TSV.
|
||||||
# Note that we use CoVoST train split for ASR which is an extension
|
# Note that we use CoVoST train split for ASR which is an extension
|
||||||
# to Common Voice train split.
|
# to Common Voice train split.
|
||||||
target_language = 'de' if source_language == 'en' else 'en'
|
target_language = "de" if source_language == "en" else "en"
|
||||||
|
|
||||||
self.root = os.path.join(root, 'raw')
|
self.root = os.path.join(root, "raw")
|
||||||
os.makedirs(self.root, exist_ok=True)
|
os.makedirs(self.root, exist_ok=True)
|
||||||
|
|
||||||
cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version],
|
cv_url = self.CV_URL_TEMPLATE.format(
|
||||||
lang=source_language)
|
ver=self.CV_VERSION_ID[version], lang=source_language
|
||||||
|
)
|
||||||
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
|
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
|
||||||
if download:
|
if download:
|
||||||
if not os.path.isfile(cv_archive):
|
if not os.path.isfile(cv_archive):
|
||||||
download_url(cv_url, self.root, hash_value=None)
|
download_url(cv_url, self.root, hash_value=None)
|
||||||
extract_archive(cv_archive)
|
extract_archive(cv_archive)
|
||||||
|
|
||||||
covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language,
|
covost_url = self.COVOST_URL_TEMPLATE.format(
|
||||||
tgt_lang=target_language)
|
src_lang=source_language, tgt_lang=target_language
|
||||||
|
)
|
||||||
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
|
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
|
||||||
if download:
|
if download:
|
||||||
if not os.path.isfile(covost_archive):
|
if not os.path.isfile(covost_archive):
|
||||||
download_url(covost_url, self.root, hash_value=None)
|
download_url(covost_url, self.root, hash_value=None)
|
||||||
extract_archive(covost_archive)
|
extract_archive(covost_archive)
|
||||||
|
|
||||||
cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv'))
|
cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
|
||||||
covost_tsv = self.load_from_tsv(
|
covost_tsv = self.load_from_tsv(
|
||||||
os.path.join(self.root,
|
os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
|
||||||
os.path.basename(covost_url).replace('.tar.gz', ''))
|
|
||||||
)
|
)
|
||||||
df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']],
|
df = pd.merge(
|
||||||
right=covost_tsv[['path', 'translation', 'split']],
|
left=cv_tsv[["path", "sentence", "client_id"]],
|
||||||
how='inner', on='path')
|
right=covost_tsv[["path", "translation", "split"]],
|
||||||
if split == 'train':
|
how="inner",
|
||||||
df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')]
|
on="path",
|
||||||
|
)
|
||||||
|
if split == "train":
|
||||||
|
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
|
||||||
else:
|
else:
|
||||||
df = df[df['split'] == split]
|
df = df[df["split"] == split]
|
||||||
self.data = df.to_dict(orient='index').items()
|
self.data = df.to_dict(orient="index").items()
|
||||||
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
|
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_tsv(cls, path: str):
|
def load_from_tsv(cls, path: str):
|
||||||
return pd.read_csv(
|
return pd.read_csv(
|
||||||
path, sep='\t', header=0, encoding='utf-8', escapechar='\\',
|
path,
|
||||||
quoting=csv.QUOTE_NONE, na_filter=False
|
sep="\t",
|
||||||
|
header=0,
|
||||||
|
encoding="utf-8",
|
||||||
|
escapechar="\\",
|
||||||
|
quoting=csv.QUOTE_NONE,
|
||||||
|
na_filter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, n: int
|
self, n: int
|
||||||
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
|
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
|
||||||
"""Load the n-th sample from the dataset.
|
"""Load the n-th sample from the dataset.
|
||||||
|
|
||||||
@ -141,12 +197,12 @@ class CoVoST(Dataset):
|
|||||||
sample_id)``
|
sample_id)``
|
||||||
"""
|
"""
|
||||||
data = self.data[n]
|
data = self.data[n]
|
||||||
path = os.path.join(self.root, 'clips', data['path'])
|
path = os.path.join(self.root, "clips", data["path"])
|
||||||
waveform, sample_rate = torchaudio.load(path)
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
sentence = data['sentence']
|
sentence = data["sentence"]
|
||||||
translation = None if self.no_translation else data['translation']
|
translation = None if self.no_translation else data["translation"]
|
||||||
speaker_id = data['client_id']
|
speaker_id = data["client_id"]
|
||||||
_id = data['path'].replace('.mp3', '')
|
_id = data["path"].replace(".mp3", "")
|
||||||
return waveform, sample_rate, sentence, translation, speaker_id, _id
|
return waveform, sample_rate, sentence, translation, speaker_id, _id
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -157,76 +213,82 @@ def process(args):
|
|||||||
root = op.join(args.data_root, args.src_lang)
|
root = op.join(args.data_root, args.src_lang)
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
# Extract features
|
# Extract features
|
||||||
feature_root = op.join(root, 'fbank80')
|
feature_root = op.join(root, "fbank80")
|
||||||
os.makedirs(feature_root, exist_ok=True)
|
os.makedirs(feature_root, exist_ok=True)
|
||||||
for split in CoVoST.SPLITS:
|
for split in CoVoST.SPLITS:
|
||||||
print(f'Fetching split {split}...')
|
print(f"Fetching split {split}...")
|
||||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang,
|
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
|
||||||
download=True)
|
print("Extracting log mel filter bank features...")
|
||||||
print('Extracting log mel filter bank features...')
|
|
||||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||||
extract_fbank_features(waveform, sample_rate,
|
extract_fbank_features(
|
||||||
op.join(feature_root, f'{utt_id}.npy'))
|
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
|
||||||
|
)
|
||||||
# Pack features into ZIP
|
# Pack features into ZIP
|
||||||
zip_filename = 'fbank80.zip'
|
zip_filename = "fbank80.zip"
|
||||||
zip_path = op.join(root, zip_filename)
|
zip_path = op.join(root, zip_filename)
|
||||||
print('ZIPing features...')
|
print("ZIPing features...")
|
||||||
create_zip(feature_root, zip_path)
|
create_zip(feature_root, zip_path)
|
||||||
print('Fetching ZIP manifest...')
|
print("Fetching ZIP manifest...")
|
||||||
zip_manifest = get_zip_manifest(args.data_root,
|
zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
|
||||||
f'{args.src_lang}/{zip_filename}')
|
|
||||||
# Generate TSV manifest
|
# Generate TSV manifest
|
||||||
print('Generating manifest...')
|
print("Generating manifest...")
|
||||||
train_text = []
|
train_text = []
|
||||||
task = f'asr_{args.src_lang}'
|
task = f"asr_{args.src_lang}"
|
||||||
if args.tgt_lang is not None:
|
if args.tgt_lang is not None:
|
||||||
task = f'st_{args.src_lang}_{args.tgt_lang}'
|
task = f"st_{args.src_lang}_{args.tgt_lang}"
|
||||||
for split in CoVoST.SPLITS:
|
for split in CoVoST.SPLITS:
|
||||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
|
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
|
||||||
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
||||||
manifest['id'].append(utt_id)
|
manifest["id"].append(utt_id)
|
||||||
manifest['audio'].append(zip_manifest[utt_id])
|
manifest["audio"].append(zip_manifest[utt_id])
|
||||||
duration_ms = int(wav.size(1) / sr * 1000)
|
duration_ms = int(wav.size(1) / sr * 1000)
|
||||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
||||||
manifest['tgt_text'].append(
|
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
|
||||||
src_utt if args.tgt_lang is None else tgt_utt
|
manifest["speaker"].append(speaker_id)
|
||||||
)
|
is_train_split = split.startswith("train")
|
||||||
manifest['speaker'].append(speaker_id)
|
|
||||||
is_train_split = split.startswith('train')
|
|
||||||
if is_train_split:
|
if is_train_split:
|
||||||
train_text.extend(manifest['tgt_text'])
|
train_text.extend(manifest["tgt_text"])
|
||||||
df = pd.DataFrame.from_dict(manifest)
|
df = pd.DataFrame.from_dict(manifest)
|
||||||
df = filter_manifest_df(df, is_train_split=is_train_split)
|
df = filter_manifest_df(df, is_train_split=is_train_split)
|
||||||
save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv'))
|
save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
|
||||||
# Generate vocab
|
# Generate vocab
|
||||||
vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size)
|
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
|
||||||
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}'
|
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
|
||||||
with NamedTemporaryFile(mode='w') as f:
|
with NamedTemporaryFile(mode="w") as f:
|
||||||
for t in train_text:
|
for t in train_text:
|
||||||
f.write(t + '\n')
|
f.write(t + "\n")
|
||||||
gen_vocab(f.name, op.join(root, spm_filename_prefix),
|
gen_vocab(
|
||||||
args.vocab_type, args.vocab_size)
|
f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
|
||||||
|
)
|
||||||
# Generate config YAML
|
# Generate config YAML
|
||||||
gen_config_yaml(root, spm_filename_prefix + '.model',
|
gen_config_yaml(
|
||||||
yaml_filename=f'config_{task}.yaml',
|
root,
|
||||||
specaugment_policy='lb')
|
spm_filename_prefix + ".model",
|
||||||
|
yaml_filename=f"config_{task}.yaml",
|
||||||
|
specaugment_policy="lb",
|
||||||
|
)
|
||||||
# Clean up
|
# Clean up
|
||||||
shutil.rmtree(feature_root)
|
shutil.rmtree(feature_root)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--data-root', '-d', required=True, type=str)
|
parser.add_argument("--data-root", "-d", required=True, type=str)
|
||||||
parser.add_argument('--vocab-type', default='unigram', required=True,
|
parser.add_argument(
|
||||||
type=str, choices=['bpe', 'unigram', 'char']),
|
"--vocab-type",
|
||||||
parser.add_argument('--vocab-size', default=1000, type=int)
|
default="unigram",
|
||||||
parser.add_argument('--src-lang', '-s', required=True, type=str)
|
required=True,
|
||||||
parser.add_argument('--tgt-lang', '-t', type=str)
|
type=str,
|
||||||
|
choices=["bpe", "unigram", "char"],
|
||||||
|
),
|
||||||
|
parser.add_argument("--vocab-size", default=1000, type=int)
|
||||||
|
parser.add_argument("--src-lang", "-s", required=True, type=str)
|
||||||
|
parser.add_argument("--tgt-lang", "-t", type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
process(args)
|
process(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -6,91 +6,114 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import os.path as op
|
import os.path as op
|
||||||
|
import shutil
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
from torchaudio.datasets import LIBRISPEECH
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from examples.speech_to_text.data_utils import (
|
from examples.speech_to_text.data_utils import (
|
||||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
create_zip,
|
||||||
extract_fbank_features, gen_config_yaml
|
extract_fbank_features,
|
||||||
|
gen_config_yaml,
|
||||||
|
gen_vocab,
|
||||||
|
get_zip_manifest,
|
||||||
|
save_df_to_tsv,
|
||||||
)
|
)
|
||||||
|
from torchaudio.datasets import LIBRISPEECH
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean',
|
SPLITS = [
|
||||||
'dev-other', 'test-clean', 'test-other']
|
"train-clean-100",
|
||||||
|
"train-clean-360",
|
||||||
|
"train-other-500",
|
||||||
|
"dev-clean",
|
||||||
|
"dev-other",
|
||||||
|
"test-clean",
|
||||||
|
"test-other",
|
||||||
|
]
|
||||||
|
|
||||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
||||||
|
|
||||||
|
|
||||||
def process(args):
|
def process(args):
|
||||||
os.makedirs(args.output_root, exist_ok=True)
|
os.makedirs(args.output_root, exist_ok=True)
|
||||||
# Extract features
|
# Extract features
|
||||||
feature_root = op.join(args.output_root, 'fbank80')
|
feature_root = op.join(args.output_root, "fbank80")
|
||||||
os.makedirs(feature_root, exist_ok=True)
|
os.makedirs(feature_root, exist_ok=True)
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
print(f'Fetching split {split}...')
|
print(f"Fetching split {split}...")
|
||||||
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
|
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
|
||||||
print('Extracting log mel filter bank features...')
|
print("Extracting log mel filter bank features...")
|
||||||
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
|
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
|
||||||
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
|
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
|
||||||
extract_fbank_features(wav, sample_rate,
|
extract_fbank_features(
|
||||||
op.join(feature_root, f'{sample_id}.npy'))
|
wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
|
||||||
|
)
|
||||||
# Pack features into ZIP
|
# Pack features into ZIP
|
||||||
zip_filename = 'fbank80.zip'
|
zip_filename = "fbank80.zip"
|
||||||
zip_path = op.join(args.output_root, zip_filename)
|
zip_path = op.join(args.output_root, zip_filename)
|
||||||
print('ZIPing features...')
|
print("ZIPing features...")
|
||||||
create_zip(feature_root, zip_path)
|
create_zip(feature_root, zip_path)
|
||||||
print('Fetching ZIP manifest...')
|
print("Fetching ZIP manifest...")
|
||||||
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
|
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
|
||||||
# Generate TSV manifest
|
# Generate TSV manifest
|
||||||
print('Generating manifest...')
|
print("Generating manifest...")
|
||||||
train_text = []
|
train_text = []
|
||||||
for split in SPLITS:
|
for split in SPLITS:
|
||||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||||
dataset = LIBRISPEECH(args.output_root, url=split)
|
dataset = LIBRISPEECH(args.output_root, url=split)
|
||||||
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
|
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
|
||||||
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
|
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
|
||||||
manifest['id'].append(sample_id)
|
manifest["id"].append(sample_id)
|
||||||
manifest['audio'].append(zip_manifest[sample_id])
|
manifest["audio"].append(zip_manifest[sample_id])
|
||||||
duration_ms = int(wav.size(1) / sample_rate * 1000)
|
duration_ms = int(wav.size(1) / sample_rate * 1000)
|
||||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
||||||
manifest['tgt_text'].append(utt)
|
manifest["tgt_text"].append(utt)
|
||||||
manifest['speaker'].append(spk_id)
|
manifest["speaker"].append(spk_id)
|
||||||
save_df_to_tsv(pd.DataFrame.from_dict(manifest),
|
save_df_to_tsv(
|
||||||
op.join(args.output_root, f'{split}.tsv'))
|
pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
|
||||||
if split.startswith('train'):
|
)
|
||||||
train_text.extend(manifest['tgt_text'])
|
if split.startswith("train"):
|
||||||
|
train_text.extend(manifest["tgt_text"])
|
||||||
# Generate vocab
|
# Generate vocab
|
||||||
vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size)
|
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
|
||||||
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}'
|
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
|
||||||
with NamedTemporaryFile(mode='w') as f:
|
with NamedTemporaryFile(mode="w") as f:
|
||||||
for t in train_text:
|
for t in train_text:
|
||||||
f.write(t + '\n')
|
f.write(t + "\n")
|
||||||
gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix),
|
gen_vocab(
|
||||||
args.vocab_type, args.vocab_size)
|
f.name,
|
||||||
|
op.join(args.output_root, spm_filename_prefix),
|
||||||
|
args.vocab_type,
|
||||||
|
args.vocab_size,
|
||||||
|
)
|
||||||
# Generate config YAML
|
# Generate config YAML
|
||||||
gen_config_yaml(args.output_root, spm_filename_prefix + '.model',
|
gen_config_yaml(
|
||||||
specaugment_policy='ld')
|
args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
|
||||||
|
)
|
||||||
# Clean up
|
# Clean up
|
||||||
shutil.rmtree(feature_root)
|
shutil.rmtree(feature_root)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--output-root', '-o', required=True, type=str)
|
parser.add_argument("--output-root", "-o", required=True, type=str)
|
||||||
parser.add_argument('--vocab-type', default='unigram', required=True,
|
parser.add_argument(
|
||||||
type=str, choices=['bpe', 'unigram', 'char']),
|
"--vocab-type",
|
||||||
parser.add_argument('--vocab-size', default=10000, type=int)
|
default="unigram",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
choices=["bpe", "unigram", "char"],
|
||||||
|
),
|
||||||
|
parser.add_argument("--vocab-size", default=10000, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
process(args)
|
process(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -6,29 +6,34 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
import os
|
import os
|
||||||
import os.path as op
|
import os.path as op
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Tuple
|
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from examples.speech_to_text.data_utils import (
|
from examples.speech_to_text.data_utils import (
|
||||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
create_zip,
|
||||||
extract_fbank_features, gen_config_yaml, filter_manifest_df
|
extract_fbank_features,
|
||||||
|
filter_manifest_df,
|
||||||
|
gen_config_yaml,
|
||||||
|
gen_vocab,
|
||||||
|
get_zip_manifest,
|
||||||
|
save_df_to_tsv,
|
||||||
)
|
)
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
||||||
TASKS = ['asr', 'st']
|
TASKS = ["asr", "st"]
|
||||||
|
|
||||||
|
|
||||||
class MUSTC(Dataset):
|
class MUSTC(Dataset):
|
||||||
@ -37,49 +42,55 @@ class MUSTC(Dataset):
|
|||||||
waveform, sample_rate, source utterance, target utterance, speaker_id,
|
waveform, sample_rate, source utterance, target utterance, speaker_id,
|
||||||
utterance_id
|
utterance_id
|
||||||
"""
|
"""
|
||||||
SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
|
|
||||||
LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru']
|
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
|
||||||
|
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
|
||||||
|
|
||||||
def __init__(self, root: str, lang: str, split: str) -> None:
|
def __init__(self, root: str, lang: str, split: str) -> None:
|
||||||
assert split in self.SPLITS and lang in self.LANGUAGES
|
assert split in self.SPLITS and lang in self.LANGUAGES
|
||||||
_root = op.join(root, f'en-{lang}', 'data', split)
|
_root = op.join(root, f"en-{lang}", "data", split)
|
||||||
wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt')
|
wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
|
||||||
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
|
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
|
||||||
# Load audio segments
|
# Load audio segments
|
||||||
try:
|
try:
|
||||||
import yaml
|
import yaml
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('Please install PyYAML to load YAML files for '
|
print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
|
||||||
'the MuST-C dataset')
|
with open(op.join(txt_root, f"{split}.yaml")) as f:
|
||||||
with open(op.join(txt_root, f'{split}.yaml')) as f:
|
|
||||||
segments = yaml.load(f, Loader=yaml.BaseLoader)
|
segments = yaml.load(f, Loader=yaml.BaseLoader)
|
||||||
# Load source and target utterances
|
# Load source and target utterances
|
||||||
for _lang in ['en', lang]:
|
for _lang in ["en", lang]:
|
||||||
with open(op.join(txt_root, f'{split}.{_lang}')) as f:
|
with open(op.join(txt_root, f"{split}.{_lang}")) as f:
|
||||||
utterances = [r.strip() for r in f]
|
utterances = [r.strip() for r in f]
|
||||||
assert len(segments) == len(utterances)
|
assert len(segments) == len(utterances)
|
||||||
for i, u in enumerate(utterances):
|
for i, u in enumerate(utterances):
|
||||||
segments[i][_lang] = u
|
segments[i][_lang] = u
|
||||||
# Gather info
|
# Gather info
|
||||||
self.data = []
|
self.data = []
|
||||||
for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']):
|
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
|
||||||
wav_path = op.join(wav_root, wav_filename)
|
wav_path = op.join(wav_root, wav_filename)
|
||||||
sample_rate = torchaudio.info(wav_path)[0].rate
|
sample_rate = torchaudio.info(wav_path)[0].rate
|
||||||
seg_group = sorted(_seg_group, key=lambda x: x['offset'])
|
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
|
||||||
for i, segment in enumerate(seg_group):
|
for i, segment in enumerate(seg_group):
|
||||||
offset = int(float(segment['offset']) * sample_rate)
|
offset = int(float(segment["offset"]) * sample_rate)
|
||||||
n_frames = int(float(segment['duration']) * sample_rate)
|
n_frames = int(float(segment["duration"]) * sample_rate)
|
||||||
_id = f'{op.splitext(wav_filename)[0]}_{i}'
|
_id = f"{op.splitext(wav_filename)[0]}_{i}"
|
||||||
self.data.append(
|
self.data.append(
|
||||||
(wav_path, offset, n_frames, sample_rate, segment['en'],
|
(
|
||||||
segment[lang], segment['speaker_id'], _id)
|
wav_path,
|
||||||
|
offset,
|
||||||
|
n_frames,
|
||||||
|
sample_rate,
|
||||||
|
segment["en"],
|
||||||
|
segment[lang],
|
||||||
|
segment["speaker_id"],
|
||||||
|
_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
|
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
|
||||||
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \
|
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
|
||||||
self.data[n]
|
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
|
||||||
waveform, _ = torchaudio.load(wav_path, offset=offset,
|
|
||||||
num_frames=n_frames)
|
|
||||||
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
|
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -88,85 +99,102 @@ class MUSTC(Dataset):
|
|||||||
|
|
||||||
def process(args):
|
def process(args):
|
||||||
for lang in MUSTC.LANGUAGES:
|
for lang in MUSTC.LANGUAGES:
|
||||||
cur_root = op.join(args.data_root, f'en-{lang}')
|
cur_root = op.join(args.data_root, f"en-{lang}")
|
||||||
if not op.isdir(cur_root):
|
if not op.isdir(cur_root):
|
||||||
print(f'{cur_root} does not exist. Skipped.')
|
print(f"{cur_root} does not exist. Skipped.")
|
||||||
continue
|
continue
|
||||||
# Extract features
|
# Extract features
|
||||||
feature_root = op.join(cur_root, 'fbank80')
|
feature_root = op.join(cur_root, "fbank80")
|
||||||
os.makedirs(feature_root, exist_ok=True)
|
os.makedirs(feature_root, exist_ok=True)
|
||||||
for split in MUSTC.SPLITS:
|
for split in MUSTC.SPLITS:
|
||||||
print(f'Fetching split {split}...')
|
print(f"Fetching split {split}...")
|
||||||
dataset = MUSTC(args.data_root, lang, split)
|
dataset = MUSTC(args.data_root, lang, split)
|
||||||
print('Extracting log mel filter bank features...')
|
print("Extracting log mel filter bank features...")
|
||||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||||
extract_fbank_features(waveform, sample_rate,
|
extract_fbank_features(
|
||||||
op.join(feature_root, f'{utt_id}.npy'))
|
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
|
||||||
|
)
|
||||||
# Pack features into ZIP
|
# Pack features into ZIP
|
||||||
zip_filename = 'fbank80.zip'
|
zip_filename = "fbank80.zip"
|
||||||
zip_path = op.join(cur_root, zip_filename)
|
zip_path = op.join(cur_root, zip_filename)
|
||||||
print('ZIPing features...')
|
print("ZIPing features...")
|
||||||
create_zip(feature_root, zip_path)
|
create_zip(feature_root, zip_path)
|
||||||
print('Fetching ZIP manifest...')
|
print("Fetching ZIP manifest...")
|
||||||
zip_manifest = get_zip_manifest(args.data_root,
|
zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
|
||||||
f'en-{lang}/{zip_filename}')
|
|
||||||
# Generate TSV manifest
|
# Generate TSV manifest
|
||||||
print('Generating manifest...')
|
print("Generating manifest...")
|
||||||
train_text = {task: [] for task in TASKS}
|
train_text = {task: [] for task in TASKS}
|
||||||
for split in MUSTC.SPLITS:
|
for split in MUSTC.SPLITS:
|
||||||
is_train_split = split.startswith('train')
|
is_train_split = split.startswith("train")
|
||||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||||
text = {task: [] for task in TASKS}
|
text = {task: [] for task in TASKS}
|
||||||
dataset = MUSTC(args.data_root, lang, split)
|
dataset = MUSTC(args.data_root, lang, split)
|
||||||
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
||||||
manifest['id'].append(utt_id)
|
manifest["id"].append(utt_id)
|
||||||
manifest['audio'].append(zip_manifest[utt_id])
|
manifest["audio"].append(zip_manifest[utt_id])
|
||||||
duration_ms = int(wav.size(1) / sr * 1000)
|
duration_ms = int(wav.size(1) / sr * 1000)
|
||||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
||||||
text['asr'].append(src_utt)
|
text["asr"].append(src_utt)
|
||||||
text['st'].append(tgt_utt)
|
text["st"].append(tgt_utt)
|
||||||
manifest['speaker'].append(speaker_id)
|
manifest["speaker"].append(speaker_id)
|
||||||
if is_train_split:
|
if is_train_split:
|
||||||
for task in TASKS:
|
for task in TASKS:
|
||||||
train_text[task].extend(text[task])
|
train_text[task].extend(text[task])
|
||||||
for task in TASKS:
|
for task in TASKS:
|
||||||
manifest['tgt_text'] = text[task]
|
manifest["tgt_text"] = text[task]
|
||||||
df = pd.DataFrame.from_dict(manifest)
|
df = pd.DataFrame.from_dict(manifest)
|
||||||
df = filter_manifest_df(df, is_train_split=is_train_split)
|
df = filter_manifest_df(df, is_train_split=is_train_split)
|
||||||
save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv'))
|
save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
|
||||||
# Generate vocab
|
# Generate vocab
|
||||||
for task in TASKS:
|
for task in TASKS:
|
||||||
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
|
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
|
||||||
if task == 'st':
|
if task == "st":
|
||||||
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
|
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
|
||||||
vocab_size_str = '' if vocab_type == 'char' else str(vocab_size)
|
vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
|
||||||
spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}'
|
spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
|
||||||
with NamedTemporaryFile(mode='w') as f:
|
with NamedTemporaryFile(mode="w") as f:
|
||||||
for t in train_text[task]:
|
for t in train_text[task]:
|
||||||
f.write(t + '\n')
|
f.write(t + "\n")
|
||||||
gen_vocab(f.name, op.join(cur_root, spm_filename_prefix),
|
gen_vocab(
|
||||||
vocab_type, vocab_size)
|
f.name,
|
||||||
|
op.join(cur_root, spm_filename_prefix),
|
||||||
|
vocab_type,
|
||||||
|
vocab_size,
|
||||||
|
)
|
||||||
# Generate config YAML
|
# Generate config YAML
|
||||||
gen_config_yaml(cur_root, spm_filename_prefix + '.model',
|
gen_config_yaml(
|
||||||
yaml_filename=f'config_{task}.yaml',
|
cur_root,
|
||||||
specaugment_policy='lb')
|
spm_filename_prefix + ".model",
|
||||||
|
yaml_filename=f"config_{task}.yaml",
|
||||||
|
specaugment_policy="lb",
|
||||||
|
)
|
||||||
# Clean up
|
# Clean up
|
||||||
shutil.rmtree(feature_root)
|
shutil.rmtree(feature_root)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--data-root', '-d', required=True, type=str)
|
parser.add_argument("--data-root", "-d", required=True, type=str)
|
||||||
parser.add_argument('--asr-vocab-type', default='unigram', required=True,
|
parser.add_argument(
|
||||||
type=str, choices=['bpe', 'unigram', 'char']),
|
"--asr-vocab-type",
|
||||||
parser.add_argument('--st-vocab-type', default='unigram', required=True,
|
default="unigram",
|
||||||
type=str, choices=['bpe', 'unigram', 'char']),
|
required=True,
|
||||||
parser.add_argument('--asr-vocab-size', default=5000, type=int)
|
type=str,
|
||||||
parser.add_argument('--st-vocab-size', default=8000, type=int)
|
choices=["bpe", "unigram", "char"],
|
||||||
|
),
|
||||||
|
parser.add_argument(
|
||||||
|
"--st-vocab-type",
|
||||||
|
default="unigram",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
choices=["bpe", "unigram", "char"],
|
||||||
|
),
|
||||||
|
parser.add_argument("--asr-vocab-size", default=5000, type=int)
|
||||||
|
parser.add_argument("--st-vocab-size", default=8000, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
process(args)
|
process(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -12,9 +12,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from itertools import chain
|
|
||||||
import sys
|
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
|
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
|
||||||
@ -22,17 +22,21 @@ from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(sys.argv[0])
|
parser = argparse.ArgumentParser(sys.argv[0])
|
||||||
parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
|
parser.add_argument(
|
||||||
help='path to system output')
|
"--sys", nargs="*", default="", metavar="FILE", help="path to system output"
|
||||||
parser.add_argument('--ref', default='', metavar='FILE',
|
)
|
||||||
help='path to references')
|
parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
|
||||||
parser.add_argument('--output', default='', metavar='FILE',
|
parser.add_argument(
|
||||||
help='print outputs into a pretty format')
|
"--output",
|
||||||
|
default="",
|
||||||
|
metavar="FILE",
|
||||||
|
help="print outputs into a pretty format",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.sys:
|
if args.sys:
|
||||||
src, tgt, hypos, log_probs = load_sys(args.sys)
|
src, tgt, hypos, log_probs = load_sys(args.sys)
|
||||||
print('pairwise BLEU: %.2f' % pairwise(hypos))
|
print("pairwise BLEU: %.2f" % pairwise(hypos))
|
||||||
if args.output:
|
if args.output:
|
||||||
merge(src, tgt, hypos, log_probs, args.output)
|
merge(src, tgt, hypos, log_probs, args.output)
|
||||||
|
|
||||||
@ -58,18 +62,18 @@ def load_sys(paths):
|
|||||||
# S: source
|
# S: source
|
||||||
# T: target
|
# T: target
|
||||||
# D: detokenized system output
|
# D: detokenized system output
|
||||||
if line.startswith(('S-', 'T-', 'D-')):
|
if line.startswith(("S-", "T-", "D-")):
|
||||||
i = int(line[line.find('-')+1:line.find('\t')])
|
i = int(line[line.find("-") + 1 : line.find("\t")])
|
||||||
if line.startswith('S-'):
|
if line.startswith("S-"):
|
||||||
src[i] = line.split('\t')[1]
|
src[i] = line.split("\t")[1]
|
||||||
if line.startswith('T-'):
|
if line.startswith("T-"):
|
||||||
tgt[i] = line.split('\t')[1]
|
tgt[i] = line.split("\t")[1]
|
||||||
if line.startswith('D-'):
|
if line.startswith("D-"):
|
||||||
if i not in hypos:
|
if i not in hypos:
|
||||||
hypos[i] = []
|
hypos[i] = []
|
||||||
log_probs[i] = []
|
log_probs[i] = []
|
||||||
hypos[i].append(line.split('\t')[2])
|
hypos[i].append(line.split("\t")[2])
|
||||||
log_probs[i].append(float(line.split('\t')[1]))
|
log_probs[i].append(float(line.split("\t")[1]))
|
||||||
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
|
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
|
||||||
|
|
||||||
|
|
||||||
@ -79,34 +83,34 @@ def load_ref(path):
|
|||||||
src, tgt, refs = [], [], []
|
src, tgt, refs = [], [], []
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(lines):
|
while i < len(lines):
|
||||||
if lines[i].startswith('S-'):
|
if lines[i].startswith("S-"):
|
||||||
src.append(lines[i].split('\t')[1].rstrip())
|
src.append(lines[i].split("\t")[1].rstrip())
|
||||||
i += 1
|
i += 1
|
||||||
elif lines[i].startswith('T-'):
|
elif lines[i].startswith("T-"):
|
||||||
tgt.append(lines[i].split('\t')[1].rstrip())
|
tgt.append(lines[i].split("\t")[1].rstrip())
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
a = []
|
a = []
|
||||||
while i < len(lines) and lines[i].startswith('R'):
|
while i < len(lines) and lines[i].startswith("R"):
|
||||||
a.append(lines[i].split('\t')[1].rstrip())
|
a.append(lines[i].split("\t")[1].rstrip())
|
||||||
i += 1
|
i += 1
|
||||||
refs.append(a)
|
refs.append(a)
|
||||||
return src, tgt, refs
|
return src, tgt, refs
|
||||||
|
|
||||||
|
|
||||||
def merge(src, tgt, hypos, log_probs, path):
|
def merge(src, tgt, hypos, log_probs, path):
|
||||||
with open(path, 'w') as f:
|
with open(path, "w") as f:
|
||||||
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
|
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
|
||||||
f.write(s + '\n')
|
f.write(s + "\n")
|
||||||
f.write(t + '\n')
|
f.write(t + "\n")
|
||||||
f.write('\n')
|
f.write("\n")
|
||||||
for h, lp in zip(hs, lps):
|
for h, lp in zip(hs, lps):
|
||||||
f.write('\t%f\t%s\n' % (lp, h.strip()))
|
f.write("\t%f\t%s\n" % (lp, h.strip()))
|
||||||
f.write('------------------------------------------------------\n')
|
f.write("------------------------------------------------------\n")
|
||||||
|
|
||||||
|
|
||||||
def corpus_bleu(sys_stream, ref_streams):
|
def corpus_bleu(sys_stream, ref_streams):
|
||||||
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
|
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
|
||||||
return bleu.score
|
return bleu.score
|
||||||
|
|
||||||
|
|
||||||
@ -116,9 +120,11 @@ def sentence_bleu(hypothesis, reference):
|
|||||||
bleu.counts[i] += 1
|
bleu.counts[i] += 1
|
||||||
bleu.totals[i] += 1
|
bleu.totals[i] += 1
|
||||||
bleu = compute_bleu(
|
bleu = compute_bleu(
|
||||||
bleu.counts, bleu.totals,
|
bleu.counts,
|
||||||
bleu.sys_len, bleu.ref_len,
|
bleu.totals,
|
||||||
smooth_method='exp',
|
bleu.sys_len,
|
||||||
|
bleu.ref_len,
|
||||||
|
smooth_method="exp",
|
||||||
)
|
)
|
||||||
return bleu.score
|
return bleu.score
|
||||||
|
|
||||||
@ -150,7 +156,7 @@ def multi_ref(refs, hypos):
|
|||||||
best = [k for k in range(len(rs)) if s[k] == s[j]]
|
best = [k for k in range(len(rs)) if s[k] == s[j]]
|
||||||
a.add(random.choice(best))
|
a.add(random.choice(best))
|
||||||
ref_cnt += len(a)
|
ref_cnt += len(a)
|
||||||
print('#refs covered: %.2f' % (ref_cnt / len(refs)))
|
print("#refs covered: %.2f" % (ref_cnt / len(refs)))
|
||||||
|
|
||||||
# transpose refs and hypos
|
# transpose refs and hypos
|
||||||
refs = list(zip(*refs))
|
refs = list(zip(*refs))
|
||||||
@ -160,33 +166,32 @@ def multi_ref(refs, hypos):
|
|||||||
k = len(hypos)
|
k = len(hypos)
|
||||||
m = len(refs)
|
m = len(refs)
|
||||||
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
|
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
|
||||||
duplicated_refs = [
|
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
|
||||||
[ref for ref in refs_i for _ in range(k)]
|
|
||||||
for refs_i in refs
|
|
||||||
]
|
|
||||||
loo_bleus = []
|
loo_bleus = []
|
||||||
for held_out_ref in range(m):
|
for held_out_ref in range(m):
|
||||||
remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
|
remaining_refs = (
|
||||||
|
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
|
||||||
|
)
|
||||||
assert len(remaining_refs) == m - 1
|
assert len(remaining_refs) == m - 1
|
||||||
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
|
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
|
||||||
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
|
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
|
||||||
|
|
||||||
|
|
||||||
def intra_ref(refs):
|
def intra_ref(refs):
|
||||||
print('ref pairwise BLEU: %.2f' % pairwise(refs))
|
print("ref pairwise BLEU: %.2f" % pairwise(refs))
|
||||||
refs = list(zip(*refs))
|
refs = list(zip(*refs))
|
||||||
m = len(refs)
|
m = len(refs)
|
||||||
concat_h = []
|
concat_h = []
|
||||||
concat_rest = [[] for j in range(m - 1)]
|
concat_rest = [[] for j in range(m - 1)]
|
||||||
for i, h in enumerate(refs):
|
for i, h in enumerate(refs):
|
||||||
rest = refs[:i] + refs[i+1:]
|
rest = refs[:i] + refs[i + 1 :]
|
||||||
concat_h.append(h)
|
concat_h.append(h)
|
||||||
for j in range(m - 1):
|
for j in range(m - 1):
|
||||||
concat_rest[j].extend(rest[j])
|
concat_rest[j].extend(rest[j])
|
||||||
concat_h = list(chain.from_iterable(concat_h))
|
concat_h = list(chain.from_iterable(concat_h))
|
||||||
bleu = corpus_bleu(concat_h, concat_rest)
|
bleu = corpus_bleu(concat_h, concat_rest)
|
||||||
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
|
print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -21,6 +21,6 @@ class LogSumExpMoE(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
posterior, = ctx.saved_tensors
|
(posterior,) = ctx.saved_tensors
|
||||||
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
|
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
|
||||||
return grad_logp, None, None
|
return grad_logp, None, None
|
||||||
|
@ -26,15 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, encoder_out):
|
def forward(self, encoder_out):
|
||||||
if not (
|
if not (
|
||||||
hasattr(encoder_out, 'encoder_out')
|
hasattr(encoder_out, "encoder_out")
|
||||||
and hasattr(encoder_out, 'encoder_padding_mask')
|
and hasattr(encoder_out, "encoder_padding_mask")
|
||||||
and encoder_out.encoder_out.size(2) == self.embed_dim
|
and encoder_out.encoder_out.size(2) == self.embed_dim
|
||||||
):
|
):
|
||||||
raise ValueError('Unexpected format for encoder_out')
|
raise ValueError("Unexpected format for encoder_out")
|
||||||
|
|
||||||
# mean pooling over time
|
# mean pooling over time
|
||||||
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
|
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
|
||||||
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
|
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
|
||||||
if encoder_padding_mask is not None:
|
if encoder_padding_mask is not None:
|
||||||
encoder_out = encoder_out.clone() # required because of transpose above
|
encoder_out = encoder_out.clone() # required because of transpose above
|
||||||
encoder_out[encoder_padding_mask] = 0
|
encoder_out[encoder_padding_mask] = 0
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq import metrics, utils
|
from fairseq import metrics, utils
|
||||||
from fairseq.tasks import register_task
|
from fairseq.tasks import register_task
|
||||||
from fairseq.tasks.translation import TranslationTask
|
from fairseq.tasks.translation import TranslationTask
|
||||||
@ -13,7 +12,7 @@ from .logsumexp_moe import LogSumExpMoE
|
|||||||
from .mean_pool_gating_network import MeanPoolGatingNetwork
|
from .mean_pool_gating_network import MeanPoolGatingNetwork
|
||||||
|
|
||||||
|
|
||||||
@register_task('translation_moe')
|
@register_task("translation_moe")
|
||||||
class TranslationMoETask(TranslationTask):
|
class TranslationMoETask(TranslationTask):
|
||||||
"""
|
"""
|
||||||
Translation task for Mixture of Experts (MoE) models.
|
Translation task for Mixture of Experts (MoE) models.
|
||||||
@ -58,19 +57,19 @@ class TranslationMoETask(TranslationTask):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def __init__(self, args, src_dict, tgt_dict):
|
def __init__(self, args, src_dict, tgt_dict):
|
||||||
if args.method == 'sMoElp':
|
if args.method == "sMoElp":
|
||||||
# soft MoE with learned prior
|
# soft MoE with learned prior
|
||||||
self.uniform_prior = False
|
self.uniform_prior = False
|
||||||
self.hard_selection = False
|
self.hard_selection = False
|
||||||
elif args.method == 'sMoEup':
|
elif args.method == "sMoEup":
|
||||||
# soft MoE with uniform prior
|
# soft MoE with uniform prior
|
||||||
self.uniform_prior = True
|
self.uniform_prior = True
|
||||||
self.hard_selection = False
|
self.hard_selection = False
|
||||||
elif args.method == 'hMoElp':
|
elif args.method == "hMoElp":
|
||||||
# hard MoE with learned prior
|
# hard MoE with learned prior
|
||||||
self.uniform_prior = False
|
self.uniform_prior = False
|
||||||
self.hard_selection = True
|
self.hard_selection = True
|
||||||
elif args.method == 'hMoEup':
|
elif args.method == "hMoEup":
|
||||||
# hard MoE with uniform prior
|
# hard MoE with uniform prior
|
||||||
self.uniform_prior = True
|
self.uniform_prior = True
|
||||||
self.hard_selection = True
|
self.hard_selection = True
|
||||||
@ -78,50 +77,56 @@ class TranslationMoETask(TranslationTask):
|
|||||||
# add indicator tokens for each expert
|
# add indicator tokens for each expert
|
||||||
for i in range(args.num_experts):
|
for i in range(args.num_experts):
|
||||||
# add to both dictionaries in case we're sharing embeddings
|
# add to both dictionaries in case we're sharing embeddings
|
||||||
src_dict.add_symbol('<expert_{}>'.format(i))
|
src_dict.add_symbol("<expert_{}>".format(i))
|
||||||
tgt_dict.add_symbol('<expert_{}>'.format(i))
|
tgt_dict.add_symbol("<expert_{}>".format(i))
|
||||||
|
|
||||||
super().__init__(args, src_dict, tgt_dict)
|
super().__init__(args, src_dict, tgt_dict)
|
||||||
|
|
||||||
def build_model(self, args):
|
def build_model(self, args):
|
||||||
from fairseq import models
|
from fairseq import models
|
||||||
|
|
||||||
model = models.build_model(args, self)
|
model = models.build_model(args, self)
|
||||||
if not self.uniform_prior and not hasattr(model, 'gating_network'):
|
if not self.uniform_prior and not hasattr(model, "gating_network"):
|
||||||
if self.args.mean_pool_gating_network:
|
if self.args.mean_pool_gating_network:
|
||||||
if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
|
if getattr(args, "mean_pool_gating_network_encoder_dim", None):
|
||||||
encoder_dim = args.mean_pool_gating_network_encoder_dim
|
encoder_dim = args.mean_pool_gating_network_encoder_dim
|
||||||
elif getattr(args, 'encoder_embed_dim', None):
|
elif getattr(args, "encoder_embed_dim", None):
|
||||||
# assume that encoder_embed_dim is the encoder's output dimension
|
# assume that encoder_embed_dim is the encoder's output dimension
|
||||||
encoder_dim = args.encoder_embed_dim
|
encoder_dim = args.encoder_embed_dim
|
||||||
else:
|
else:
|
||||||
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
|
raise ValueError(
|
||||||
|
"Must specify --mean-pool-gating-network-encoder-dim"
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(args, 'mean_pool_gating_network_dropout', None):
|
if getattr(args, "mean_pool_gating_network_dropout", None):
|
||||||
dropout = args.mean_pool_gating_network_dropout
|
dropout = args.mean_pool_gating_network_dropout
|
||||||
elif getattr(args, 'dropout', None):
|
elif getattr(args, "dropout", None):
|
||||||
dropout = args.dropout
|
dropout = args.dropout
|
||||||
else:
|
else:
|
||||||
raise ValueError('Must specify --mean-pool-gating-network-dropout')
|
raise ValueError("Must specify --mean-pool-gating-network-dropout")
|
||||||
|
|
||||||
model.gating_network = MeanPoolGatingNetwork(
|
model.gating_network = MeanPoolGatingNetwork(
|
||||||
encoder_dim, args.num_experts, dropout,
|
encoder_dim,
|
||||||
|
args.num_experts,
|
||||||
|
dropout,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'translation_moe task with learned prior requires the model to '
|
"translation_moe task with learned prior requires the model to "
|
||||||
'have a gating network; try using --mean-pool-gating-network'
|
"have a gating network; try using --mean-pool-gating-network"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def expert_index(self, i):
|
def expert_index(self, i):
|
||||||
return i + self.tgt_dict.index('<expert_0>')
|
return i + self.tgt_dict.index("<expert_0>")
|
||||||
|
|
||||||
def _get_loss(self, sample, model, criterion):
|
def _get_loss(self, sample, model, criterion):
|
||||||
assert hasattr(criterion, 'compute_loss'), \
|
assert hasattr(
|
||||||
'translation_moe task requires the criterion to implement the compute_loss() method'
|
criterion, "compute_loss"
|
||||||
|
), "translation_moe task requires the criterion to implement the compute_loss() method"
|
||||||
|
|
||||||
k = self.args.num_experts
|
k = self.args.num_experts
|
||||||
bsz = sample['target'].size(0)
|
bsz = sample["target"].size(0)
|
||||||
|
|
||||||
def get_lprob_y(encoder_out, prev_output_tokens_k):
|
def get_lprob_y(encoder_out, prev_output_tokens_k):
|
||||||
net_output = model.decoder(
|
net_output = model.decoder(
|
||||||
@ -134,20 +139,22 @@ class TranslationMoETask(TranslationTask):
|
|||||||
|
|
||||||
def get_lprob_yz(winners=None):
|
def get_lprob_yz(winners=None):
|
||||||
encoder_out = model.encoder(
|
encoder_out = model.encoder(
|
||||||
src_tokens=sample['net_input']['src_tokens'],
|
src_tokens=sample["net_input"]["src_tokens"],
|
||||||
src_lengths=sample['net_input']['src_lengths'],
|
src_lengths=sample["net_input"]["src_lengths"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if winners is None:
|
if winners is None:
|
||||||
lprob_y = []
|
lprob_y = []
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
|
prev_output_tokens_k = sample["net_input"][
|
||||||
|
"prev_output_tokens"
|
||||||
|
].clone()
|
||||||
assert not prev_output_tokens_k.requires_grad
|
assert not prev_output_tokens_k.requires_grad
|
||||||
prev_output_tokens_k[:, 0] = self.expert_index(i)
|
prev_output_tokens_k[:, 0] = self.expert_index(i)
|
||||||
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
|
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
|
||||||
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
|
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
|
||||||
else:
|
else:
|
||||||
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
|
prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
|
||||||
prev_output_tokens_k[:, 0] = self.expert_index(winners)
|
prev_output_tokens_k[:, 0] = self.expert_index(winners)
|
||||||
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
|
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
|
||||||
|
|
||||||
@ -177,17 +184,21 @@ class TranslationMoETask(TranslationTask):
|
|||||||
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
|
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
|
||||||
|
|
||||||
loss = loss.sum()
|
loss = loss.sum()
|
||||||
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
|
sample_size = (
|
||||||
|
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
|
||||||
|
)
|
||||||
logging_output = {
|
logging_output = {
|
||||||
'loss': utils.item(loss.data),
|
"loss": utils.item(loss.data),
|
||||||
'ntokens': sample['ntokens'],
|
"ntokens": sample["ntokens"],
|
||||||
'nsentences': bsz,
|
"nsentences": bsz,
|
||||||
'sample_size': sample_size,
|
"sample_size": sample_size,
|
||||||
'posterior': prob_z_xy.float().sum(dim=0).cpu(),
|
"posterior": prob_z_xy.float().sum(dim=0).cpu(),
|
||||||
}
|
}
|
||||||
return loss, sample_size, logging_output
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
|
def train_step(
|
||||||
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||||
|
):
|
||||||
model.train()
|
model.train()
|
||||||
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
||||||
if ignore_grad:
|
if ignore_grad:
|
||||||
@ -201,7 +212,15 @@ class TranslationMoETask(TranslationTask):
|
|||||||
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
||||||
return loss, sample_size, logging_output
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None):
|
def inference_step(
|
||||||
|
self,
|
||||||
|
generator,
|
||||||
|
models,
|
||||||
|
sample,
|
||||||
|
prefix_tokens=None,
|
||||||
|
expert=None,
|
||||||
|
constraints=None,
|
||||||
|
):
|
||||||
expert = expert or self.args.gen_expert
|
expert = expert or self.args.gen_expert
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return generator.generate(
|
return generator.generate(
|
||||||
@ -215,6 +234,6 @@ class TranslationMoETask(TranslationTask):
|
|||||||
def reduce_metrics(self, logging_outputs, criterion):
|
def reduce_metrics(self, logging_outputs, criterion):
|
||||||
super().reduce_metrics(logging_outputs, criterion)
|
super().reduce_metrics(logging_outputs, criterion)
|
||||||
metrics.log_scalar(
|
metrics.log_scalar(
|
||||||
'posterior',
|
"posterior",
|
||||||
sum(log['posterior'] for log in logging_outputs if 'posterior' in log)
|
sum(log["posterior"] for log in logging_outputs if "posterior" in log),
|
||||||
)
|
)
|
||||||
|
@ -4,37 +4,38 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
aggregate_funcs = {
|
aggregate_funcs = {
|
||||||
'std': np.std,
|
"std": np.std,
|
||||||
'var': np.var,
|
"var": np.var,
|
||||||
'median': np.median,
|
"median": np.median,
|
||||||
'mean': np.mean,
|
"mean": np.mean,
|
||||||
'min': np.min,
|
"min": np.min,
|
||||||
'max': np.max,
|
"max": np.max,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input_file', required=True, type=str)
|
parser.add_argument("-i", "--input_file", required=True, type=str)
|
||||||
parser.add_argument('-n', '--repeat_times', required=True, type=int)
|
parser.add_argument("-n", "--repeat_times", required=True, type=int)
|
||||||
parser.add_argument('-o', '--output_file', required=False)
|
parser.add_argument("-o", "--output_file", required=False)
|
||||||
parser.add_argument('-f', '--func', required=False, default='mean')
|
parser.add_argument("-f", "--func", required=False, default="mean")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
stream = open(args.output_file, 'w') if args.output_file else sys.stdout
|
stream = open(args.output_file, "w") if args.output_file else sys.stdout
|
||||||
|
|
||||||
segment_scores = []
|
segment_scores = []
|
||||||
for line in open(args.input_file):
|
for line in open(args.input_file):
|
||||||
segment_scores.append(float(line.strip()))
|
segment_scores.append(float(line.strip()))
|
||||||
if len(segment_scores) == args.repeat_times:
|
if len(segment_scores) == args.repeat_times:
|
||||||
stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores)))
|
stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
|
||||||
segment_scores = []
|
segment_scores = []
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -4,14 +4,13 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from itertools import combinations
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from itertools import combinations
|
||||||
|
|
||||||
|
|
||||||
def read_translations(path, n_repeats):
|
def read_translations(path, n_repeats):
|
||||||
@ -19,7 +18,7 @@ def read_translations(path, n_repeats):
|
|||||||
segment_translations = []
|
segment_translations = []
|
||||||
translations = defaultdict(list)
|
translations = defaultdict(list)
|
||||||
for line in open(path):
|
for line in open(path):
|
||||||
segment_translations.append(' '.join(line.split()))
|
segment_translations.append(" ".join(line.split()))
|
||||||
if len(segment_translations) == n_repeats:
|
if len(segment_translations) == n_repeats:
|
||||||
translations[segment_counter] = segment_translations
|
translations[segment_counter] = segment_translations
|
||||||
segment_translations = []
|
segment_translations = []
|
||||||
@ -30,42 +29,55 @@ def read_translations(path, n_repeats):
|
|||||||
def generate_input(translations, n_repeats):
|
def generate_input(translations, n_repeats):
|
||||||
_, ref_path = tempfile.mkstemp()
|
_, ref_path = tempfile.mkstemp()
|
||||||
_, mt_path = tempfile.mkstemp()
|
_, mt_path = tempfile.mkstemp()
|
||||||
ref_fh = open(ref_path, 'w')
|
ref_fh = open(ref_path, "w")
|
||||||
mt_fh = open(mt_path, 'w')
|
mt_fh = open(mt_path, "w")
|
||||||
for segid in sorted(translations.keys()):
|
for segid in sorted(translations.keys()):
|
||||||
assert len(translations[segid]) == n_repeats
|
assert len(translations[segid]) == n_repeats
|
||||||
indexes = combinations(range(n_repeats), 2)
|
indexes = combinations(range(n_repeats), 2)
|
||||||
for idx1, idx2 in indexes:
|
for idx1, idx2 in indexes:
|
||||||
mt_fh.write(translations[segid][idx1].strip() + '\n')
|
mt_fh.write(translations[segid][idx1].strip() + "\n")
|
||||||
ref_fh.write(translations[segid][idx2].strip() + '\n')
|
ref_fh.write(translations[segid][idx2].strip() + "\n")
|
||||||
sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path))
|
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
|
||||||
return ref_path, mt_path
|
return ref_path, mt_path
|
||||||
|
|
||||||
|
|
||||||
def run_meteor(ref_path, mt_path, metric_path, lang='en'):
|
def run_meteor(ref_path, mt_path, metric_path, lang="en"):
|
||||||
_, out_path = tempfile.mkstemp()
|
_, out_path = tempfile.mkstemp()
|
||||||
subprocess.call([
|
subprocess.call(
|
||||||
'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path,
|
[
|
||||||
'-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R
|
"java",
|
||||||
'-norm',
|
"-Xmx2G",
|
||||||
'-l', lang], stdout=open(out_path, 'w'))
|
"-jar",
|
||||||
|
metric_path,
|
||||||
|
mt_path,
|
||||||
|
ref_path,
|
||||||
|
"-p",
|
||||||
|
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
|
||||||
|
"-norm",
|
||||||
|
"-l",
|
||||||
|
lang,
|
||||||
|
],
|
||||||
|
stdout=open(out_path, "w"),
|
||||||
|
)
|
||||||
os.remove(ref_path)
|
os.remove(ref_path)
|
||||||
os.remove(mt_path)
|
os.remove(mt_path)
|
||||||
sys.stderr.write('\nSaved Meteor output to %s' % out_path)
|
sys.stderr.write("\nSaved Meteor output to %s" % out_path)
|
||||||
return out_path
|
return out_path
|
||||||
|
|
||||||
|
|
||||||
def read_output(meteor_output_path, n_repeats):
|
def read_output(meteor_output_path, n_repeats):
|
||||||
n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2))
|
n_combinations = math.factorial(n_repeats) / (
|
||||||
|
math.factorial(2) * math.factorial(n_repeats - 2)
|
||||||
|
)
|
||||||
raw_scores = []
|
raw_scores = []
|
||||||
average_scores = []
|
average_scores = []
|
||||||
for line in open(meteor_output_path):
|
for line in open(meteor_output_path):
|
||||||
if not line.startswith('Segment '):
|
if not line.startswith("Segment "):
|
||||||
continue
|
continue
|
||||||
score = float(line.strip().split('\t')[1])
|
score = float(line.strip().split("\t")[1])
|
||||||
raw_scores.append(score)
|
raw_scores.append(score)
|
||||||
if len(raw_scores) == n_combinations:
|
if len(raw_scores) == n_combinations:
|
||||||
average_scores.append(sum(raw_scores)/n_combinations)
|
average_scores.append(sum(raw_scores) / n_combinations)
|
||||||
raw_scores = []
|
raw_scores = []
|
||||||
os.remove(meteor_output_path)
|
os.remove(meteor_output_path)
|
||||||
return average_scores
|
return average_scores
|
||||||
@ -73,25 +85,25 @@ def read_output(meteor_output_path, n_repeats):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input')
|
parser.add_argument("-i", "--input")
|
||||||
parser.add_argument('-n', '--repeat_times', type=int)
|
parser.add_argument("-n", "--repeat_times", type=int)
|
||||||
parser.add_argument('-m', '--meteor')
|
parser.add_argument("-m", "--meteor")
|
||||||
parser.add_argument('-o', '--output')
|
parser.add_argument("-o", "--output")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
translations = read_translations(args.infile, args.repetitions)
|
translations = read_translations(args.infile, args.repetitions)
|
||||||
sys.stderr.write('\nGenerating input for Meteor...')
|
sys.stderr.write("\nGenerating input for Meteor...")
|
||||||
ref_path, mt_path = generate_input(translations, args.repetitions)
|
ref_path, mt_path = generate_input(translations, args.repetitions)
|
||||||
sys.stderr.write('\nRunning Meteor...')
|
sys.stderr.write("\nRunning Meteor...")
|
||||||
out_path = run_meteor(ref_path, mt_path, args.meteor)
|
out_path = run_meteor(ref_path, mt_path, args.meteor)
|
||||||
sys.stderr.write('\nReading output...')
|
sys.stderr.write("\nReading output...")
|
||||||
scores = read_output(out_path, args.repetitions)
|
scores = read_output(out_path, args.repetitions)
|
||||||
sys.stderr.write('\nWriting results...')
|
sys.stderr.write("\nWriting results...")
|
||||||
with open(args.output, 'w') as o:
|
with open(args.output, "w") as o:
|
||||||
for scr in scores:
|
for scr in scores:
|
||||||
o.write('{}\n'.format(scr))
|
o.write("{}\n".format(scr))
|
||||||
o.close()
|
o.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -8,21 +8,21 @@ import sys
|
|||||||
|
|
||||||
|
|
||||||
def _normalize_spaces(line):
|
def _normalize_spaces(line):
|
||||||
return ' '.join(line.split())
|
return " ".join(line.split())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input_file', required=True, type=str)
|
parser.add_argument("-i", "--input_file", required=True, type=str)
|
||||||
parser.add_argument('-n', '--repeat_times', required=True, type=int)
|
parser.add_argument("-n", "--repeat_times", required=True, type=int)
|
||||||
parser.add_argument('-o', '--output_file', required=False, type=str)
|
parser.add_argument("-o", "--output_file", required=False, type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
stream = open(args.output_file, 'w') if args.output_file else sys.stdout
|
stream = open(args.output_file, "w") if args.output_file else sys.stdout
|
||||||
|
|
||||||
for line in open(args.input_file):
|
for line in open(args.input_file):
|
||||||
for _ in range(args.repeat_times):
|
for _ in range(args.repeat_times):
|
||||||
stream.write(_normalize_spaces(line) + '\n')
|
stream.write(_normalize_spaces(line) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -8,30 +8,31 @@
|
|||||||
Helper script to pre-compute embeddings for a wav2letter++ dataset
|
Helper script to pre-compute embeddings for a wav2letter++ dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
import pprint
|
import pprint
|
||||||
import glob, os, argparse
|
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tqdm
|
import tqdm
|
||||||
except:
|
except:
|
||||||
print("Install tqdm to use --log-format=tqdm")
|
print("Install tqdm to use --log-format=tqdm")
|
||||||
|
|
||||||
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
|
||||||
|
|
||||||
import tqdm
|
|
||||||
import soundfile as sf
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
|
|
||||||
class FilesDataset:
|
class FilesDataset:
|
||||||
def __init__(self, files, labels):
|
def __init__(self, files, labels):
|
||||||
self.files = files
|
self.files = files
|
||||||
if labels and osp.exists(labels):
|
if labels and osp.exists(labels):
|
||||||
with open(labels, 'r') as lbl_f:
|
with open(labels, "r") as lbl_f:
|
||||||
self.labels = [line.rstrip() for line in lbl_f]
|
self.labels = [line.rstrip() for line in lbl_f]
|
||||||
else:
|
else:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
@ -50,7 +51,7 @@ class FilesDataset:
|
|||||||
if self.labels:
|
if self.labels:
|
||||||
if isinstance(self.labels, str):
|
if isinstance(self.labels, str):
|
||||||
lbl_file = osp.splitext(fname)[0] + "." + self.labels
|
lbl_file = osp.splitext(fname)[0] + "." + self.labels
|
||||||
with open(lbl_file, 'r') as lblf:
|
with open(lbl_file, "r") as lblf:
|
||||||
lbls = lblf.readline()
|
lbls = lblf.readline()
|
||||||
assert lbls is not None
|
assert lbls is not None
|
||||||
else:
|
else:
|
||||||
@ -116,24 +117,24 @@ class DatasetWriter:
|
|||||||
assert len(files) > 0
|
assert len(files) > 0
|
||||||
|
|
||||||
if self.args.shard is not None:
|
if self.args.shard is not None:
|
||||||
files = files[self.args.shard::self.args.num_shards]
|
files = files[self.args.shard :: self.args.num_shards]
|
||||||
|
|
||||||
lbls = []
|
lbls = []
|
||||||
with open(self.data_file(split), 'w') as srcf:
|
with open(self.data_file(split), "w") as srcf:
|
||||||
for line, lbl in self.iterate(files):
|
for line, lbl in self.iterate(files):
|
||||||
print(line, file=srcf)
|
print(line, file=srcf)
|
||||||
if self.args.labels:
|
if self.args.labels:
|
||||||
lbls.append(lbl + '\n')
|
lbls.append(lbl + "\n")
|
||||||
|
|
||||||
if self.args.labels:
|
if self.args.labels:
|
||||||
assert all(a is not None for a in lbls)
|
assert all(a is not None for a in lbls)
|
||||||
with open(self.lbl_file(split), 'w') as lblf:
|
with open(self.lbl_file(split), "w") as lblf:
|
||||||
lblf.writelines(lbls)
|
lblf.writelines(lbls)
|
||||||
|
|
||||||
def iterate(self, files):
|
def iterate(self, files):
|
||||||
|
|
||||||
data = self.load_data(files)
|
data = self.load_data(files)
|
||||||
for samples in tqdm.tqdm(data, total=len(files)//32):
|
for samples in tqdm.tqdm(data, total=len(files) // 32):
|
||||||
|
|
||||||
for wav, lbl in samples:
|
for wav, lbl in samples:
|
||||||
x = wav.unsqueeze(0).float().cuda()
|
x = wav.unsqueeze(0).float().cuda()
|
||||||
@ -162,7 +163,6 @@ class DatasetWriter:
|
|||||||
idx = torch.cat(result, dim=0)
|
idx = torch.cat(result, dim=0)
|
||||||
yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
|
yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
|
||||||
|
|
||||||
|
|
||||||
def lbl_file(self, name):
|
def lbl_file(self, name):
|
||||||
shard_part = "" if self.args.shard is None else f".{self.args.shard}"
|
shard_part = "" if self.args.shard is None else f".{self.args.shard}"
|
||||||
return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
|
return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
|
||||||
@ -230,7 +230,9 @@ class DatasetWriter:
|
|||||||
|
|
||||||
self.process_splits()
|
self.process_splits()
|
||||||
|
|
||||||
if hasattr(self.model.feature_extractor, "vars") and (self.args.shard is None or self.args.shard == 0):
|
if hasattr(self.model.feature_extractor, "vars") and (
|
||||||
|
self.args.shard is None or self.args.shard == 0
|
||||||
|
):
|
||||||
vars = (
|
vars = (
|
||||||
self.model.feature_extractor.vars.view(
|
self.model.feature_extractor.vars.view(
|
||||||
self.model.feature_extractor.banks,
|
self.model.feature_extractor.banks,
|
||||||
@ -248,4 +250,4 @@ if __name__ == "__main__":
|
|||||||
write_data = DatasetWriter()
|
write_data = DatasetWriter()
|
||||||
|
|
||||||
write_data()
|
write_data()
|
||||||
print("Done.")
|
print("Done.")
|
||||||
|
@ -14,13 +14,12 @@ import os
|
|||||||
from shutil import copy
|
from shutil import copy
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
def read_audio(fname):
|
def read_audio(fname):
|
||||||
@ -33,7 +32,6 @@ def read_audio(fname):
|
|||||||
|
|
||||||
|
|
||||||
class PretrainedWav2VecModel(nn.Module):
|
class PretrainedWav2VecModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, fname):
|
def __init__(self, fname):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -55,32 +53,33 @@ class PretrainedWav2VecModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingWriterConfig(argparse.ArgumentParser):
|
class EmbeddingWriterConfig(argparse.ArgumentParser):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("Pre-compute embeddings for wav2letter++ datasets")
|
super().__init__("Pre-compute embeddings for wav2letter++ datasets")
|
||||||
|
|
||||||
kwargs = {"action": "store", "type": str, "required": True}
|
kwargs = {"action": "store", "type": str, "required": True}
|
||||||
|
|
||||||
self.add_argument("--input", "-i",
|
self.add_argument("--input", "-i", help="Input Directory", **kwargs)
|
||||||
help="Input Directory", **kwargs)
|
self.add_argument("--output", "-o", help="Output Directory", **kwargs)
|
||||||
self.add_argument("--output", "-o",
|
self.add_argument("--model", help="Path to model checkpoint", **kwargs)
|
||||||
help="Output Directory", **kwargs)
|
self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs)
|
||||||
self.add_argument("--model",
|
self.add_argument(
|
||||||
help="Path to model checkpoint", **kwargs)
|
"--ext", default="wav", required=False, help="Audio file extension"
|
||||||
self.add_argument("--split",
|
)
|
||||||
help="Dataset Splits", nargs='+', **kwargs)
|
|
||||||
self.add_argument("--ext", default="wav", required=False,
|
|
||||||
help="Audio file extension")
|
|
||||||
|
|
||||||
self.add_argument("--no-copy-labels", action="store_true",
|
self.add_argument(
|
||||||
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.")
|
"--no-copy-labels",
|
||||||
self.add_argument("--use-feat", action="store_true",
|
action="store_true",
|
||||||
help="Use the feature vector ('z') instead of context vector ('c') for features")
|
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.",
|
||||||
self.add_argument("--gpu",
|
)
|
||||||
help="GPU to use", default=0, type=int)
|
self.add_argument(
|
||||||
|
"--use-feat",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the feature vector ('z') instead of context vector ('c') for features",
|
||||||
|
)
|
||||||
|
self.add_argument("--gpu", help="GPU to use", default=0, type=int)
|
||||||
|
|
||||||
|
|
||||||
class Prediction():
|
class Prediction:
|
||||||
""" Lightweight wrapper around a fairspeech embedding model """
|
""" Lightweight wrapper around a fairspeech embedding model """
|
||||||
|
|
||||||
def __init__(self, fname, gpu=0):
|
def __init__(self, fname, gpu=0):
|
||||||
@ -95,7 +94,7 @@ class Prediction():
|
|||||||
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
|
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
class H5Writer():
|
class H5Writer:
|
||||||
""" Write features as hdf5 file in wav2letter++ compatible format """
|
""" Write features as hdf5 file in wav2letter++ compatible format """
|
||||||
|
|
||||||
def __init__(self, fname):
|
def __init__(self, fname):
|
||||||
@ -112,7 +111,7 @@ class H5Writer():
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingDatasetWriter(object):
|
class EmbeddingDatasetWriter(object):
|
||||||
""" Given a model and a wav2letter++ dataset, pre-compute and store embeddings
|
"""Given a model and a wav2letter++ dataset, pre-compute and store embeddings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_root, str :
|
input_root, str :
|
||||||
@ -123,13 +122,17 @@ class EmbeddingDatasetWriter(object):
|
|||||||
Dataset split
|
Dataset split
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_root, output_root, split,
|
def __init__(
|
||||||
model_fname,
|
self,
|
||||||
extension="wav",
|
input_root,
|
||||||
gpu=0,
|
output_root,
|
||||||
verbose=False,
|
split,
|
||||||
use_feat=False,
|
model_fname,
|
||||||
):
|
extension="wav",
|
||||||
|
gpu=0,
|
||||||
|
verbose=False,
|
||||||
|
use_feat=False,
|
||||||
|
):
|
||||||
|
|
||||||
assert os.path.exists(model_fname)
|
assert os.path.exists(model_fname)
|
||||||
|
|
||||||
@ -143,8 +146,9 @@ class EmbeddingDatasetWriter(object):
|
|||||||
self.extension = extension
|
self.extension = extension
|
||||||
self.use_feat = use_feat
|
self.use_feat = use_feat
|
||||||
|
|
||||||
assert os.path.exists(self.input_path), \
|
assert os.path.exists(self.input_path), "Input path '{}' does not exist".format(
|
||||||
"Input path '{}' does not exist".format(self.input_path)
|
self.input_path
|
||||||
|
)
|
||||||
|
|
||||||
def _progress(self, iterable, **kwargs):
|
def _progress(self, iterable, **kwargs):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
@ -176,7 +180,11 @@ class EmbeddingDatasetWriter(object):
|
|||||||
def copy_labels(self):
|
def copy_labels(self):
|
||||||
self.require_output_path()
|
self.require_output_path()
|
||||||
|
|
||||||
labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))))
|
labels = list(
|
||||||
|
filter(
|
||||||
|
lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))
|
||||||
|
)
|
||||||
|
)
|
||||||
for fname in tqdm.tqdm(labels):
|
for fname in tqdm.tqdm(labels):
|
||||||
copy(fname, self.output_path)
|
copy(fname, self.output_path)
|
||||||
|
|
||||||
@ -191,10 +199,16 @@ class EmbeddingDatasetWriter(object):
|
|||||||
|
|
||||||
paths = self.input_fnames
|
paths = self.input_fnames
|
||||||
|
|
||||||
fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \
|
fnames_context = map(
|
||||||
map(os.path.basename, paths))
|
lambda x: os.path.join(
|
||||||
|
self.output_path, x.replace("." + self.extension, ".h5context")
|
||||||
|
),
|
||||||
|
map(os.path.basename, paths),
|
||||||
|
)
|
||||||
|
|
||||||
for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)):
|
for name, target_fname in self._progress(
|
||||||
|
zip(paths, fnames_context), total=len(self)
|
||||||
|
):
|
||||||
wav, sr = read_audio(name)
|
wav, sr = read_audio(name)
|
||||||
z, c = self.model(wav)
|
z, c = self.model(wav)
|
||||||
feat = z if self.use_feat else c
|
feat = z if self.use_feat else c
|
||||||
@ -204,7 +218,8 @@ class EmbeddingDatasetWriter(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|
||||||
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
|
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
|
||||||
n_files=len(self), **self.__dict__)
|
n_files=len(self), **self.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -10,32 +10,50 @@ Data pre-processing: build vocabularies and binarize training data.
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import soundfile
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
|
parser.add_argument(
|
||||||
parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
|
"root", metavar="DIR", help="root directory containing flac files to index"
|
||||||
help='percentage of data to use as validation set (between 0 and 1)')
|
)
|
||||||
parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
|
parser.add_argument(
|
||||||
parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
|
"--valid-percent",
|
||||||
parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
|
default=0.01,
|
||||||
parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
|
type=float,
|
||||||
help='if set, path must contain this substring for a file to be included in the manifest')
|
metavar="D",
|
||||||
|
help="percentage of data to use as validation set (between 0 and 1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dest", default=".", type=str, metavar="DIR", help="output directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--path-must-contain",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
metavar="FRAG",
|
||||||
|
help="if set, path must contain this substring for a file to be included in the manifest",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert args.valid_percent >= 0 and args.valid_percent <= 1.
|
assert args.valid_percent >= 0 and args.valid_percent <= 1.0
|
||||||
|
|
||||||
dir_path = os.path.realpath(args.root)
|
dir_path = os.path.realpath(args.root)
|
||||||
search_path = os.path.join(dir_path, '**/*.' + args.ext)
|
search_path = os.path.join(dir_path, "**/*." + args.ext)
|
||||||
rand = random.Random(args.seed)
|
rand = random.Random(args.seed)
|
||||||
|
|
||||||
with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
|
with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open(
|
||||||
os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
|
os.path.join(args.dest, "valid.tsv"), "w"
|
||||||
|
) as valid_f:
|
||||||
print(dir_path, file=train_f)
|
print(dir_path, file=train_f)
|
||||||
print(dir_path, file=valid_f)
|
print(dir_path, file=valid_f)
|
||||||
|
|
||||||
@ -47,10 +65,12 @@ def main(args):
|
|||||||
|
|
||||||
frames = soundfile.info(fname).frames
|
frames = soundfile.info(fname).frames
|
||||||
dest = train_f if rand.random() > args.valid_percent else valid_f
|
dest = train_f if rand.random() > args.valid_percent else valid_f
|
||||||
print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
|
print(
|
||||||
|
"{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -4,16 +4,17 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
"""isort:skip_file"""
|
"""isort:skip_file"""
|
||||||
|
|
||||||
__all__ = ['pdb']
|
__all__ = ["pdb"]
|
||||||
__version__ = '1.0.0a0'
|
__version__ = "1.0.0a0"
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# backwards compatibility to support `from fairseq.meters import AverageMeter`
|
# backwards compatibility to support `from fairseq.meters import AverageMeter`
|
||||||
from fairseq.logging import meters, metrics, progress_bar # noqa
|
from fairseq.logging import meters, metrics, progress_bar # noqa
|
||||||
sys.modules['fairseq.meters'] = meters
|
|
||||||
sys.modules['fairseq.metrics'] = metrics
|
sys.modules["fairseq.meters"] = meters
|
||||||
sys.modules['fairseq.progress_bar'] = progress_bar
|
sys.modules["fairseq.metrics"] = metrics
|
||||||
|
sys.modules["fairseq.progress_bar"] = progress_bar
|
||||||
|
|
||||||
import fairseq.criterions # noqa
|
import fairseq.criterions # noqa
|
||||||
import fairseq.models # noqa
|
import fairseq.models # noqa
|
||||||
|
@ -4,9 +4,4 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# import models/tasks to register them
|
# import models/tasks to register them
|
||||||
from . import ( # noqa
|
from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
|
||||||
dummy_lm,
|
|
||||||
dummy_masked_lm,
|
|
||||||
dummy_model,
|
|
||||||
dummy_mt,
|
|
||||||
)
|
|
||||||
|
@ -7,25 +7,27 @@ import logging
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq.data import Dictionary, FairseqDataset
|
from fairseq.data import Dictionary, FairseqDataset
|
||||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@register_task('dummy_lm')
|
@register_task("dummy_lm")
|
||||||
class DummyLMTask(LegacyFairseqTask):
|
class DummyLMTask(LegacyFairseqTask):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add task-specific arguments to the parser."""
|
"""Add task-specific arguments to the parser."""
|
||||||
parser.add_argument('--dict-size', default=49996, type=int)
|
parser.add_argument("--dict-size", default=49996, type=int)
|
||||||
parser.add_argument('--dataset-size', default=100000, type=int)
|
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||||
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
parser.add_argument(
|
||||||
help='max number of total tokens over all segments '
|
"--tokens-per-sample",
|
||||||
'per sample for BERT dataset')
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help="max number of total tokens over all segments "
|
||||||
|
"per sample for BERT dataset",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, args, dictionary):
|
def __init__(self, args, dictionary):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
@ -44,8 +46,8 @@ class DummyLMTask(LegacyFairseqTask):
|
|||||||
"""Setup the task. """
|
"""Setup the task. """
|
||||||
dictionary = Dictionary()
|
dictionary = Dictionary()
|
||||||
for i in range(args.dict_size):
|
for i in range(args.dict_size):
|
||||||
dictionary.add_symbol('word{}'.format(i))
|
dictionary.add_symbol("word{}".format(i))
|
||||||
logger.info('dictionary: {} types'.format(len(dictionary)))
|
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||||
return cls(args, dictionary)
|
return cls(args, dictionary)
|
||||||
|
|
||||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||||
@ -59,16 +61,16 @@ class DummyLMTask(LegacyFairseqTask):
|
|||||||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
||||||
self.datasets[split] = DummyDataset(
|
self.datasets[split] = DummyDataset(
|
||||||
{
|
{
|
||||||
'id': 1,
|
"id": 1,
|
||||||
'net_input': {
|
"net_input": {
|
||||||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||||
'src_lengths': torch.full(
|
"src_lengths": torch.full(
|
||||||
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
(bsz,), self.args.tokens_per_sample, dtype=torch.long
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||||
'nsentences': bsz,
|
"nsentences": bsz,
|
||||||
'ntokens': bsz * self.args.tokens_per_sample,
|
"ntokens": bsz * self.args.tokens_per_sample,
|
||||||
},
|
},
|
||||||
num_items=self.args.dataset_size,
|
num_items=self.args.dataset_size,
|
||||||
item_size=self.args.tokens_per_sample,
|
item_size=self.args.tokens_per_sample,
|
||||||
@ -84,7 +86,6 @@ class DummyLMTask(LegacyFairseqTask):
|
|||||||
|
|
||||||
|
|
||||||
class DummyDataset(FairseqDataset):
|
class DummyDataset(FairseqDataset):
|
||||||
|
|
||||||
def __init__(self, batch, num_items, item_size):
|
def __init__(self, batch, num_items, item_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
|
@ -7,32 +7,34 @@ import logging
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq.data import Dictionary, FairseqDataset
|
from fairseq.data import Dictionary, FairseqDataset
|
||||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@register_task('dummy_masked_lm')
|
@register_task("dummy_masked_lm")
|
||||||
class DummyMaskedLMTask(LegacyFairseqTask):
|
class DummyMaskedLMTask(LegacyFairseqTask):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add task-specific arguments to the parser."""
|
"""Add task-specific arguments to the parser."""
|
||||||
parser.add_argument('--dict-size', default=49995, type=int)
|
parser.add_argument("--dict-size", default=49995, type=int)
|
||||||
parser.add_argument('--dataset-size', default=100000, type=int)
|
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||||
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
parser.add_argument(
|
||||||
help='max number of total tokens over all segments '
|
"--tokens-per-sample",
|
||||||
'per sample for BERT dataset')
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help="max number of total tokens over all segments "
|
||||||
|
"per sample for BERT dataset",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, args, dictionary):
|
def __init__(self, args, dictionary):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.dictionary = dictionary
|
self.dictionary = dictionary
|
||||||
|
|
||||||
# add mask token
|
# add mask token
|
||||||
self.mask_idx = dictionary.add_symbol('<mask>')
|
self.mask_idx = dictionary.add_symbol("<mask>")
|
||||||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||||
|
|
||||||
mask_idx = 0
|
mask_idx = 0
|
||||||
@ -52,8 +54,8 @@ class DummyMaskedLMTask(LegacyFairseqTask):
|
|||||||
"""Setup the task. """
|
"""Setup the task. """
|
||||||
dictionary = Dictionary()
|
dictionary = Dictionary()
|
||||||
for i in range(args.dict_size):
|
for i in range(args.dict_size):
|
||||||
dictionary.add_symbol('word{}'.format(i))
|
dictionary.add_symbol("word{}".format(i))
|
||||||
logger.info('dictionary: {} types'.format(len(dictionary)))
|
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||||
return cls(args, dictionary)
|
return cls(args, dictionary)
|
||||||
|
|
||||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||||
@ -67,16 +69,16 @@ class DummyMaskedLMTask(LegacyFairseqTask):
|
|||||||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
||||||
self.datasets[split] = DummyDataset(
|
self.datasets[split] = DummyDataset(
|
||||||
{
|
{
|
||||||
'id': 1,
|
"id": 1,
|
||||||
'net_input': {
|
"net_input": {
|
||||||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||||
'src_lengths': torch.full(
|
"src_lengths": torch.full(
|
||||||
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
(bsz,), self.args.tokens_per_sample, dtype=torch.long
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||||
'nsentences': bsz,
|
"nsentences": bsz,
|
||||||
'ntokens': bsz * self.args.tokens_per_sample,
|
"ntokens": bsz * self.args.tokens_per_sample,
|
||||||
},
|
},
|
||||||
num_items=self.args.dataset_size,
|
num_items=self.args.dataset_size,
|
||||||
item_size=self.args.tokens_per_sample,
|
item_size=self.args.tokens_per_sample,
|
||||||
@ -92,7 +94,6 @@ class DummyMaskedLMTask(LegacyFairseqTask):
|
|||||||
|
|
||||||
|
|
||||||
class DummyDataset(FairseqDataset):
|
class DummyDataset(FairseqDataset):
|
||||||
|
|
||||||
def __init__(self, batch, num_items, item_size):
|
def __init__(self, batch, num_items, item_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from fairseq.data import Dictionary
|
from fairseq.data import Dictionary
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqDecoder,
|
FairseqDecoder,
|
||||||
@ -15,17 +14,16 @@ from fairseq.models import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model('dummy_model')
|
@register_model("dummy_model")
|
||||||
class DummyModel(FairseqLanguageModel):
|
class DummyModel(FairseqLanguageModel):
|
||||||
|
|
||||||
def __init__(self, args, encoder):
|
def __init__(self, args, encoder):
|
||||||
super().__init__(encoder)
|
super().__init__(encoder)
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
parser.add_argument('--num-layers', type=int, default=24)
|
parser.add_argument("--num-layers", type=int, default=24)
|
||||||
parser.add_argument('--embed-dim', type=int, default=1024)
|
parser.add_argument("--embed-dim", type=int, default=1024)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model(cls, args, task):
|
def build_model(cls, args, task):
|
||||||
@ -41,32 +39,35 @@ class DummyModel(FairseqLanguageModel):
|
|||||||
|
|
||||||
|
|
||||||
class DummyEncoder(FairseqDecoder):
|
class DummyEncoder(FairseqDecoder):
|
||||||
|
|
||||||
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
||||||
super().__init__(Dictionary())
|
super().__init__(Dictionary())
|
||||||
self.embed = nn.Embedding(
|
self.embed = nn.Embedding(
|
||||||
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
||||||
)
|
)
|
||||||
self.layers_a = nn.ModuleList([
|
self.layers_a = nn.ModuleList(
|
||||||
nn.Sequential(
|
[
|
||||||
nn.LayerNorm(embed_dim),
|
nn.Sequential(
|
||||||
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
|
nn.LayerNorm(embed_dim),
|
||||||
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
|
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
|
||||||
nn.Linear(embed_dim, embed_dim), # output projection
|
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
|
||||||
nn.Dropout(),
|
nn.Linear(embed_dim, embed_dim), # output projection
|
||||||
)
|
nn.Dropout(),
|
||||||
for i in range(num_layers)
|
)
|
||||||
])
|
for i in range(num_layers)
|
||||||
self.layers_b = nn.ModuleList([
|
]
|
||||||
nn.Sequential(
|
)
|
||||||
nn.LayerNorm(embed_dim),
|
self.layers_b = nn.ModuleList(
|
||||||
nn.Linear(embed_dim, 4*embed_dim), # FFN
|
[
|
||||||
nn.ReLU(),
|
nn.Sequential(
|
||||||
nn.Linear(4*embed_dim, embed_dim), # FFN
|
nn.LayerNorm(embed_dim),
|
||||||
nn.Dropout(0.1),
|
nn.Linear(embed_dim, 4 * embed_dim), # FFN
|
||||||
)
|
nn.ReLU(),
|
||||||
for i in range(num_layers)
|
nn.Linear(4 * embed_dim, embed_dim), # FFN
|
||||||
])
|
nn.Dropout(0.1),
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.out_proj = nn.Linear(embed_dim, num_embed)
|
self.out_proj = nn.Linear(embed_dim, num_embed)
|
||||||
|
|
||||||
def forward(self, tokens, masked_tokens=None):
|
def forward(self, tokens, masked_tokens=None):
|
||||||
@ -90,6 +91,6 @@ class DummyEncoder(FairseqDecoder):
|
|||||||
return F.softmax(logits, dim=-1)
|
return F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture('dummy_model', 'dummy_model')
|
@register_model_architecture("dummy_model", "dummy_model")
|
||||||
def base_architecture(args):
|
def base_architecture(args):
|
||||||
pass
|
pass
|
||||||
|
@ -7,24 +7,22 @@ import logging
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq.data import Dictionary, FairseqDataset
|
from fairseq.data import Dictionary, FairseqDataset
|
||||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@register_task('dummy_mt')
|
@register_task("dummy_mt")
|
||||||
class DummyMTTask(LegacyFairseqTask):
|
class DummyMTTask(LegacyFairseqTask):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add task-specific arguments to the parser."""
|
"""Add task-specific arguments to the parser."""
|
||||||
parser.add_argument('--dict-size', default=49996, type=int)
|
parser.add_argument("--dict-size", default=49996, type=int)
|
||||||
parser.add_argument('--dataset-size', default=100000, type=int)
|
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||||
parser.add_argument('--src-len', default=30, type=int)
|
parser.add_argument("--src-len", default=30, type=int)
|
||||||
parser.add_argument('--tgt-len', default=30, type=int)
|
parser.add_argument("--tgt-len", default=30, type=int)
|
||||||
|
|
||||||
def __init__(self, args, dictionary):
|
def __init__(self, args, dictionary):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
@ -41,8 +39,8 @@ class DummyMTTask(LegacyFairseqTask):
|
|||||||
"""Setup the task. """
|
"""Setup the task. """
|
||||||
dictionary = Dictionary()
|
dictionary = Dictionary()
|
||||||
for i in range(args.dict_size):
|
for i in range(args.dict_size):
|
||||||
dictionary.add_symbol('word{}'.format(i))
|
dictionary.add_symbol("word{}".format(i))
|
||||||
logger.info('dictionary: {} types'.format(len(dictionary)))
|
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||||
|
|
||||||
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
||||||
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
||||||
@ -62,17 +60,17 @@ class DummyMTTask(LegacyFairseqTask):
|
|||||||
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
||||||
self.datasets[split] = DummyDataset(
|
self.datasets[split] = DummyDataset(
|
||||||
{
|
{
|
||||||
'id': 1,
|
"id": 1,
|
||||||
'net_input': {
|
"net_input": {
|
||||||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||||
'src_lengths': torch.full(
|
"src_lengths": torch.full(
|
||||||
(bsz, ), self.args.src_len, dtype=torch.long
|
(bsz,), self.args.src_len, dtype=torch.long
|
||||||
),
|
),
|
||||||
'prev_output_tokens': tgt.clone(),
|
"prev_output_tokens": tgt.clone(),
|
||||||
},
|
},
|
||||||
'target': tgt,
|
"target": tgt,
|
||||||
'nsentences': bsz,
|
"nsentences": bsz,
|
||||||
'ntokens': bsz * self.args.tgt_len,
|
"ntokens": bsz * self.args.tgt_len,
|
||||||
},
|
},
|
||||||
num_items=self.args.dataset_size,
|
num_items=self.args.dataset_size,
|
||||||
item_size=item_size,
|
item_size=item_size,
|
||||||
@ -88,7 +86,6 @@ class DummyMTTask(LegacyFairseqTask):
|
|||||||
|
|
||||||
|
|
||||||
class DummyDataset(FairseqDataset):
|
class DummyDataset(FairseqDataset):
|
||||||
|
|
||||||
def __init__(self, batch, num_items, item_size):
|
def __init__(self, batch, num_items, item_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
|
@ -6,9 +6,10 @@
|
|||||||
import os
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
from fairseq.tokenizer import tokenize_line
|
|
||||||
import torch
|
import torch
|
||||||
from fairseq.file_io import PathManager
|
from fairseq.file_io import PathManager
|
||||||
|
from fairseq.tokenizer import tokenize_line
|
||||||
|
|
||||||
|
|
||||||
def safe_readline(f):
|
def safe_readline(f):
|
||||||
pos = f.tell()
|
pos = f.tell()
|
||||||
|
@ -67,12 +67,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
|||||||
or is_better(val_loss, save_checkpoint.best)
|
or is_better(val_loss, save_checkpoint.best)
|
||||||
)
|
)
|
||||||
if val_loss is not None and args.keep_best_checkpoints > 0:
|
if val_loss is not None and args.keep_best_checkpoints > 0:
|
||||||
checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
|
checkpoint_conds[
|
||||||
args.best_checkpoint_metric, val_loss)] = (
|
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
|
||||||
not hasattr(save_checkpoint, "best")
|
] = not hasattr(save_checkpoint, "best") or is_better(
|
||||||
or is_better(val_loss, save_checkpoint.best)
|
val_loss, save_checkpoint.best
|
||||||
)
|
)
|
||||||
checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints
|
checkpoint_conds[
|
||||||
|
"checkpoint_last{}.pt".format(suffix)
|
||||||
|
] = not args.no_last_checkpoints
|
||||||
|
|
||||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||||
if hasattr(save_checkpoint, "best"):
|
if hasattr(save_checkpoint, "best"):
|
||||||
@ -112,10 +114,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
|||||||
if args.keep_best_checkpoints > 0:
|
if args.keep_best_checkpoints > 0:
|
||||||
# only keep the best N checkpoints according to validation metric
|
# only keep the best N checkpoints according to validation metric
|
||||||
checkpoints = checkpoint_paths(
|
checkpoints = checkpoint_paths(
|
||||||
args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
|
args.save_dir,
|
||||||
|
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
|
||||||
|
args.best_checkpoint_metric
|
||||||
|
),
|
||||||
|
)
|
||||||
if not args.maximize_best_checkpoint_metric:
|
if not args.maximize_best_checkpoint_metric:
|
||||||
checkpoints = checkpoints[::-1]
|
checkpoints = checkpoints[::-1]
|
||||||
for old_chk in checkpoints[args.keep_best_checkpoints:]:
|
for old_chk in checkpoints[args.keep_best_checkpoints :]:
|
||||||
if os.path.lexists(old_chk):
|
if os.path.lexists(old_chk):
|
||||||
os.remove(old_chk)
|
os.remove(old_chk)
|
||||||
|
|
||||||
@ -133,16 +139,23 @@ def load_checkpoint(args, trainer, **passthrough_args):
|
|||||||
reset_meters = args.reset_meters
|
reset_meters = args.reset_meters
|
||||||
reset_dataloader = args.reset_dataloader
|
reset_dataloader = args.reset_dataloader
|
||||||
|
|
||||||
if getattr(args, 'finetune_from_model', None) is not None \
|
if getattr(args, "finetune_from_model", None) is not None and (
|
||||||
and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
|
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
||||||
raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer"
|
):
|
||||||
" or reset_lr_scheduler or reset_meters or reset_dataloader")
|
raise ValueError(
|
||||||
|
"--finetune-from-model can not be set together with either --reset-optimizer"
|
||||||
|
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||||
|
)
|
||||||
|
|
||||||
suffix = getattr(args, "checkpoint_suffix", "")
|
suffix = getattr(args, "checkpoint_suffix", "")
|
||||||
if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt'
|
if (
|
||||||
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
|
args.restore_file == "checkpoint_last.pt"
|
||||||
|
): # default value of restore_file is 'checkpoint_last.pt'
|
||||||
|
checkpoint_path = os.path.join(
|
||||||
|
args.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||||
|
)
|
||||||
first_launch = not PathManager.exists(checkpoint_path)
|
first_launch = not PathManager.exists(checkpoint_path)
|
||||||
if getattr(args, 'finetune_from_model', None) is not None and first_launch:
|
if getattr(args, "finetune_from_model", None) is not None and first_launch:
|
||||||
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
||||||
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
||||||
if PathManager.exists(args.finetune_from_model):
|
if PathManager.exists(args.finetune_from_model):
|
||||||
@ -151,19 +164,26 @@ def load_checkpoint(args, trainer, **passthrough_args):
|
|||||||
reset_lr_scheduler = True
|
reset_lr_scheduler = True
|
||||||
reset_meters = True
|
reset_meters = True
|
||||||
reset_dataloader = True
|
reset_dataloader = True
|
||||||
logger.info(f'loading pretrained model from {checkpoint_path}: '
|
logger.info(
|
||||||
'optimizer, lr scheduler, meters, dataloader will be reset')
|
f"loading pretrained model from {checkpoint_path}: "
|
||||||
|
"optimizer, lr scheduler, meters, dataloader will be reset"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist')
|
raise ValueError(
|
||||||
|
f"--funetune-from-model {args.finetune_from_model} does not exist"
|
||||||
|
)
|
||||||
elif getattr(args, "model_parallel_size", 1) > 1:
|
elif getattr(args, "model_parallel_size", 1) > 1:
|
||||||
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
|
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
|
||||||
else:
|
else:
|
||||||
checkpoint_path = args.restore_file
|
checkpoint_path = args.restore_file
|
||||||
|
|
||||||
if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None):
|
if args.restore_file != "checkpoint_last.pt" and getattr(
|
||||||
|
args, "finetune_from_model", None
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'--finetune-from-model and --restore-file (non-default value) '
|
"--finetune-from-model and --restore-file (non-default value) "
|
||||||
'can not be specified together: ' + str(args))
|
"can not be specified together: " + str(args)
|
||||||
|
)
|
||||||
|
|
||||||
extra_state = trainer.load_checkpoint(
|
extra_state = trainer.load_checkpoint(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
@ -213,7 +233,9 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
|
def load_model_ensemble(
|
||||||
|
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
|
||||||
|
):
|
||||||
"""Loads an ensemble of models.
|
"""Loads an ensemble of models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -222,18 +244,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s
|
|||||||
were used during model training
|
were used during model training
|
||||||
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
||||||
"""
|
"""
|
||||||
assert not (strict and num_shards > 1), \
|
assert not (
|
||||||
"Cannot load state dict with strict=True and checkpoint shards > 1"
|
strict and num_shards > 1
|
||||||
|
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||||
ensemble, args, _task = load_model_ensemble_and_task(
|
ensemble, args, _task = load_model_ensemble_and_task(
|
||||||
filenames, arg_overrides, task, strict, suffix, num_shards,
|
filenames,
|
||||||
|
arg_overrides,
|
||||||
|
task,
|
||||||
|
strict,
|
||||||
|
suffix,
|
||||||
|
num_shards,
|
||||||
)
|
)
|
||||||
return ensemble, args
|
return ensemble, args
|
||||||
|
|
||||||
|
|
||||||
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
|
def load_model_ensemble_and_task(
|
||||||
|
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
|
||||||
|
):
|
||||||
from fairseq import tasks
|
from fairseq import tasks
|
||||||
assert not (strict and num_shards > 1), \
|
|
||||||
"Cannot load state dict with strict=True and checkpoint shards > 1"
|
assert not (
|
||||||
|
strict and num_shards > 1
|
||||||
|
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||||
ensemble = []
|
ensemble = []
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
orig_filename = filename
|
orig_filename = filename
|
||||||
@ -533,7 +565,9 @@ def verify_checkpoint_directory(save_dir: str) -> None:
|
|||||||
with open(temp_file_path, "w"):
|
with open(temp_file_path, "w"):
|
||||||
pass
|
pass
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
|
logger.warning(
|
||||||
|
"Unable to access checkpoint save directory: {}".format(save_dir)
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user