Apply black+isort (#1357)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357

Reviewed By: alexeib

Differential Revision: D24377772

fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
This commit is contained in:
Myle Ott 2020-10-18 18:13:29 -07:00 committed by Facebook GitHub Bot
parent 5695cdfb2c
commit a48f235636
396 changed files with 15418 additions and 9810 deletions

View File

@ -20,10 +20,11 @@
import os import os
import sys import sys
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath('..'))
source_suffix = ['.rst'] # source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
source_suffix = [".rst"]
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
@ -35,34 +36,34 @@ source_suffix = ['.rst']
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinxarg.ext', "sphinxarg.ext",
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = 'fairseq' project = "fairseq"
copyright = '2019, Facebook AI Research (FAIR)' copyright = "2019, Facebook AI Research (FAIR)"
author = 'Facebook AI Research (FAIR)' author = "Facebook AI Research (FAIR)"
github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.9.0' version = "0.9.0"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '0.9.0' release = "0.9.0"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
@ -74,11 +75,11 @@ language = None
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path # This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
highlight_language = 'python' highlight_language = "python"
# If true, `todo` and `todoList` produce output, else they produce nothing. # If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False todo_include_todos = False
@ -89,7 +90,7 @@ todo_include_todos = False
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
@ -100,11 +101,11 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
html_context = { html_context = {
'css_files': [ "css_files": [
'_static/theme_overrides.css', # override wide tables in RTD theme "_static/theme_overrides.css", # override wide tables in RTD theme
], ],
} }
@ -113,7 +114,7 @@ html_context = {
# #
# This is required for the alabaster theme # This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
#html_sidebars = { # html_sidebars = {
# '**': [ # '**': [
# 'about.html', # 'about.html',
# 'navigation.html', # 'navigation.html',
@ -121,12 +122,12 @@ html_context = {
# 'searchbox.html', # 'searchbox.html',
# 'donate.html', # 'donate.html',
# ] # ]
#} # }
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'numpy': ('http://docs.scipy.org/doc/numpy/', None), "numpy": ("http://docs.scipy.org/doc/numpy/", None),
'python': ('https://docs.python.org/', None), "python": ("https://docs.python.org/", None),
'torch': ('https://pytorch.org/docs/master/', None), "torch": ("https://pytorch.org/docs/master/", None),
} }

View File

@ -3,6 +3,6 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
__version__ = '0.9.0' __version__ = "0.9.0"
import examples.noisychannel # noqa import examples.noisychannel # noqa

View File

@ -7,8 +7,8 @@
import argparse import argparse
import fileinput import fileinput
import hashlib import hashlib
from multiprocessing import Pool
import sys import sys
from multiprocessing import Pool
def get_hashes_and_lines(raw_line): def get_hashes_and_lines(raw_line):
@ -18,12 +18,12 @@ def get_hashes_and_lines(raw_line):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, default=10) parser.add_argument("--workers", type=int, default=10)
parser.add_argument('files', nargs='*', help='input files') parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args() args = parser.parse_args()
seen = set() seen = set()
with fileinput.input(args.files, mode='rb') as h: with fileinput.input(args.files, mode="rb") as h:
pool = Pool(args.workers) pool = Pool(args.workers)
results = pool.imap_unordered(get_hashes_and_lines, h, 1000) results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
for i, (hash, raw_line) in enumerate(results): for i, (hash, raw_line) in enumerate(results):
@ -37,5 +37,5 @@ def main():
print(file=sys.stderr, flush=True) print(file=sys.stderr, flush=True)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -11,26 +11,38 @@ from tqdm import tqdm
def main(): def main():
parser = argparse.ArgumentParser(description=( parser = argparse.ArgumentParser(
'Extract back-translations from the stdout of fairseq-generate. ' description=(
'If there are multiply hypotheses for a source, we only keep the first one. ' "Extract back-translations from the stdout of fairseq-generate. "
)) "If there are multiply hypotheses for a source, we only keep the first one. "
parser.add_argument('--output', required=True, help='output prefix') )
parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)') )
parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)') parser.add_argument("--output", required=True, help="output prefix")
parser.add_argument('--minlen', type=int, help='min length filter') parser.add_argument(
parser.add_argument('--maxlen', type=int, help='max length filter') "--srclang", required=True, help="source language (extracted from H-* lines)"
parser.add_argument('--ratio', type=float, help='ratio filter') )
parser.add_argument('files', nargs='*', help='input files') parser.add_argument(
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
)
parser.add_argument("--minlen", type=int, help="min length filter")
parser.add_argument("--maxlen", type=int, help="max length filter")
parser.add_argument("--ratio", type=float, help="ratio filter")
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args() args = parser.parse_args()
def validate(src, tgt): def validate(src, tgt):
srclen = len(src.split(' ')) if src != '' else 0 srclen = len(src.split(" ")) if src != "" else 0
tgtlen = len(tgt.split(' ')) if tgt != '' else 0 tgtlen = len(tgt.split(" ")) if tgt != "" else 0
if ( if (
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen)) or (
or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)) args.maxlen is not None
and (srclen > args.maxlen or tgtlen > args.maxlen)
)
or (
args.ratio is not None
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
)
): ):
return False return False
return True return True
@ -41,19 +53,20 @@ def main():
except IndexError: except IndexError:
return default return default
with open(args.output + '.' + args.srclang, 'w') as src_h, \ with open(args.output + "." + args.srclang, "w") as src_h, open(
open(args.output + '.' + args.tgtlang, 'w') as tgt_h: args.output + "." + args.tgtlang, "w"
) as tgt_h:
for line in tqdm(fileinput.input(args.files)): for line in tqdm(fileinput.input(args.files)):
if line.startswith('S-'): if line.startswith("S-"):
tgt = safe_index(line.rstrip().split('\t'), 1, '') tgt = safe_index(line.rstrip().split("\t"), 1, "")
elif line.startswith('H-'): elif line.startswith("H-"):
if tgt is not None: if tgt is not None:
src = safe_index(line.rstrip().split('\t'), 2, '') src = safe_index(line.rstrip().split("\t"), 2, "")
if validate(src, tgt): if validate(src, tgt):
print(src, file=src_h) print(src, file=src_h)
print(tgt, file=tgt_h) print(tgt, file=tgt_h)
tgt = None tgt = None
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -4,203 +4,251 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os.path as op
import argparse import argparse
import os import os
from multiprocessing import cpu_count import os.path as op
from collections import namedtuple from collections import namedtuple
from typing import Optional, List from multiprocessing import cpu_count
from typing import List, Optional
import sentencepiece as sp import sentencepiece as sp
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
from fairseq.data.encoders.byte_utils import byte_encode
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
from fairseq.data.encoders.characters import Characters
from fairseq.data.encoders.byte_bpe import ByteBPE from fairseq.data.encoders.byte_bpe import ByteBPE
from fairseq.data.encoders.byte_utils import byte_encode
from fairseq.data.encoders.bytes import Bytes from fairseq.data.encoders.bytes import Bytes
from fairseq.data.encoders.characters import Characters
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
SPLITS = ['train', 'valid', 'test'] SPLITS = ["train", "valid", "test"]
def _convert_xml(in_path: str, out_path: str): def _convert_xml(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
ss = s.strip() ss = s.strip()
if not ss.startswith('<seg'): if not ss.startswith("<seg"):
continue continue
ss = ss.replace('</seg>', '').split('">') ss = ss.replace("</seg>", "").split('">')
assert len(ss) == 2 assert len(ss) == 2
f_o.write(ss[1].strip() + '\n') f_o.write(ss[1].strip() + "\n")
def _convert_train(in_path: str, out_path: str): def _convert_train(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
ss = s.strip() ss = s.strip()
if ss.startswith('<'): if ss.startswith("<"):
continue continue
f_o.write(ss.strip() + '\n') f_o.write(ss.strip() + "\n")
def _get_bytes(in_path: str, out_path: str): def _get_bytes(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
f_o.write(Bytes.encode(s.strip()) + '\n') f_o.write(Bytes.encode(s.strip()) + "\n")
def _get_chars(in_path: str, out_path: str): def _get_chars(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
f_o.write(Characters.encode(s.strip()) + '\n') f_o.write(Characters.encode(s.strip()) + "\n")
def pretokenize(in_path: str, out_path: str, src: str, tgt: str): def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
Args = namedtuple('Args', ['moses_source_lang', 'moses_target_lang', Args = namedtuple(
'moses_no_dash_splits', 'moses_no_escape']) "Args",
args = Args(moses_source_lang=src, moses_target_lang=tgt, [
moses_no_dash_splits=False, moses_no_escape=False) "moses_source_lang",
"moses_target_lang",
"moses_no_dash_splits",
"moses_no_escape",
],
)
args = Args(
moses_source_lang=src,
moses_target_lang=tgt,
moses_no_dash_splits=False,
moses_no_escape=False,
)
pretokenizer = MosesTokenizer(args) pretokenizer = MosesTokenizer(args)
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
f_o.write(pretokenizer.encode(s.strip()) + '\n') f_o.write(pretokenizer.encode(s.strip()) + "\n")
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str): def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
with open(out_path, 'w') as f_o: with open(out_path, "w") as f_o:
for lang in [src, tgt]: for lang in [src, tgt]:
with open(f'{in_path_prefix}.{lang}') as f: with open(f"{in_path_prefix}.{lang}") as f:
for s in f: for s in f:
f_o.write(byte_encode(s.strip()) + '\n') f_o.write(byte_encode(s.strip()) + "\n")
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int): def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
arguments = [ arguments = [
f'--input={in_path}', f'--model_prefix={model_prefix}', f"--input={in_path}",
f'--model_type=bpe', f'--vocab_size={vocab_size}', f"--model_prefix={model_prefix}",
'--character_coverage=1.0', '--normalization_rule_name=identity', f"--model_type=bpe",
f'--num_threads={cpu_count()}' f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
"--normalization_rule_name=identity",
f"--num_threads={cpu_count()}",
] ]
sp.SentencePieceTrainer.Train(' '.join(arguments)) sp.SentencePieceTrainer.Train(" ".join(arguments))
def _apply_bbpe(model_path: str, in_path: str, out_path: str): def _apply_bbpe(model_path: str, in_path: str, out_path: str):
Args = namedtuple('Args', ['sentencepiece_model_path']) Args = namedtuple("Args", ["sentencepiece_model_path"])
args = Args(sentencepiece_model_path=model_path) args = Args(sentencepiece_model_path=model_path)
tokenizer = ByteBPE(args) tokenizer = ByteBPE(args)
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
f_o.write(tokenizer.encode(s.strip()) + '\n') f_o.write(tokenizer.encode(s.strip()) + "\n")
def _apply_bpe(model_path: str, in_path: str, out_path: str): def _apply_bpe(model_path: str, in_path: str, out_path: str):
Args = namedtuple('Args', ['sentencepiece_model']) Args = namedtuple("Args", ["sentencepiece_model"])
args = Args(sentencepiece_model=model_path) args = Args(sentencepiece_model=model_path)
tokenizer = SentencepieceBPE(args) tokenizer = SentencepieceBPE(args)
with open(in_path) as f, open(out_path, 'w') as f_o: with open(in_path) as f, open(out_path, "w") as f_o:
for s in f: for s in f:
f_o.write(tokenizer.encode(s.strip()) + '\n') f_o.write(tokenizer.encode(s.strip()) + "\n")
def _concat_files(in_paths: List[str], out_path: str): def _concat_files(in_paths: List[str], out_path: str):
with open(out_path, 'w') as f_o: with open(out_path, "w") as f_o:
for p in in_paths: for p in in_paths:
with open(p) as f: with open(p) as f:
for r in f: for r in f:
f_o.write(r) f_o.write(r)
def preprocess_iwslt17(root: str, src: str, tgt: str, bpe_size: Optional[int], def preprocess_iwslt17(
need_chars: bool, bbpe_size: Optional[int], root: str,
need_bytes: bool): src: str,
tgt: str,
bpe_size: Optional[int],
need_chars: bool,
bbpe_size: Optional[int],
need_bytes: bool,
):
# extract bitext # extract bitext
in_root = op.join(root, f'{src}-{tgt}') in_root = op.join(root, f"{src}-{tgt}")
for lang in [src, tgt]: for lang in [src, tgt]:
_convert_train( _convert_train(
op.join(in_root, f'train.tags.{src}-{tgt}.{lang}'), op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
op.join(root, f'train.{lang}') op.join(root, f"train.{lang}"),
) )
_convert_xml( _convert_xml(
op.join(in_root, f'IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml'), op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
op.join(root, f'valid.{lang}') op.join(root, f"valid.{lang}"),
) )
_convert_xml( _convert_xml(
op.join(in_root, f'IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml'), op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
op.join(root, f'test.{lang}') op.join(root, f"test.{lang}"),
) )
# pre-tokenize # pre-tokenize
for lang in [src, tgt]: for lang in [src, tgt]:
for split in SPLITS: for split in SPLITS:
pretokenize(op.join(root, f'{split}.{lang}'), pretokenize(
op.join(root, f'{split}.moses.{lang}'), src, tgt) op.join(root, f"{split}.{lang}"),
op.join(root, f"{split}.moses.{lang}"),
src,
tgt,
)
# tokenize with BPE vocabulary # tokenize with BPE vocabulary
if bpe_size is not None: if bpe_size is not None:
# learn vocabulary # learn vocabulary
concated_train_path = op.join(root, 'train.all') concated_train_path = op.join(root, "train.all")
_concat_files( _concat_files(
[op.join(root, 'train.moses.fr'), op.join(root, 'train.moses.en')], [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
concated_train_path concated_train_path,
) )
bpe_model_prefix = op.join(root, f'spm_bpe{bpe_size}') bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size) _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
os.remove(concated_train_path) os.remove(concated_train_path)
# apply # apply
for lang in [src, tgt]: for lang in [src, tgt]:
for split in SPLITS: for split in SPLITS:
_apply_bpe( _apply_bpe(
bpe_model_prefix + '.model', bpe_model_prefix + ".model",
op.join(root, f'{split}.moses.{lang}'), op.join(root, f"{split}.moses.{lang}"),
op.join(root, f'{split}.moses.bpe{bpe_size}.{lang}') op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
) )
# tokenize with bytes vocabulary # tokenize with bytes vocabulary
if need_bytes: if need_bytes:
for lang in [src, tgt]: for lang in [src, tgt]:
for split in SPLITS: for split in SPLITS:
_get_bytes(op.join(root, f'{split}.moses.{lang}'), _get_bytes(
op.join(root, f'{split}.moses.bytes.{lang}')) op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bytes.{lang}"),
)
# tokenize with characters vocabulary # tokenize with characters vocabulary
if need_chars: if need_chars:
for lang in [src, tgt]: for lang in [src, tgt]:
for split in SPLITS: for split in SPLITS:
_get_chars(op.join(root, f'{split}.moses.{lang}'), _get_chars(
op.join(root, f'{split}.moses.chars.{lang}')) op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.chars.{lang}"),
)
# tokenize with byte-level BPE vocabulary # tokenize with byte-level BPE vocabulary
if bbpe_size is not None: if bbpe_size is not None:
# learn vocabulary # learn vocabulary
bchar_path = op.join(root, 'train.bchar') bchar_path = op.join(root, "train.bchar")
_convert_to_bchar(op.join(root, 'train.moses'), src, tgt, bchar_path) _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
bbpe_model_prefix = op.join(root, f'spm_bbpe{bbpe_size}') bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size) _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
os.remove(bchar_path) os.remove(bchar_path)
# apply # apply
for lang in [src, tgt]: for lang in [src, tgt]:
for split in SPLITS: for split in SPLITS:
_apply_bbpe( _apply_bbpe(
bbpe_model_prefix + '.model', bbpe_model_prefix + ".model",
op.join(root, f'{split}.moses.{lang}'), op.join(root, f"{split}.moses.{lang}"),
op.join(root, f'{split}.moses.bbpe{bbpe_size}.{lang}') op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
) )
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='data') parser.add_argument("--root", type=str, default="data")
parser.add_argument('--bpe-vocab', default=None, type=int, parser.add_argument(
help='Generate tokenized bitext with BPE of size K.' "--bpe-vocab",
'Default to None (disabled).') default=None,
parser.add_argument('--bbpe-vocab', default=None, type=int, type=int,
help='Generate tokenized bitext with BBPE of size K.' help="Generate tokenized bitext with BPE of size K."
'Default to None (disabled).') "Default to None (disabled).",
parser.add_argument('--byte-vocab', action='store_true', )
help='Generate tokenized bitext with bytes vocabulary') parser.add_argument(
parser.add_argument('--char-vocab', action='store_true', "--bbpe-vocab",
help='Generate tokenized bitext with chars vocabulary') default=None,
type=int,
help="Generate tokenized bitext with BBPE of size K."
"Default to None (disabled).",
)
parser.add_argument(
"--byte-vocab",
action="store_true",
help="Generate tokenized bitext with bytes vocabulary",
)
parser.add_argument(
"--char-vocab",
action="store_true",
help="Generate tokenized bitext with chars vocabulary",
)
args = parser.parse_args() args = parser.parse_args()
preprocess_iwslt17(args.root, 'fr', 'en', args.bpe_vocab, args.char_vocab, preprocess_iwslt17(
args.bbpe_vocab, args.byte_vocab) args.root,
"fr",
"en",
args.bpe_vocab,
args.char_vocab,
args.bbpe_vocab,
args.byte_vocab,
)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -11,7 +11,7 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerModel, TransformerEncoder from fairseq.models.transformer import TransformerEncoder, TransformerModel
@register_model("gru_transformer") @register_model("gru_transformer")
@ -24,9 +24,12 @@ class GRUTransformerModel(TransformerModel):
class GRUTransformerEncoder(TransformerEncoder): class GRUTransformerEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens) super().__init__(args, dictionary, embed_tokens)
self.emb_ctx = nn.GRU(input_size=embed_tokens.embedding_dim, self.emb_ctx = nn.GRU(
hidden_size=embed_tokens.embedding_dim // 2, input_size=embed_tokens.embedding_dim,
num_layers=1, bidirectional=True) hidden_size=embed_tokens.embedding_dim // 2,
num_layers=1,
bidirectional=True,
)
def forward_embedding(self, src_tokens): def forward_embedding(self, src_tokens):
# embed tokens and positions # embed tokens and positions

View File

@ -16,11 +16,12 @@ def main(args):
print(normalizer.normalize(line.rstrip()), flush=True) print(normalizer.normalize(line.rstrip()), flush=True)
if __name__ == '__main__': if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en') parser.add_argument("--lang", "-l", default="en")
parser.add_argument('--penn', '-p', action='store_true') parser.add_argument("--penn", "-p", action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -6,12 +6,14 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sys import sys
import sacremoses import sacremoses
def main(args): def main(args):
"""Tokenizes, preserving tabs""" """Tokenizes, preserving tabs"""
mt = sacremoses.MosesTokenizer(lang=args.lang) mt = sacremoses.MosesTokenizer(lang=args.lang)
def tok(s): def tok(s):
return mt.tokenize(s, return_str=True) return mt.tokenize(s, return_str=True)
@ -20,12 +22,13 @@ def main(args):
print(*parts, sep="\t", flush=True) print(*parts, sep="\t", flush=True)
if __name__ == '__main__': if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en') parser.add_argument("--lang", "-l", default="en")
parser.add_argument('--penn', '-p', action='store_true') parser.add_argument("--penn", "-p", action="store_true")
parser.add_argument('--fields', '-f', help="fields to tokenize") parser.add_argument("--fields", "-f", help="fields to tokenize")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -3,14 +3,15 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import faiss
import numpy as np
import glob
import argparse import argparse
import glob
from subprocess import check_call from subprocess import check_call
import faiss
import numpy as np
GB = 1024*1024*1024
GB = 1024 * 1024 * 1024
def call(cmd): def call(cmd):
@ -18,14 +19,14 @@ def call(cmd):
check_call(cmd, shell=True) check_call(cmd, shell=True)
def get_batches(directory, lang, prefix='all_avg_pool'): def get_batches(directory, lang, prefix="all_avg_pool"):
print(f"Finding in {directory}/{prefix}.{lang}*") print(f"Finding in {directory}/{prefix}.{lang}*")
files = glob.glob(f'{directory}/{prefix}.{lang}*') files = glob.glob(f"{directory}/{prefix}.{lang}*")
emb_files = [] emb_files = []
txt_files = [] txt_files = []
for emb_fi in files: for emb_fi in files:
emb_files.append(emb_fi) emb_files.append(emb_fi)
txt_fi = emb_fi.replace(prefix, 'sentences') txt_fi = emb_fi.replace(prefix, "sentences")
txt_files.append(txt_fi) txt_files.append(txt_fi)
return emb_files, txt_files return emb_files, txt_files
@ -38,7 +39,7 @@ def load_batch(emb_file, dim):
return embeddings return embeddings
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'): def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
sims = [] sims = []
inds = [] inds = []
xfrom = 0 xfrom = 0
@ -53,7 +54,7 @@ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
y_batch = load_batch(y_batch_f, dim) y_batch = load_batch(y_batch_f, dim)
neighbor_size = min(k, y_batch.shape[0]) neighbor_size = min(k, y_batch.shape[0])
yto = yfrom + y_batch.shape[0] yto = yfrom + y_batch.shape[0]
print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
idx = faiss.IndexFlatIP(dim) idx = faiss.IndexFlatIP(dim)
idx = faiss.index_cpu_to_all_gpus(idx) idx = faiss.index_cpu_to_all_gpus(idx)
idx.add(y_batch) idx.add(y_batch)
@ -86,8 +87,10 @@ def score(sim, fwd_mean, bwd_mean, margin):
return margin(sim, (fwd_mean + bwd_mean) / 2) return margin(sim, (fwd_mean + bwd_mean) / 2)
def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False): def score_candidates(
print(' - scoring {:d} candidates'.format(sim_mat.shape[0])) sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
):
print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
scores = np.zeros(candidate_inds.shape) scores = np.zeros(candidate_inds.shape)
for i in range(scores.shape[0]): for i in range(scores.shape[0]):
for j in range(scores.shape[1]): for j in range(scores.shape[1]):
@ -106,42 +109,50 @@ def load_text(files):
return all_sentences return all_sentences
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Mine bitext') parser = argparse.ArgumentParser(description="Mine bitext")
parser.add_argument('--src-lang', help='Source language') parser.add_argument("--src-lang", help="Source language")
parser.add_argument('--tgt-lang', help='Target language') parser.add_argument("--tgt-lang", help="Target language")
parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt') parser.add_argument(
parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model') "--dict-path", help="Path to dictionary file", default="dict.txt"
parser.add_argument('--dim', type=int, default=1024, )
help='Embedding dimension') parser.add_argument(
parser.add_argument('--mem', type=int, default=5, "--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
help='Memory in GB') )
parser.add_argument('--src-dir', help='Source directory') parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
parser.add_argument('--tgt-dir', help='Target directory') parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
parser.add_argument('--output', help='Output path') parser.add_argument("--src-dir", help="Source directory")
parser.add_argument('--neighborhood', type=int, default=4, parser.add_argument("--tgt-dir", help="Target directory")
help='Embedding dimension') parser.add_argument("--output", help="Output path")
parser.add_argument('--threshold', type=float, default=1.06, parser.add_argument(
help='Threshold on mined bitext') "--neighborhood", type=int, default=4, help="Embedding dimension"
parser.add_argument('--valid-size', type=int, default=2000, )
help='Number of sentences used for validation set') parser.add_argument(
parser.add_argument('--min-count', type=int, default=50000, "--threshold", type=float, default=1.06, help="Threshold on mined bitext"
help='Min num sentences used for each language') )
parser.add_argument(
"--valid-size",
type=int,
default=2000,
help="Number of sentences used for validation set",
)
parser.add_argument(
"--min-count",
type=int,
default=50000,
help="Min num sentences used for each language",
)
args = parser.parse_args() args = parser.parse_args()
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang) x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang) y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
margin = lambda a, b: a / b margin = lambda a, b: a / b
y2x_sim, y2x_ind = knnGPU_sharded( y2x_sim, y2x_ind = knnGPU_sharded(
y_batches_f, x_batches_f, y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
args.dim, )
args.neighborhood,
direction='y2x')
x2y_sim, x2y_ind = knnGPU_sharded( x2y_sim, x2y_ind = knnGPU_sharded(
x_batches_f, y_batches_f, x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
args.dim, )
args.neighborhood,
direction='x2y')
x2y_mean = x2y_sim.mean(axis=1) x2y_mean = x2y_sim.mean(axis=1)
y2x_mean = y2x_sim.mean(axis=1) y2x_mean = y2x_sim.mean(axis=1)
@ -149,8 +160,13 @@ if __name__ == '__main__':
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin) bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)] fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)] bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)), indices = np.stack(
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1) (
np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
),
axis=1,
)
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
x_sentences = load_text(x_sents_f) x_sentences = load_text(x_sents_f)
@ -162,20 +178,20 @@ if __name__ == '__main__':
directory = args.output directory = args.output
call(f"mkdir -p {directory}") call(f"mkdir -p {directory}")
src_out = open( src_out = open(
f'{directory}/all.{args.src_lang}', f"{directory}/all.{args.src_lang}",
mode='w', mode="w",
encoding='utf-8', encoding="utf-8",
errors='surrogateescape') errors="surrogateescape",
)
tgt_out = open( tgt_out = open(
f'{directory}/all.{args.tgt_lang}', f"{directory}/all.{args.tgt_lang}",
mode='w', mode="w",
encoding='utf-8', encoding="utf-8",
errors='surrogateescape') errors="surrogateescape",
)
scores_out = open( scores_out = open(
f'{directory}/all.scores', f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
mode='w', )
encoding='utf-8',
errors='surrogateescape')
count = 0 count = 0
for i in np.argsort(-scores): for i in np.argsort(-scores):
src_ind, trg_ind = indices[i] src_ind, trg_ind = indices[i]
@ -195,20 +211,23 @@ if __name__ == '__main__':
scores_out.close() scores_out.close()
print(f"Found {count} pairs for threshold={threshold}") print(f"Found {count} pairs for threshold={threshold}")
with open(f'{directory}/all.{args.src_lang}') as all_s, \ with open(f"{directory}/all.{args.src_lang}") as all_s, open(
open(f'{directory}/all.{args.tgt_lang}') as all_t, \ f"{directory}/all.{args.tgt_lang}"
open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \ ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \ f"{directory}/valid.{args.tgt_lang}", "w"
open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \ ) as valid_t, open(
open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t: f"{directory}/train.{args.src_lang}", "w"
count = 0 ) as train_s, open(
for s_line, t_line in zip(all_s, all_t): f"{directory}/train.{args.tgt_lang}", "w"
s_line = s_line.split('\t')[1] ) as train_t:
t_line = t_line.split('\t')[1] count = 0
if count >= args.valid_size: for s_line, t_line in zip(all_s, all_t):
train_s.write(s_line) s_line = s_line.split("\t")[1]
train_t.write(t_line) t_line = t_line.split("\t")[1]
else: if count >= args.valid_size:
valid_s.write(s_line) train_s.write(s_line)
valid_t.write(t_line) train_t.write(t_line)
count += 1 else:
valid_s.write(s_line)
valid_t.write(t_line)
count += 1

View File

@ -7,27 +7,29 @@
Translate pre-processed data with a trained model. Translate pre-processed data with a trained model.
""" """
import numpy as np
import torch import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel from fairseq.sequence_generator import EnsembleModel
import numpy as np
def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False): def get_avg_pool(
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
):
model = EnsembleModel(models) model = EnsembleModel(models)
# model.forward normally channels prev_output_tokens into the decoder # model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder # separately, but SequenceGenerator directly calls model.encoder
encoder_input = { encoder_input = {
k: v for k, v in sample['net_input'].items() k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
if k != 'prev_output_tokens'
} }
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input) encoder_outs = model.forward_encoder(encoder_input)
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32) np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32) encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
np.float32
)
encoder_mask = np.expand_dims(encoder_mask.T, axis=2) encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
if has_langtok: if has_langtok:
encoder_mask = encoder_mask[1:, :, :] encoder_mask = encoder_mask[1:, :, :]
@ -38,13 +40,15 @@ def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langto
def main(args): def main(args):
assert args.path is not None, '--path required for generation!' assert args.path is not None, "--path required for generation!"
assert not args.sampling or args.nbest == args.beam, \ assert (
'--sampling requires --nbest to be equal to --beam' not args.sampling or args.nbest == args.beam
assert args.replace_unk is None or args.raw_text, \ ), "--sampling requires --nbest to be equal to --beam"
'--replace-unk requires a raw text dataset (--raw-text)' assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
args.beam=1 args.beam = 1
utils.import_user_module(args) utils.import_user_module(args)
if args.max_tokens is None: if args.max_tokens is None:
@ -58,15 +62,15 @@ def main(args):
# Set dictionaries # Set dictionaries
try: try:
src_dict = getattr(task, 'source_dictionary', None) src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError: except NotImplementedError:
src_dict = None src_dict = None
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print("| loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble( models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'), args.path.split(":"),
arg_overrides=eval(args.model_overrides), arg_overrides=eval(args.model_overrides),
task=task, task=task,
) )
@ -105,9 +109,9 @@ def main(args):
shard_id = 0 shard_id = 0
all_avg_pool = None all_avg_pool = None
encoder_has_langtok = ( encoder_has_langtok = (
hasattr(task.args, 'encoder_langtok') hasattr(task.args, "encoder_langtok")
and task.args.encoder_langtok is not None and task.args.encoder_langtok is not None
and hasattr(task.args, 'lang_tok_replacing_bos_eos') and hasattr(task.args, "lang_tok_replacing_bos_eos")
and not task.args.lang_tok_replacing_bos_eos and not task.args.lang_tok_replacing_bos_eos
) )
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
@ -116,34 +120,42 @@ def main(args):
print("Skipping None") print("Skipping None")
continue continue
sample = utils.move_to_cuda(sample) if use_cuda else sample sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample: if "net_input" not in sample:
continue continue
prefix_tokens = None prefix_tokens = None
if args.prefix_size > 0: if args.prefix_size > 0:
prefix_tokens = sample['target'][:, :args.prefix_size] prefix_tokens = sample["target"][:, : args.prefix_size]
with torch.no_grad(): with torch.no_grad():
avg_pool = get_avg_pool( avg_pool = get_avg_pool(
models, sample, prefix_tokens, src_dict, models,
args.remove_bpe, sample,
has_langtok=encoder_has_langtok) prefix_tokens,
src_dict,
args.remove_bpe,
has_langtok=encoder_has_langtok,
)
if all_avg_pool is not None: if all_avg_pool is not None:
all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
else: else:
all_avg_pool = avg_pool all_avg_pool = avg_pool
if not isinstance(sample['id'], list): if not isinstance(sample["id"], list):
sample_ids = sample['id'].tolist() sample_ids = sample["id"].tolist()
else: else:
sample_ids = sample['id'] sample_ids = sample["id"]
for i, sample_id in enumerate(sample_ids): for i, sample_id in enumerate(sample_ids):
# Remove padding # Remove padding
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
)
# Either retrieve the original sentences or regenerate them from tokens. # Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None: if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) src_str = task.dataset(args.gen_subset).src.get_original_text(
sample_id
)
else: else:
if src_dict is not None: if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe) src_str = src_dict.string(src_tokens, args.remove_bpe)
@ -152,37 +164,50 @@ def main(args):
if not args.quiet: if not args.quiet:
if src_dict is not None: if src_dict is not None:
print('S-{}\t{}'.format(sample_id, src_str)) print("S-{}\t{}".format(sample_id, src_str))
source_sentences.append(f"{sample_id}\t{src_str}") source_sentences.append(f"{sample_id}\t{src_str}")
num_sentences += sample['nsentences'] num_sentences += sample["nsentences"]
if all_avg_pool.shape[0] >= 1000000: if all_avg_pool.shape[0] >= 1000000:
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}', with open(
'w') as avg_pool_file: f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
"w",
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file) all_avg_pool.tofile(avg_pool_file)
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file: with open(
sentence_file.writelines(f'{line}\n' for line in source_sentences) f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
"w",
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
all_avg_pool = None all_avg_pool = None
source_sentences = [] source_sentences = []
shard_id += 1 shard_id += 1
if all_avg_pool is not None: if all_avg_pool is not None:
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}', with open(
'w') as avg_pool_file: f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file) all_avg_pool.tofile(avg_pool_file)
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file: with open(
sentence_file.writelines(f'{line}\n' for line in source_sentences) f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
return None return None
def cli_main(): def cli_main():
parser = options.get_generation_parser() parser = options.get_generation_parser()
parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N', parser.add_argument(
help='directory to save encoder outputs') "--encoder-save-dir",
default="",
type=str,
metavar="N",
help="directory to save encoder outputs",
)
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
main(args) main(args)
if __name__ == '__main__': if __name__ == "__main__":
cli_main() cli_main()

View File

@ -3,10 +3,11 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import numpy as np
import argparse import argparse
import glob import glob
import numpy as np
DIM = 1024 DIM = 1024
@ -14,9 +15,13 @@ DIM = 1024
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
target_ids = [tid for tid in target_embs] target_ids = [tid for tid in target_embs]
source_mat = np.stack(source_embs.values(), axis=0) source_mat = np.stack(source_embs.values(), axis=0)
normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True) normalized_source_mat = source_mat / np.linalg.norm(
source_mat, axis=1, keepdims=True
)
target_mat = np.stack(target_embs.values(), axis=0) target_mat = np.stack(target_embs.values(), axis=0)
normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True) normalized_target_mat = target_mat / np.linalg.norm(
target_mat, axis=1, keepdims=True
)
sim_mat = normalized_source_mat.dot(normalized_target_mat.T) sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
if return_sim_mat: if return_sim_mat:
return sim_mat return sim_mat
@ -36,14 +41,14 @@ def load_embeddings(directory, LANGS):
lang_dir = f"{directory}/{lang}" lang_dir = f"{directory}/{lang}"
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
for embed_file in embedding_files: for embed_file in embedding_files:
shard_id = embed_file.split('.')[-1] shard_id = embed_file.split(".")[-1]
embeddings = np.fromfile(embed_file, dtype=np.float32) embeddings = np.fromfile(embed_file, dtype=np.float32)
num_rows = embeddings.shape[0] // DIM num_rows = embeddings.shape[0] // DIM
embeddings = embeddings.reshape((num_rows, DIM)) embeddings = embeddings.reshape((num_rows, DIM))
with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file: with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
for idx, line in enumerate(sentence_file): for idx, line in enumerate(sentence_file):
sentence_id, sentence = line.strip().split('\t') sentence_id, sentence = line.strip().split("\t")
sentence_texts[lang][sentence_id] = sentence sentence_texts[lang][sentence_id] = sentence
sentence_embeddings[lang][sentence_id] = embeddings[idx, :] sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
@ -55,7 +60,7 @@ def compute_accuracy(directory, LANGS):
top_1_accuracy = {} top_1_accuracy = {}
top1_str = " ".join(LANGS) + '\n' top1_str = " ".join(LANGS) + "\n"
for source_lang in LANGS: for source_lang in LANGS:
top_1_accuracy[source_lang] = {} top_1_accuracy[source_lang] = {}
top1_str += f"{source_lang} " top1_str += f"{source_lang} "
@ -63,8 +68,8 @@ def compute_accuracy(directory, LANGS):
top1 = 0 top1 = 0
top5 = 0 top5 = 0
neighbors_map = compute_dist( neighbors_map = compute_dist(
sentence_embeddings[source_lang], sentence_embeddings[source_lang], sentence_embeddings[target_lang]
sentence_embeddings[target_lang]) )
for sentence_id, neighbors in neighbors_map.items(): for sentence_id, neighbors in neighbors_map.items():
if sentence_id == neighbors[0]: if sentence_id == neighbors[0]:
top1 += 1 top1 += 1
@ -75,17 +80,13 @@ def compute_accuracy(directory, LANGS):
top1_str += "\n" top1_str += "\n"
print(top1_str) print(top1_str)
print(top1_str, file=open(f"{directory}/accuracy", 'w')) print(top1_str, file=open(f"{directory}/accuracy", "w"))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Analyze encoder outputs') parser = argparse.ArgumentParser(description="Analyze encoder outputs")
parser.add_argument('directory', parser.add_argument("directory", help="Source language corpus")
help='Source language corpus' parser.add_argument("--langs", help="List of langs")
)
parser.add_argument('--langs',
help='List of langs'
)
args = parser.parse_args() args = parser.parse_args()
langs = args.langs.split(',') langs = args.langs.split(",")
compute_accuracy(args.directory, langs) compute_accuracy(args.directory, langs)

View File

@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .models import latent_multilingual_transformer # noqa from . import multilingual_translation_latent_depth # noqa
from .modules import latent_layers # noqa from .loss import latent_depth # noqa
from .loss import latent_depth # noqa from .models import latent_multilingual_transformer # noqa
from . import multilingual_translation_latent_depth # noqa from .modules import latent_layers # noqa

View File

@ -3,8 +3,9 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch
import math import math
import torch
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@ -19,17 +20,16 @@ class LatentLayersKLLoss(_Loss):
eps = 1e-7 eps = 1e-7
if prior == "uniform": if prior == "uniform":
# uniform prior # uniform prior
kl_loss = (samples * ( kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
torch.log(samples + eps) - math.log(0.5)
)).sum(-1)
elif prior == "agged_posterior": elif prior == "agged_posterior":
# aggregated posterior # aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0) y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0) agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1) row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm normed_agg_q = agged_q / row_norm
kl_loss = (samples * ( kl_loss = (
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1) samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
).sum(-1)
else: else:
raise NotImplementedError("The specified prior is not implemented.") raise NotImplementedError("The specified prior is not implemented.")
@ -37,7 +37,9 @@ class LatentLayersKLLoss(_Loss):
kl_loss /= layer_samples[0].size()[0] kl_loss /= layer_samples[0].size()[0]
kl_weight = min( kl_weight = min(
self.args.sparsity_weight, self.args.sparsity_weight,
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates (update_num - self.args.soft_update)
* self.args.sparsity_weight
/ self.args.anneal_updates,
) )
kl_loss *= kl_weight * sample_size kl_loss *= kl_weight * sample_size
return kl_loss return kl_loss
@ -58,15 +60,17 @@ class LatentLayersSparsityLoss(_Loss):
share_loss = 0 share_loss = 0
global_sparsity_loss = 0 global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0) layer_samples = torch.stack(layer_samples_list, dim=0)
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and if (
update_num > (self.args.soft_update + self.args.anneal_updates)): self.args.target_layers > 0 or self.args.share_weight > 0
) and update_num > (self.args.soft_update + self.args.anneal_updates):
# anneal sparsity weight # anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update): if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0 weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update): elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = ( weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates) (update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight / self.args.anneal_updates * self.args.share_weight
/ self.args.anneal_updates
) )
else: else:
weight_anneal = 1 weight_anneal = 1
@ -75,12 +79,21 @@ class LatentLayersSparsityLoss(_Loss):
layer_utilization /= layer_samples.size()[0] layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0: if self.args.share_weight > 0:
# encouraging sharing across languages # encouraging sharing across languages
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0) share_loss = sum(
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss -1.0 * v * math.log(v) for v in layer_utilization if v > 0
)
batch_loss += (
weight_anneal * self.args.share_weight * sample_size * share_loss
)
if self.args.target_layers > 0: if self.args.target_layers > 0:
# computed expected number of layers selected # computed expected number of layers selected
expeted_layers = sum(layer_utilization) expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers # compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2 global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss batch_loss += (
weight_anneal
* self.args.share_weight
* sample_size
* global_sparsity_loss
)
return batch_loss return batch_loss

View File

@ -3,34 +3,31 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from fairseq.models import ( from fairseq.models import register_model, register_model_architecture
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.models.multilingual_transformer import MultilingualTransformerModel from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from fairseq.models.transformer import (
from .latent_transformer import ( TransformerDecoder,
LatentTransformerEncoder, TransformerEncoder,
LatentTransformerDecoder, base_architecture,
) )
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
@register_model('latent_multilingual_transformer')
@register_model("latent_multilingual_transformer")
class LatentMultilingualTransformerModel(MultilingualTransformerModel): class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or """A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth" decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102). (https://arxiv.org/abs/2009.13102).
""" """
@classmethod @classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder: if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer: if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs)) return LatentTransformerEncoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else: else:
return TransformerEncoder(args, lang_dict, embed_tokens) return TransformerEncoder(args, lang_dict, embed_tokens)
else: else:
@ -42,19 +39,21 @@ class LatentMultilingualTransformerModel(MultilingualTransformerModel):
return TransformerDecoder(args, lang_dict, embed_tokens) return TransformerDecoder(args, lang_dict, embed_tokens)
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer') @register_model_architecture(
"latent_multilingual_transformer", "latent_multilingual_transformer"
)
def latent_multilingual_architecture(args): def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, 'encoder_layers', 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, 'decoder_layers', 24) args.decoder_layers = getattr(args, "decoder_layers", 24)
args.share_encoders = getattr(args, 'share_encoders', True) args.share_encoders = getattr(args, "share_encoders", True)
args.share_decoders = getattr(args, 'share_decoders', True) args.share_decoders = getattr(args, "share_decoders", True)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True) args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True) args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
base_architecture(args) base_architecture(args)

View File

@ -7,26 +7,27 @@ from typing import Any, Dict, Optional
import torch.nn as nn import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from ..modules.latent_layers import LayerSelect
from torch import Tensor from torch import Tensor
from ..modules.latent_layers import LayerSelect
class LatentTransformerEncoder(TransformerEncoder): class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in """Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder. TransformerEncoder.
""" """
def __init__(self, args, dictionary, embed_tokens, num_logits=1): def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits self.num_logits = num_logits
self.num_layers = args.encoder_layers self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens) super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None self.lang_idx = None
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
self._build_encoder_layer(args, idx) [self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
for idx in range(args.encoder_layers) )
])
def set_lang_idx(self, lang_idx): def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx self.lang_idx = lang_idx
@ -50,6 +51,7 @@ class LatentTransformerEncoderLayer(TransformerEncoderLayer):
layer_select (LayerSelect, optional): instance of LayerSelect module with logits layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method. parameters and sampling method.
""" """
def __init__(self, args, idx, layer_select=None): def __init__(self, args, idx, layer_select=None):
super().__init__(args) super().__init__(args)
self.idx = idx self.idx = idx
@ -63,7 +65,10 @@ class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in """Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder. TransformerDecoder.
""" """
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
):
self.num_logits = num_logits self.num_logits = num_logits
self.num_layers = args.decoder_layers self.num_layers = args.decoder_layers
super().__init__( super().__init__(
@ -71,16 +76,20 @@ class LatentTransformerDecoder(TransformerDecoder):
) )
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None self.lang_idx = None
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
self._build_decoder_layer(args, no_encoder_attn, idx) [
for idx in range(args.decoder_layers) self._build_decoder_layer(args, no_encoder_attn, idx)
]) for idx in range(args.decoder_layers)
]
)
def set_lang_idx(self, lang_idx): def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None): def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn) return LatentTransformerDecoderLayer(
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
)
def forward( def forward(
self, self,
@ -119,8 +128,15 @@ class LatentTransformerDecoderLayer(TransformerDecoderLayer):
(default: False). (default: False).
""" """
def __init__( def __init__(
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False self,
args,
idx,
layer_select=None,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
): ):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx self.idx = idx

View File

@ -12,6 +12,7 @@ class LayerSelect(nn.Module):
either (soft) weighting or (hard) selection of residual connection. either (soft) weighting or (hard) selection of residual connection.
https://arxiv.org/abs/2009.13102 https://arxiv.org/abs/2009.13102
""" """
def __init__(self, num_layers, num_logits, args): def __init__(self, num_layers, num_logits, args):
super(LayerSelect, self).__init__() super(LayerSelect, self).__init__()
self.args = args self.args = args
@ -27,14 +28,14 @@ class LayerSelect(nn.Module):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
parser.add_argument( parser.add_argument(
'--soft-select', "--soft-select",
action='store_true', action="store_true",
help='use soft samples in training an inference' help="use soft samples in training an inference",
) )
parser.add_argument('--sampling-tau', type=float, help='sampling temperature') parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
def sample(self, logit_idx): def sample(self, logit_idx):
""" To leverage the efficiency of distributed training, samples for all """To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other. learnt independent of each other.
@ -43,7 +44,9 @@ class LayerSelect(nn.Module):
""" """
assert logit_idx is not None assert logit_idx is not None
self.samples = self._gumbel_sigmoid( self.samples = self._gumbel_sigmoid(
self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :], self.layer_logits[logit_idx, :].detach()
if self.detach_grad
else self.layer_logits[logit_idx, :],
dim=-1, dim=-1,
tau=self.tau, tau=self.tau,
hard=self.hard_select, hard=self.hard_select,
@ -54,10 +57,20 @@ class LayerSelect(nn.Module):
sample = self.samples[i] sample = self.samples[i]
return sample return sample
def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5): def _gumbel_sigmoid(
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
):
# ~Gumbel(0,1) # ~Gumbel(0,1)
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() gumbels1 = (
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
gumbels2 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
# Difference of two gumbels because we apply a sigmoid # Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid() y_soft = gumbels1.sigmoid()

View File

@ -5,10 +5,11 @@
from fairseq.tasks import register_task from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
@register_task('multilingual_translation_latent_depth') @register_task("multilingual_translation_latent_depth")
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask): class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth. """A task for multiple translation with latent depth.
@ -39,7 +40,9 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
def __init__(self, args, dicts, training): def __init__(self, args, dicts, training):
super().__init__(args, dicts, training) super().__init__(args, dicts, training)
self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]) self.src_langs, self.tgt_langs = zip(
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
)
if self.training and self.encoder_latent_layer: if self.training and self.encoder_latent_layer:
assert self.args.share_encoders assert self.args.share_encoders
if self.training and self.decoder_latent_layer: if self.training and self.decoder_latent_layer:
@ -47,46 +50,56 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
if training or self.encoder_latent_layer or self.decoder_latent_layer: if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs self.lang_pairs = args.lang_pairs
else: else:
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer): if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args) self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args) self.sparsity_loss = LatentLayersSparsityLoss(self.args)
def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad): def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
src, tgt = lang_pair.split("-") src, tgt = lang_pair.split("-")
if self.encoder_latent_layer: if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src] src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx) model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update model.models[lang_pair].encoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
if self.decoder_latent_layer: if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt] tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update model.models[lang_pair].decoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if self.encoder_latent_layer: if self.encoder_latent_layer:
none_samples = sum( none_samples = sum(
1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples 1 if x is None else 0
for x in model.models[lang_pair].encoder.layer_select.layer_samples
) )
if none_samples == 0 or self.args.prior != "agged_posterior": if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss( loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples, model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx, src_lang_idx,
update_num, update_num,
sample_size sample_size,
) )
if self.decoder_latent_layer: if self.decoder_latent_layer:
none_samples = sum( none_samples = sum(
1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples 1 if x is None else 0
for x in model.models[lang_pair].decoder.layer_select.layer_samples
) )
if none_samples == 0 or self.args.prior != "agged_posterior": if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss( loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples, model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx, tgt_lang_idx,
update_num, update_num,
sample_size sample_size,
) )
if ignore_grad: if ignore_grad:
loss *= 0 loss *= 0
@ -99,18 +112,31 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
return loss, sample_size, logging_output return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
agg_loss, agg_sample_size, agg_logging_output = super().train_step( agg_loss, agg_sample_size, agg_logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad) sample, model, criterion, optimizer, update_num, ignore_grad
)
# compute auxiliary loss from layere sparsity, based on all samples from all languages # compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num): if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0 sparsity_loss = 0
if self.encoder_latent_layer: if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss( sparsity_loss += self.sparsity_loss(
next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size) next(
iter(model.models.values())
).encoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if self.decoder_latent_layer: if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss( sparsity_loss += self.sparsity_loss(
next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size) next(
iter(model.models.values())
).decoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if sparsity_loss > 0: if sparsity_loss > 0:
optimizer.backward(sparsity_loss) optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output return agg_loss, agg_sample_size, agg_logging_output
@ -123,10 +149,14 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
if self.decoder_latent_layer: if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt] tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
return loss, sample_size, logging_output return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
if self.encoder_latent_layer or self.decoder_latent_layer: if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models: for model in models:
if self.encoder_latent_layer: if self.encoder_latent_layer:
@ -137,15 +167,23 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
assert model.decoder.layer_select is not None assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang] tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx) model.decoder.set_lang_idx(tgt_lang_idx)
return super().inference_step(generator, models, sample, prefix_tokens, constraints) return super().inference_step(
generator, models, sample, prefix_tokens, constraints
)
@property @property
def encoder_latent_layer(self): def encoder_latent_layer(self):
return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer return (
hasattr(self.args, "encoder_latent_layer")
and self.args.encoder_latent_layer
)
@property @property
def decoder_latent_layer(self): def decoder_latent_layer(self):
return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer return (
hasattr(self.args, "decoder_latent_layer")
and self.args.decoder_latent_layer
)
@property @property
def src_lang_idx_dict(self): def src_lang_idx_dict(self):

View File

@ -8,37 +8,40 @@ Linformer: Self-Attention with Linear Complexity
import logging import logging
from fairseq.models import ( from fairseq.models import register_model, register_model_architecture
register_model, from fairseq.models.roberta import RobertaEncoder, RobertaModel
register_model_architecture,
)
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
from fairseq.models.roberta import ( from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
RobertaModel,
RobertaEncoder,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_model('linformer_roberta') @register_model("linformer_roberta")
class LinformerModel(RobertaModel): class LinformerModel(RobertaModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
RobertaModel.add_args(parser) RobertaModel.add_args(parser)
# add args for Linformer # add args for Linformer
parser.add_argument('--compressed', type=int, parser.add_argument(
help='compressed ratio of sequence length') "--compressed", type=int, help="compressed ratio of sequence length"
parser.add_argument('--shared-kv-compressed', type=int, )
help='share compressed matrix between k and v, in each layer') parser.add_argument(
parser.add_argument('--shared-layer-kv-compressed', type=int, "--shared-kv-compressed",
help='share compressed matrix between k and v and across all layers') type=int,
parser.add_argument('--freeze-compress', type=int, help="share compressed matrix between k and v, in each layer",
help='freeze the parameters in compressed layer') )
parser.add_argument(
"--shared-layer-kv-compressed",
type=int,
help="share compressed matrix between k and v and across all layers",
)
parser.add_argument(
"--freeze-compress",
type=int,
help="freeze the parameters in compressed layer",
)
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
@ -47,7 +50,7 @@ class LinformerModel(RobertaModel):
# make sure all arguments are present # make sure all arguments are present
base_architecture(args) base_architecture(args)
if not hasattr(args, 'max_positions'): if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample args.max_positions = args.tokens_per_sample
encoder = LinformerEncoder(args, task.source_dictionary) encoder = LinformerEncoder(args, task.source_dictionary)
@ -85,47 +88,47 @@ class LinformerEncoder(RobertaEncoder):
) )
@register_model_architecture('linformer_roberta', 'linformer_roberta') @register_model_architecture("linformer_roberta", "linformer_roberta")
def base_architecture(args): def base_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.compressed = getattr(args, 'compressed', 4) args.compressed = getattr(args, "compressed", 4)
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0) args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0) args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
args.freeze_compress = getattr(args, 'freeze_compress', 0) args.freeze_compress = getattr(args, "freeze_compress", 0)
@register_model_architecture('linformer_roberta', 'linformer_roberta_base') @register_model_architecture("linformer_roberta", "linformer_roberta_base")
def linformer_roberta_base_architecture(args): def linformer_roberta_base_architecture(args):
base_architecture(args) base_architecture(args)
@register_model_architecture('linformer_roberta', 'linformer_roberta_large') @register_model_architecture("linformer_roberta", "linformer_roberta_large")
def linformer_roberta_large_architecture(args): def linformer_roberta_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24) args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.compressed = getattr(args, 'compressed', 4) args.compressed = getattr(args, "compressed", 4)
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0) args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0) args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)

View File

@ -6,8 +6,8 @@
import math import math
import torch.nn as nn import torch.nn as nn
from fairseq.modules import TransformerSentenceEncoder from fairseq.modules import TransformerSentenceEncoder
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
@ -117,7 +117,9 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
qn_block_size, qn_block_size,
): ):
if self.shared_layer_kv_compressed == 1: if self.shared_layer_kv_compressed == 1:
compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed) compress_layer = nn.Linear(
self.max_seq_len, self.max_seq_len // self.compressed
)
# intialize parameters for compressed layer # intialize parameters for compressed layer
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
if self.freeze_compress == 1: if self.freeze_compress == 1:
@ -139,8 +141,7 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
shared_kv_compressed=self.shared_kv_compressed, shared_kv_compressed=self.shared_kv_compressed,
shared_compress_layer=( shared_compress_layer=(
None if self.shared_layer_kv_compressed == 0 None if self.shared_layer_kv_compressed == 0 else self.compress_layer
else self.compress_layer
), ),
freeze_compress=self.freeze_compress, freeze_compress=self.freeze_compress,
) )
@ -156,7 +157,8 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
if self.shared_layer_kv_compressed: if self.shared_layer_kv_compressed:
for layer_idx in range(len(self.layers)): for layer_idx in range(len(self.layers)):
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format( new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
layer_idx, k[len(prefix + 'compress_layer.'):], layer_idx,
k[len(prefix + "compress_layer.") :],
) )
items_to_add[new_k] = state_dict[k] items_to_add[new_k] = state_dict[k]

View File

@ -6,6 +6,7 @@
from typing import Callable from typing import Callable
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from .multihead_linear_attention import MultiheadLinearAttention from .multihead_linear_attention import MultiheadLinearAttention
@ -23,7 +24,7 @@ class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
dropout: float = 0.1, dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
activation_dropout: float = 0.1, activation_dropout: float = 0.1,
activation_fn: str = 'relu', activation_fn: str = "relu",
export: bool = False, export: bool = False,
q_noise: float = 0.0, q_noise: float = 0.0,
qn_block_size: int = 8, qn_block_size: int = 8,

View File

@ -9,10 +9,10 @@ from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from torch import Tensor, nn
from torch.nn import Parameter
from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.nn import Parameter
@with_incremental_state @with_incremental_state
@ -65,16 +65,24 @@ class MultiheadLinearAttention(nn.Module):
"Self-attention requires query, key and " "value to be of the same size" "Self-attention requires query, key and " "value to be of the same size"
) )
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) self.k_proj = quant_noise(
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) )
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
# used for compress sequence to subsequence # used for compress sequence to subsequence
if shared_compress_layer is None: if shared_compress_layer is None:
self.compress_seq_len = max_seq_len // compressed self.compress_seq_len = max_seq_len // compressed
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False) self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
if shared_kv_compressed == 0: if shared_kv_compressed == 0:
self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False) self.compress_v = nn.Linear(
max_seq_len, self.compress_seq_len, bias=False
)
self.layerwise_sharing = False self.layerwise_sharing = False
else: else:
self.compress_k = shared_compress_layer self.compress_k = shared_compress_layer
@ -83,7 +91,9 @@ class MultiheadLinearAttention(nn.Module):
self.layerwise_sharing = True self.layerwise_sharing = True
self.shared_kv_compressed = shared_kv_compressed self.shared_kv_compressed = shared_kv_compressed
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv: if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
@ -116,22 +126,28 @@ class MultiheadLinearAttention(nn.Module):
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
if not self.layerwise_sharing: # otherwise, we already initialize the parameters if (
nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2)) not self.layerwise_sharing
): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
if self.shared_kv_compressed == 0: if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2)) nn.init.xavier_uniform_(
self.compress_v.weight, gain=1 / math.sqrt(2)
)
else: else:
nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight)
if not self.layerwise_sharing: # otherwise, we already initialize the parameters if (
not self.layerwise_sharing
): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight) nn.init.xavier_uniform_(self.compress_k.weight)
if self.shared_kv_compressed == 0: if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(self.compress_v.weight) nn.init.xavier_uniform_(self.compress_v.weight)
nn.init.xavier_uniform_(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None: if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.) nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None: if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k) nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None: if self.bias_v is not None:
@ -189,14 +205,26 @@ class MultiheadLinearAttention(nn.Module):
q = self.q_proj(query) q = self.q_proj(query)
k_input = query.permute(1, 2, 0).contiguous() # B * C * T k_input = query.permute(1, 2, 0).contiguous() # B * C * T
k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() k_input = (
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
k = self.k_proj(k_input) k = self.k_proj(k_input)
v_input = query.permute(1, 2, 0).contiguous() # B * C * T v_input = query.permute(1, 2, 0).contiguous() # B * C * T
if self.shared_kv_compressed == 0: if self.shared_kv_compressed == 0:
v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() v_input = (
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() v_input = (
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
v = self.v_proj(v_input) v = self.v_proj(v_input)
elif self.encoder_decoder_attention: elif self.encoder_decoder_attention:
# encoder-decoder attention # encoder-decoder attention
@ -302,7 +330,9 @@ class MultiheadLinearAttention(nn.Module):
) )
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) attn_weights = MultiheadLinearAttention.apply_sparse_mask(
attn_weights, tgt_len, src_len, bsz
)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
@ -385,7 +415,9 @@ class MultiheadLinearAttention(nn.Module):
@torch.jit.export @torch.jit.export
def reorder_incremental_state( def reorder_incremental_state(
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
): ):
"""Reorder buffered internal state (for incremental generation).""" """Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
@ -393,7 +425,9 @@ class MultiheadLinearAttention(nn.Module):
for k in input_buffer.keys(): for k in input_buffer.keys():
input_buffer_k = input_buffer[k] input_buffer_k = input_buffer[k]
if input_buffer_k is not None: if input_buffer_k is not None:
if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0): if self.encoder_decoder_attention and input_buffer_k.size(
0
) == new_order.size(0):
break break
input_buffer[k] = input_buffer_k.index_select(0, new_order) input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer) incremental_state = self._set_input_buffer(incremental_state, input_buffer)
@ -428,8 +462,8 @@ class MultiheadLinearAttention(nn.Module):
# in_proj_weight used to be q + k + v with same dimensions # in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3) dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim] items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:] items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k) keys_to_remove.append(k)
@ -438,9 +472,9 @@ class MultiheadLinearAttention(nn.Module):
dim = int(state_dict[k].shape[0] / 3) dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim:2 * dim dim : 2 * dim
] ]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:] items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias") keys_to_remove.append(prefix + "in_proj_bias")

View File

@ -8,14 +8,16 @@
import sys import sys
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
factory=IndicNormalizerFactory()
normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing') factory = IndicNormalizerFactory()
normalizer = factory.get_normalizer(
sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing"
)
for line in sys.stdin: for line in sys.stdin:
normalized_line=normalizer.normalize(line.strip()) normalized_line = normalizer.normalize(line.strip())
tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1])) tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1]))
print(tokenized_line) print(tokenized_line)

View File

@ -8,5 +8,6 @@ import sys
from pythainlp import word_tokenize from pythainlp import word_tokenize
for line in sys.stdin: for line in sys.stdin:
print(" ".join(word_tokenize(line.strip()))) print(" ".join(word_tokenize(line.strip())))

View File

@ -6,7 +6,9 @@
import fileinput import fileinput
import sacrebleu import sacrebleu
for line in fileinput.input(): for line in fileinput.input():
print(sacrebleu.tokenize_zh(line)) print(sacrebleu.tokenize_zh(line))

View File

@ -6,19 +6,27 @@
import argparse import argparse
import fileinput import fileinput
import sacremoses import sacremoses
def main(): def main():
parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description="")
parser.add_argument('files', nargs='*', help='input files') parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args() args = parser.parse_args()
detok = sacremoses.MosesDetokenizer() detok = sacremoses.MosesDetokenizer()
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed): for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
print(detok.detokenize(line.strip().split(' ')).replace(' @', '').replace('@ ', '').replace(' =', '=').replace('= ', '=').replace(' ', '')) print(
detok.detokenize(line.strip().split(" "))
.replace(" @", "")
.replace("@ ", "")
.replace(" =", "=")
.replace("= ", "=")
.replace(" ", "")
)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -7,21 +7,22 @@ import math
from multiprocessing import Pool from multiprocessing import Pool
import numpy as np import numpy as np
from fairseq import options from fairseq import options
from fairseq.data import dictionary from fairseq.data import dictionary
from fairseq.scoring import bleu from fairseq.scoring import bleu
from . import ( from . import (
rerank_generate, rerank_generate,
rerank_options,
rerank_score_bw, rerank_score_bw,
rerank_score_lm, rerank_score_lm,
rerank_options,
rerank_utils, rerank_utils,
) )
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize): def score_target_hypo(
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
):
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
@ -61,11 +62,21 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
bitext2_score = None bitext2_score = None
bitext2_backwards = None bitext2_backwards = None
score = rerank_utils.get_score(a, b, c, target_len, score = rerank_utils.get_score(
bitext1.rescore_score[i], bitext2_score, lm_score=lm_score, a,
lenpen=lenpen, src_len=bitext1.source_lengths[i], b,
tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards, c,
bitext2_backwards=bitext2_backwards, normalize=normalize) target_len,
bitext1.rescore_score[i],
bitext2_score,
lm_score=lm_score,
lenpen=lenpen,
src_len=bitext1.source_lengths[i],
tgt_len=bitext1.target_lengths[i],
bitext1_backwards=bitext1.backwards,
bitext2_backwards=bitext2_backwards,
normalize=normalize,
)
if score > best_score: if score > best_score:
best_score = score best_score = score
@ -88,8 +99,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
for key in range(len(gen_keys)): for key in range(len(gen_keys)):
if args.prefix_len is None: if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
"pred and rescore hypo mismatch: i: " + str(key) + ", " "pred and rescore hypo mismatch: i: "
+ str(hypo_lst[key]) + str(gen_keys[key]) + str(key)
+ ", "
+ str(hypo_lst[key])
+ str(gen_keys[key])
+ str(gen_output.no_bpe_hypo[key]) + str(gen_output.no_bpe_hypo[key])
) )
sys_tok = dict.encode_line(hypo_lst[key]) sys_tok = dict.encode_line(hypo_lst[key])
@ -97,7 +111,9 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
scorer.add(ref_tok, sys_tok) scorer.add(ref_tok, sys_tok)
else: else:
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) full_hypo = rerank_utils.get_full_from_prefix(
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
)
sys_tok = dict.encode_line(full_hypo) sys_tok = dict.encode_line(full_hypo)
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok) scorer.add(ref_tok, sys_tok)
@ -107,20 +123,31 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
# recover the orinal ids from n best list generation # recover the orinal ids from n best list generation
for key in range(len(gen_output.no_bpe_target)): for key in range(len(gen_output.no_bpe_target)):
if args.prefix_len is None: if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
"pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key]) "pred and rescore hypo mismatch:"
+ "i:"
+ str(key)
+ str(hypo_lst[key])
+ str(gen_output.no_bpe_hypo[key])
)
ordered_hypos[gen_keys[key]] = hypo_lst[key] ordered_hypos[gen_keys[key]] = hypo_lst[key]
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
gen_keys[key]
]
else: else:
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) full_hypo = rerank_utils.get_full_from_prefix(
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
)
ordered_hypos[gen_keys[key]] = full_hypo ordered_hypos[gen_keys[key]] = full_hypo
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
gen_keys[key]
]
# write the hypos in the original order from nbest list generation # write the hypos in the original order from nbest list generation
if args.num_shards == (len(bitext1_lst)): if args.num_shards == (len(bitext1_lst)):
with open(target_outfile, 'w') as t: with open(target_outfile, "w") as t:
with open(hypo_outfile, 'w') as h: with open(hypo_outfile, "w") as h:
for key in range(len(ordered_hypos)): for key in range(len(ordered_hypos)):
t.write(ordered_targets[key]) t.write(ordered_targets[key])
h.write(ordered_hypos[key]) h.write(ordered_hypos[key])
@ -135,17 +162,38 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
def match_target_hypo(args, target_outfile, hypo_outfile): def match_target_hypo(args, target_outfile, hypo_outfile):
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
if len(args.weight1) == 1: if len(args.weight1) == 1:
res = score_target_hypo(args, args.weight1[0], args.weight2[0], res = score_target_hypo(
args.weight3[0], args.lenpen[0], target_outfile, args,
hypo_outfile, True, args.normalize) args.weight1[0],
args.weight2[0],
args.weight3[0],
args.lenpen[0],
target_outfile,
hypo_outfile,
True,
args.normalize,
)
rerank_scores = [res] rerank_scores = [res]
else: else:
print("launching pool") print("launching pool")
with Pool(32) as p: with Pool(32) as p:
rerank_scores = p.starmap(score_target_hypo, rerank_scores = p.starmap(
[(args, args.weight1[i], args.weight2[i], args.weight3[i], score_target_hypo,
args.lenpen[i], target_outfile, hypo_outfile, [
False, args.normalize) for i in range(len(args.weight1))]) (
args,
args.weight1[i],
args.weight2[i],
args.weight3[i],
args.lenpen[i],
target_outfile,
hypo_outfile,
False,
args.normalize,
)
for i in range(len(args.weight1))
],
)
if len(rerank_scores) > 1: if len(rerank_scores) > 1:
best_index = np.argmax(rerank_scores) best_index = np.argmax(rerank_scores)
@ -155,11 +203,22 @@ def match_target_hypo(args, target_outfile, hypo_outfile):
print("best weight1", args.weight1[best_index]) print("best weight1", args.weight1[best_index])
print("best weight2", args.weight2[best_index]) print("best weight2", args.weight2[best_index])
print("best weight3", args.weight3[best_index]) print("best weight3", args.weight3[best_index])
return args.lenpen[best_index], args.weight1[best_index], \ return (
args.weight2[best_index], args.weight3[best_index], best_score args.lenpen[best_index],
args.weight1[best_index],
args.weight2[best_index],
args.weight3[best_index],
best_score,
)
else: else:
return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0] return (
args.lenpen[0],
args.weight1[0],
args.weight2[0],
args.weight3[0],
rerank_scores[0],
)
def load_score_files(args): def load_score_files(args):
@ -175,55 +234,100 @@ def load_score_files(args):
for shard_id in shard_ids: for shard_id in shard_ids:
using_nbest = args.nbest_list is not None using_nbest = args.nbest_list is not None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ (
backwards_preprocessed_dir, lm_preprocessed_dir = \ pre_gen,
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, left_to_right_preprocessed_dir,
args.gen_model_name, shard_id, args.num_shards, args.sampling, right_to_left_preprocessed_dir,
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None rerank1_is_gen = (
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, score1_file = rerank_utils.rescore_file_name(
target_prefix_frac=args.target_prefix_frac, pre_gen,
source_prefix_frac=args.source_prefix_frac, args.prefix_len,
backwards=args.backwards1) args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None: if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, score2_file = rerank_utils.rescore_file_name(
target_prefix_frac=args.target_prefix_frac, pre_gen,
source_prefix_frac=args.source_prefix_frac, args.prefix_len,
backwards=args.backwards2) args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
if args.language_model is not None: if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) lm_score_file = rerank_utils.rescore_file_name(
pre_gen, args.prefix_len, args.lm_name, lm_file=True
)
# get gen output # get gen output
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest: if using_nbest:
print("Using predefined n-best list from interactive.py") print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, gen_output = rerank_utils.BitextOutputFromGen(
nbest=using_nbest, prefix_len=args.prefix_len, predictions_bpe_file,
target_prefix_frac=args.target_prefix_frac) bpe_symbol=args.remove_bpe,
nbest=using_nbest,
prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac,
)
if rerank1_is_gen: if rerank1_is_gen:
bitext1 = gen_output bitext1 = gen_output
else: else:
bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1, bitext1 = rerank_utils.BitextOutput(
args.remove_bpe, args.prefix_len, args.target_prefix_frac, score1_file,
args.source_prefix_frac) args.backwards1,
args.right_to_left1,
args.remove_bpe,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
if args.score_model2 is not None or args.nbest_list is not None: if args.score_model2 is not None or args.nbest_list is not None:
if rerank2_is_gen: if rerank2_is_gen:
bitext2 = gen_output bitext2 = gen_output
else: else:
bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2, bitext2 = rerank_utils.BitextOutput(
args.remove_bpe, args.prefix_len, args.target_prefix_frac, score2_file,
args.source_prefix_frac) args.backwards2,
args.right_to_left2,
args.remove_bpe,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
assert bitext2.source_lengths == bitext1.source_lengths, \ assert (
"source lengths for rescoring models do not match" bitext2.source_lengths == bitext1.source_lengths
assert bitext2.target_lengths == bitext1.target_lengths, \ ), "source lengths for rescoring models do not match"
"target lengths for rescoring models do not match" assert (
bitext2.target_lengths == bitext1.target_lengths
), "target lengths for rescoring models do not match"
else: else:
if args.diff_bpe: if args.diff_bpe:
assert args.score_model2 is None assert args.score_model2 is None
@ -232,8 +336,13 @@ def load_score_files(args):
bitext2 = None bitext2 = None
if args.language_model is not None: if args.language_model is not None:
lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len, lm_res1 = rerank_utils.LMOutput(
args.remove_bpe, args.target_prefix_frac) lm_score_file,
args.lm_dict,
args.prefix_len,
args.remove_bpe,
args.target_prefix_frac,
)
else: else:
lm_res1 = None lm_res1 = None
@ -259,28 +368,46 @@ def rerank(args):
shard_ids = [args.shard_id] shard_ids = [args.shard_id]
for shard_id in shard_ids: for shard_id in shard_ids:
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ (
backwards_preprocessed_dir, lm_preprocessed_dir = \ pre_gen,
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, left_to_right_preprocessed_dir,
args.gen_model_name, shard_id, args.num_shards, args.sampling, right_to_left_preprocessed_dir,
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
rerank_generate.gen_and_reprocess_nbest(args) rerank_generate.gen_and_reprocess_nbest(args)
rerank_score_bw.score_bw(args) rerank_score_bw.score_bw(args)
rerank_score_lm.score_lm(args) rerank_score_lm.score_lm(args)
if args.write_hypos is None: if args.write_hypos is None:
write_targets = pre_gen+"/matched_targets" write_targets = pre_gen + "/matched_targets"
write_hypos = pre_gen+"/matched_hypos" write_hypos = pre_gen + "/matched_hypos"
else: else:
write_targets = args.write_hypos+"_targets" + args.gen_subset write_targets = args.write_hypos + "_targets" + args.gen_subset
write_hypos = args.write_hypos+"_hypos" + args.gen_subset write_hypos = args.write_hypos + "_hypos" + args.gen_subset
if args.all_shards: if args.all_shards:
write_targets += "_all_shards" write_targets += "_all_shards"
write_hypos += "_all_shards" write_hypos += "_all_shards"
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \ (
match_target_hypo(args, write_targets, write_hypos) best_lenpen,
best_weight1,
best_weight2,
best_weight3,
best_score,
) = match_target_hypo(args, write_targets, write_hypos)
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
@ -291,5 +418,5 @@ def cli_main():
rerank(args) rerank(args)
if __name__ == '__main__': if __name__ == "__main__":
cli_main() cli_main()

View File

@ -8,9 +8,9 @@
Generate n-best translations using a trained model. Generate n-best translations using a trained model.
""" """
from contextlib import redirect_stdout
import os import os
import subprocess import subprocess
from contextlib import redirect_stdout
from fairseq import options from fairseq import options
from fairseq_cli import generate, preprocess from fairseq_cli import generate, preprocess
@ -22,8 +22,12 @@ def gen_and_reprocess_nbest(args):
if args.score_dict_dir is None: if args.score_dict_dir is None:
args.score_dict_dir = args.data args.score_dict_dir = args.data
if args.prefix_len is not None: if args.prefix_len is not None:
assert args.right_to_left1 is False, "prefix length not compatible with right to left models" assert (
assert args.right_to_left2 is False, "prefix length not compatible with right to left models" args.right_to_left1 is False
), "prefix length not compatible with right to left models"
assert (
args.right_to_left2 is False
), "prefix length not compatible with right to left models"
if args.nbest_list is not None: if args.nbest_list is not None:
assert args.score_model2 is None assert args.score_model2 is None
@ -35,27 +39,50 @@ def gen_and_reprocess_nbest(args):
scorer1_src = args.source_lang scorer1_src = args.source_lang
scorer1_tgt = args.target_lang scorer1_tgt = args.target_lang
store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name store_data = (
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
)
if not os.path.exists(store_data): if not os.path.exists(store_data):
os.makedirs(store_data) os.makedirs(store_data)
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ (
backwards_preprocessed_dir, lm_preprocessed_dir = \ pre_gen,
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, left_to_right_preprocessed_dir,
args.gen_model_name, args.shard_id, args.num_shards, right_to_left_preprocessed_dir,
args.sampling, args.prefix_len, args.target_prefix_frac, backwards_preprocessed_dir,
args.source_prefix_frac) lm_preprocessed_dir,
assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported" ) = rerank_utils.get_directories(
assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported" args.data_dir_name,
assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \ args.num_rescore,
"target prefix frac and target prefix len incompatible" args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
assert not (
args.right_to_left1 and args.backwards1
), "backwards right to left not supported"
assert not (
args.right_to_left2 and args.backwards2
), "backwards right to left not supported"
assert not (
args.prefix_len is not None and args.target_prefix_frac is not None
), "target prefix frac and target prefix len incompatible"
# make directory to store generation results # make directory to store generation results
if not os.path.exists(pre_gen): if not os.path.exists(pre_gen):
os.makedirs(pre_gen) os.makedirs(pre_gen)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None rerank1_is_gen = (
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
if args.nbest_list is not None: if args.nbest_list is not None:
rerank2_is_gen = True rerank2_is_gen = True
@ -70,17 +97,25 @@ def gen_and_reprocess_nbest(args):
if not os.path.exists(backwards_preprocessed_dir): if not os.path.exists(backwards_preprocessed_dir):
os.makedirs(backwards_preprocessed_dir) os.makedirs(backwards_preprocessed_dir)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, score1_file = rerank_utils.rescore_file_name(
target_prefix_frac=args.target_prefix_frac, pre_gen,
source_prefix_frac=args.source_prefix_frac, args.prefix_len,
backwards=args.backwards1) args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None: if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, score2_file = rerank_utils.rescore_file_name(
target_prefix_frac=args.target_prefix_frac, pre_gen,
source_prefix_frac=args.source_prefix_frac, args.prefix_len,
backwards=args.backwards2) args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
using_nbest = args.nbest_list is not None using_nbest = args.nbest_list is not None
@ -92,17 +127,29 @@ def gen_and_reprocess_nbest(args):
if not os.path.isfile(predictions_bpe_file): if not os.path.isfile(predictions_bpe_file):
print("STEP 1: generate predictions using the p(T|S) model with bpe") print("STEP 1: generate predictions using the p(T|S) model with bpe")
print(args.data) print(args.data)
param1 = [args.data, param1 = [
"--path", args.gen_model, args.data,
"--shard-id", str(args.shard_id), "--path",
"--num-shards", str(args.num_shards), args.gen_model,
"--nbest", str(args.num_rescore), "--shard-id",
"--batch-size", str(args.batch_size), str(args.shard_id),
"--beam", str(args.num_rescore), "--num-shards",
"--batch-size", str(args.num_rescore), str(args.num_shards),
"--gen-subset", args.gen_subset, "--nbest",
"--source-lang", args.source_lang, str(args.num_rescore),
"--target-lang", args.target_lang] "--batch-size",
str(args.batch_size),
"--beam",
str(args.num_rescore),
"--batch-size",
str(args.num_rescore),
"--gen-subset",
args.gen_subset,
"--source-lang",
args.source_lang,
"--target-lang",
args.target_lang,
]
if args.sampling: if args.sampling:
param1 += ["--sampling"] param1 += ["--sampling"]
@ -110,124 +157,229 @@ def gen_and_reprocess_nbest(args):
input_args = options.parse_args_and_arch(gen_parser, param1) input_args = options.parse_args_and_arch(gen_parser, param1)
print(input_args) print(input_args)
with open(predictions_bpe_file, 'w') as f: with open(predictions_bpe_file, "w") as f:
with redirect_stdout(f): with redirect_stdout(f):
generate.main(input_args) generate.main(input_args)
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, gen_output = rerank_utils.BitextOutputFromGen(
nbest=using_nbest, prefix_len=args.prefix_len, predictions_bpe_file,
target_prefix_frac=args.target_prefix_frac) bpe_symbol=args.remove_bpe,
nbest=using_nbest,
prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac,
)
if args.diff_bpe: if args.diff_bpe:
rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, rerank_utils.write_reprocessed(
gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang, gen_output.no_bpe_source,
pre_gen+"/target_gen_bpe."+args.target_lang, gen_output.no_bpe_hypo,
pre_gen+"/reference_gen_bpe."+args.target_lang) gen_output.no_bpe_target,
pre_gen + "/source_gen_bpe." + args.source_lang,
pre_gen + "/target_gen_bpe." + args.target_lang,
pre_gen + "/reference_gen_bpe." + args.target_lang,
)
bitext_bpe = args.rescore_bpe_code bitext_bpe = args.rescore_bpe_code
bpe_src_param = ["-c", bitext_bpe, bpe_src_param = [
"--input", pre_gen+"/source_gen_bpe."+args.source_lang, "-c",
"--output", pre_gen+"/rescore_data."+args.source_lang] bitext_bpe,
bpe_tgt_param = ["-c", bitext_bpe, "--input",
"--input", pre_gen+"/target_gen_bpe."+args.target_lang, pre_gen + "/source_gen_bpe." + args.source_lang,
"--output", pre_gen+"/rescore_data."+args.target_lang] "--output",
pre_gen + "/rescore_data." + args.source_lang,
]
bpe_tgt_param = [
"-c",
bitext_bpe,
"--input",
pre_gen + "/target_gen_bpe." + args.target_lang,
"--output",
pre_gen + "/rescore_data." + args.target_lang,
]
subprocess.call(["python", subprocess.call(
os.path.join(os.path.dirname(__file__), [
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, "python",
shell=False) os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_src_param,
shell=False,
)
subprocess.call(["python", subprocess.call(
os.path.join(os.path.dirname(__file__), [
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param, "python",
shell=False) os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_tgt_param,
shell=False,
)
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \ if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
(args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen): args.score_model2 is not None
print("STEP 2: process the output of generate.py so we have clean text files with the translations") and not os.path.isfile(score2_file)
and not rerank2_is_gen
):
print(
"STEP 2: process the output of generate.py so we have clean text files with the translations"
)
rescore_file = "/rescore_data" rescore_file = "/rescore_data"
if args.prefix_len is not None: if args.prefix_len is not None:
prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len) prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
if args.target_prefix_frac is not None: if args.target_prefix_frac is not None:
target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac) target_prefix_frac_rescore_file = (
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
)
if args.source_prefix_frac is not None: if args.source_prefix_frac is not None:
source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac) source_prefix_frac_rescore_file = (
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
)
if not args.right_to_left1 or not args.right_to_left2: if not args.right_to_left1 or not args.right_to_left2:
if not args.diff_bpe: if not args.diff_bpe:
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, rerank_utils.write_reprocessed(
pre_gen+rescore_file+"."+args.source_lang, gen_output.source,
pre_gen+rescore_file+"."+args.target_lang, gen_output.hypo,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe) gen_output.target,
pre_gen + rescore_file + "." + args.source_lang,
pre_gen + rescore_file + "." + args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
)
if args.prefix_len is not None: if args.prefix_len is not None:
bw_rescore_file = prefix_len_rescore_file bw_rescore_file = prefix_len_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, rerank_utils.write_reprocessed(
pre_gen+prefix_len_rescore_file+"."+args.source_lang, gen_output.source,
pre_gen+prefix_len_rescore_file+"."+args.target_lang, gen_output.hypo,
pre_gen+"/reference_file", prefix_len=args.prefix_len, gen_output.target,
bpe_symbol=args.remove_bpe) pre_gen + prefix_len_rescore_file + "." + args.source_lang,
pre_gen + prefix_len_rescore_file + "." + args.target_lang,
pre_gen + "/reference_file",
prefix_len=args.prefix_len,
bpe_symbol=args.remove_bpe,
)
elif args.target_prefix_frac is not None: elif args.target_prefix_frac is not None:
bw_rescore_file = target_prefix_frac_rescore_file bw_rescore_file = target_prefix_frac_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, rerank_utils.write_reprocessed(
pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang, gen_output.source,
pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang, gen_output.hypo,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, gen_output.target,
target_prefix_frac=args.target_prefix_frac) pre_gen
+ target_prefix_frac_rescore_file
+ "."
+ args.source_lang,
pre_gen
+ target_prefix_frac_rescore_file
+ "."
+ args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
target_prefix_frac=args.target_prefix_frac,
)
else: else:
bw_rescore_file = rescore_file bw_rescore_file = rescore_file
if args.source_prefix_frac is not None: if args.source_prefix_frac is not None:
fw_rescore_file = source_prefix_frac_rescore_file fw_rescore_file = source_prefix_frac_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, rerank_utils.write_reprocessed(
pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang, gen_output.source,
pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang, gen_output.hypo,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, gen_output.target,
source_prefix_frac=args.source_prefix_frac) pre_gen
+ source_prefix_frac_rescore_file
+ "."
+ args.source_lang,
pre_gen
+ source_prefix_frac_rescore_file
+ "."
+ args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
source_prefix_frac=args.source_prefix_frac,
)
else: else:
fw_rescore_file = rescore_file fw_rescore_file = rescore_file
if args.right_to_left1 or args.right_to_left2: if args.right_to_left1 or args.right_to_left2:
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, rerank_utils.write_reprocessed(
pre_gen+"/right_to_left_rescore_data."+args.source_lang, gen_output.source,
pre_gen+"/right_to_left_rescore_data."+args.target_lang, gen_output.hypo,
pre_gen+"/right_to_left_reference_file", gen_output.target,
right_to_left=True, bpe_symbol=args.remove_bpe) pre_gen + "/right_to_left_rescore_data." + args.source_lang,
pre_gen + "/right_to_left_rescore_data." + args.target_lang,
pre_gen + "/right_to_left_reference_file",
right_to_left=True,
bpe_symbol=args.remove_bpe,
)
print("STEP 3: binarize the translations") print("STEP 3: binarize the translations")
if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen: if (
not args.right_to_left1
or args.score_model2 is not None
and not args.right_to_left2
or not rerank1_is_gen
):
if args.backwards1 or args.backwards2: if args.backwards1 or args.backwards2:
if args.backwards_score_dict_dir is not None: if args.backwards_score_dict_dir is not None:
bw_dict = args.backwards_score_dict_dir bw_dict = args.backwards_score_dict_dir
else: else:
bw_dict = args.score_dict_dir bw_dict = args.score_dict_dir
bw_preprocess_param = ["--source-lang", scorer1_src, bw_preprocess_param = [
"--target-lang", scorer1_tgt, "--source-lang",
"--trainpref", pre_gen+bw_rescore_file, scorer1_src,
"--srcdict", bw_dict + "/dict." + scorer1_src + ".txt", "--target-lang",
"--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt", scorer1_tgt,
"--destdir", backwards_preprocessed_dir] "--trainpref",
pre_gen + bw_rescore_file,
"--srcdict",
bw_dict + "/dict." + scorer1_src + ".txt",
"--tgtdict",
bw_dict + "/dict." + scorer1_tgt + ".txt",
"--destdir",
backwards_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser() preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(bw_preprocess_param) input_args = preprocess_parser.parse_args(bw_preprocess_param)
preprocess.main(input_args) preprocess.main(input_args)
preprocess_param = ["--source-lang", scorer1_src, preprocess_param = [
"--target-lang", scorer1_tgt, "--source-lang",
"--trainpref", pre_gen+fw_rescore_file, scorer1_src,
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", "--target-lang",
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", scorer1_tgt,
"--destdir", left_to_right_preprocessed_dir] "--trainpref",
pre_gen + fw_rescore_file,
"--srcdict",
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
"--tgtdict",
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
"--destdir",
left_to_right_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser() preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param) input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args) preprocess.main(input_args)
if args.right_to_left1 or args.right_to_left2: if args.right_to_left1 or args.right_to_left2:
preprocess_param = ["--source-lang", scorer1_src, preprocess_param = [
"--target-lang", scorer1_tgt, "--source-lang",
"--trainpref", pre_gen+"/right_to_left_rescore_data", scorer1_src,
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", "--target-lang",
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", scorer1_tgt,
"--destdir", right_to_left_preprocessed_dir] "--trainpref",
pre_gen + "/right_to_left_rescore_data",
"--srcdict",
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
"--tgtdict",
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
"--destdir",
right_to_left_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser() preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param) input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args) preprocess.main(input_args)
@ -241,5 +393,5 @@ def cli_main():
gen_and_reprocess_nbest(args) gen_and_reprocess_nbest(args)
if __name__ == '__main__': if __name__ == "__main__":
cli_main() cli_main()

View File

@ -6,14 +6,14 @@
from fairseq import options from fairseq import options
def get_reranking_parser(default_task='translation'): def get_reranking_parser(default_task="translation"):
parser = options.get_parser('Generation and reranking', default_task) parser = options.get_parser("Generation and reranking", default_task)
add_reranking_args(parser) add_reranking_args(parser)
return parser return parser
def get_tuning_parser(default_task='translation'): def get_tuning_parser(default_task="translation"):
parser = options.get_parser('Reranking tuning', default_task) parser = options.get_parser("Reranking tuning", default_task)
add_reranking_args(parser) add_reranking_args(parser)
add_tuning_args(parser) add_tuning_args(parser)
return parser return parser
@ -110,17 +110,40 @@ def add_reranking_args(parser):
def add_tuning_args(parser): def add_tuning_args(parser):
group = parser.add_argument_group("Tuning") group = parser.add_argument_group("Tuning")
group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float, group.add_argument(
help='lower bound of search space') "--lower-bound",
group.add_argument('--upper-bound', default=[3], nargs='+', type=float, default=[-0.7],
help='upper bound of search space') nargs="+",
group.add_argument('--tune-param', default=['lenpen'], nargs='+', type=float,
choices=['lenpen', 'weight1', 'weight2', 'weight3'], help="lower bound of search space",
help='the parameter(s) to tune') )
group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'], group.add_argument(
help='the subset to tune on ') "--upper-bound",
group.add_argument('--num-trials', default=1000, type=int, default=[3],
help='number of trials to do for random search') nargs="+",
group.add_argument('--share-weights', action='store_true', type=float,
help='share weight2 and weight 3') help="upper bound of search space",
)
group.add_argument(
"--tune-param",
default=["lenpen"],
nargs="+",
choices=["lenpen", "weight1", "weight2", "weight3"],
help="the parameter(s) to tune",
)
group.add_argument(
"--tune-subset",
default="valid",
choices=["valid", "test", "train"],
help="the subset to tune on ",
)
group.add_argument(
"--num-trials",
default=1000,
type=int,
help="number of trials to do for random search",
)
group.add_argument(
"--share-weights", action="store_true", help="share weight2 and weight 3"
)
return group return group

View File

@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from contextlib import redirect_stdout
import os import os
from contextlib import redirect_stdout
from fairseq import options from fairseq import options
from fairseq_cli import generate from fairseq_cli import generate
@ -13,82 +13,124 @@ from . import rerank_options, rerank_utils
def score_bw(args): def score_bw(args):
if args.backwards1: if args.backwards1:
scorer1_src = args.target_lang scorer1_src = args.target_lang
scorer1_tgt = args.source_lang scorer1_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
if args.score_model2 is not None:
if args.backwards2:
scorer2_src = args.target_lang
scorer2_tgt = args.source_lang
else: else:
scorer1_src = args.source_lang scorer2_src = args.source_lang
scorer1_tgt = args.target_lang scorer2_tgt = args.target_lang
if args.score_model2 is not None: rerank1_is_gen = (
if args.backwards2: args.gen_model == args.score_model1 and args.source_prefix_frac is None
scorer2_src = args.target_lang )
scorer2_tgt = args.source_lang rerank2_is_gen = (
else: args.gen_model == args.score_model2 and args.source_prefix_frac is None
scorer2_src = args.source_lang )
scorer2_tgt = args.target_lang
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None (
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ score1_file = rerank_utils.rescore_file_name(
backwards_preprocessed_dir, lm_preprocessed_dir = \ pre_gen,
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, args.prefix_len,
args.gen_model_name, args.shard_id, args.num_shards, args.model1_name,
args.sampling, args.prefix_len, args.target_prefix_frac, target_prefix_frac=args.target_prefix_frac,
args.source_prefix_frac) source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, if args.score_model2 is not None:
target_prefix_frac=args.target_prefix_frac, score2_file = rerank_utils.rescore_file_name(
source_prefix_frac=args.source_prefix_frac, pre_gen,
backwards=args.backwards1) args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
if args.score_model2 is not None: if args.right_to_left1:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, rerank_data1 = right_to_left_preprocessed_dir
target_prefix_frac=args.target_prefix_frac, elif args.backwards1:
source_prefix_frac=args.source_prefix_frac, rerank_data1 = backwards_preprocessed_dir
backwards=args.backwards2) else:
rerank_data1 = left_to_right_preprocessed_dir
if args.right_to_left1: gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
rerank_data1 = right_to_left_preprocessed_dir if not rerank1_is_gen and not os.path.isfile(score1_file):
elif args.backwards1: print("STEP 4: score the translations for model 1")
rerank_data1 = backwards_preprocessed_dir
model_param1 = [
"--path",
args.score_model1,
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
]
gen_model1_param = [rerank_data1] + gen_param + model_param1
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
with open(score1_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
if (
args.score_model2 is not None
and not os.path.isfile(score2_file)
and not rerank2_is_gen
):
print("STEP 4: score the translations for model 2")
if args.right_to_left2:
rerank_data2 = right_to_left_preprocessed_dir
elif args.backwards2:
rerank_data2 = backwards_preprocessed_dir
else: else:
rerank_data1 = left_to_right_preprocessed_dir rerank_data2 = left_to_right_preprocessed_dir
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] model_param2 = [
if not rerank1_is_gen and not os.path.isfile(score1_file): "--path",
print("STEP 4: score the translations for model 1") args.score_model2,
"--source-lang",
scorer2_src,
"--target-lang",
scorer2_tgt,
]
gen_model2_param = [rerank_data2] + gen_param + model_param2
model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt] gen_parser = options.get_generation_parser()
gen_model1_param = [rerank_data1] + gen_param + model_param1 input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
gen_parser = options.get_generation_parser() with open(score2_file, "w") as f:
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) with redirect_stdout(f):
generate.main(input_args)
with open(score1_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
print("STEP 4: score the translations for model 2")
if args.right_to_left2:
rerank_data2 = right_to_left_preprocessed_dir
elif args.backwards2:
rerank_data2 = backwards_preprocessed_dir
else:
rerank_data2 = left_to_right_preprocessed_dir
model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
gen_model2_param = [rerank_data2] + gen_param + model_param2
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
with open(score2_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
def cli_main(): def cli_main():
@ -97,5 +139,5 @@ def cli_main():
score_bw(args) score_bw(args)
if __name__ == '__main__': if __name__ == "__main__":
cli_main() cli_main()

View File

@ -12,22 +12,38 @@ from . import rerank_options, rerank_utils
def score_lm(args): def score_lm(args):
using_nbest = args.nbest_list is not None using_nbest = args.nbest_list is not None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ (
backwards_preprocessed_dir, lm_preprocessed_dir = \ pre_gen,
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, left_to_right_preprocessed_dir,
args.gen_model_name, args.shard_id, args.num_shards, right_to_left_preprocessed_dir,
args.sampling, args.prefix_len, args.target_prefix_frac, backwards_preprocessed_dir,
args.source_prefix_frac) lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest: if using_nbest:
print("Using predefined n-best list from interactive.py") print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest) gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
)
if args.language_model is not None: if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) lm_score_file = rerank_utils.rescore_file_name(
pre_gen, args.prefix_len, args.lm_name, lm_file=True
)
if args.language_model is not None and not os.path.isfile(lm_score_file): if args.language_model is not None and not os.path.isfile(lm_score_file):
print("STEP 4.5: language modeling for P(T)") print("STEP 4.5: language modeling for P(T)")
@ -38,10 +54,21 @@ def score_lm(args):
else: else:
bpe_status = "different" bpe_status = "different"
rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen, rerank_utils.lm_scoring(
args.lm_dict, args.lm_name, args.language_model, lm_preprocessed_dir,
args.lm_bpe_code, 128, lm_score_file, args.target_lang, bpe_status,
args.source_lang, prefix_len=args.prefix_len) gen_output,
pre_gen,
args.lm_dict,
args.lm_name,
args.language_model,
args.lm_bpe_code,
128,
lm_score_file,
args.target_lang,
args.source_lang,
prefix_len=args.prefix_len,
)
def cli_main(): def cli_main():
@ -50,5 +77,5 @@ def cli_main():
score_lm(args) score_lm(args)
if __name__ == '__main__': if __name__ == "__main__":
cli_main() cli_main()

View File

@ -5,8 +5,8 @@
import argparse import argparse
import random import random
import numpy as np
import numpy as np
from fairseq import options from fairseq import options
from . import rerank, rerank_options from . import rerank, rerank_options
@ -14,7 +14,7 @@ from . import rerank, rerank_options
def random_search(args): def random_search(args):
param_values = [] param_values = []
tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3'] tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3] initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
for i, elem in enumerate(initial_params): for i, elem in enumerate(initial_params):
if type(elem) is not list: if type(elem) is not list:
@ -33,51 +33,60 @@ def random_search(args):
param_values += initial_params param_values += initial_params
random.seed(args.seed) random.seed(args.seed)
random_params = np.array([ random_params = np.array(
[random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))] [
for k in range(args.num_trials) [
]) random.uniform(args.lower_bound[i], args.upper_bound[i])
set_params = np.array([ for i in range(len(args.tune_param))
[initial_params[i][0] for i in range(len(tuneable_parameters))] ]
for k in range(args.num_trials) for k in range(args.num_trials)
]) ]
)
set_params = np.array(
[
[initial_params[i][0] for i in range(len(tuneable_parameters))]
for k in range(args.num_trials)
]
)
random_params = np.concatenate((random_params, set_params), 1) random_params = np.concatenate((random_params, set_params), 1)
rerank_args = vars(args).copy() rerank_args = vars(args).copy()
if args.nbest_list: if args.nbest_list:
rerank_args['gen_subset'] = 'test' rerank_args["gen_subset"] = "test"
else: else:
rerank_args['gen_subset'] = args.tune_subset rerank_args["gen_subset"] = args.tune_subset
for k in range(len(tune_parameters)): for k in range(len(tune_parameters)):
rerank_args[tune_parameters[k]] = list(random_params[:, k]) rerank_args[tune_parameters[k]] = list(random_params[:, k])
if args.share_weights: if args.share_weights:
k = tune_parameters.index('weight2') k = tune_parameters.index("weight2")
rerank_args['weight3'] = list(random_params[:, k]) rerank_args["weight3"] = list(random_params[:, k])
rerank_args = argparse.Namespace(**rerank_args) rerank_args = argparse.Namespace(**rerank_args)
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args) best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
rerank_args
)
rerank_args = vars(args).copy() rerank_args = vars(args).copy()
rerank_args['lenpen'] = [best_lenpen] rerank_args["lenpen"] = [best_lenpen]
rerank_args['weight1'] = [best_weight1] rerank_args["weight1"] = [best_weight1]
rerank_args['weight2'] = [best_weight2] rerank_args["weight2"] = [best_weight2]
rerank_args['weight3'] = [best_weight3] rerank_args["weight3"] = [best_weight3]
# write the hypothesis from the valid set from the best trial # write the hypothesis from the valid set from the best trial
if args.gen_subset != "valid": if args.gen_subset != "valid":
rerank_args['gen_subset'] = "valid" rerank_args["gen_subset"] = "valid"
rerank_args = argparse.Namespace(**rerank_args) rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args) rerank.rerank(rerank_args)
# test with the best hyperparameters on gen subset # test with the best hyperparameters on gen subset
rerank_args = vars(args).copy() rerank_args = vars(args).copy()
rerank_args['gen_subset'] = args.gen_subset rerank_args["gen_subset"] = args.gen_subset
rerank_args['lenpen'] = [best_lenpen] rerank_args["lenpen"] = [best_lenpen]
rerank_args['weight1'] = [best_weight1] rerank_args["weight1"] = [best_weight1]
rerank_args['weight2'] = [best_weight2] rerank_args["weight2"] = [best_weight2]
rerank_args['weight3'] = [best_weight3] rerank_args["weight3"] = [best_weight3]
rerank_args = argparse.Namespace(**rerank_args) rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args) rerank.rerank(rerank_args)
@ -89,5 +98,5 @@ def cli_main():
random_search(args) random_search(args)
if __name__ == '__main__': if __name__ == "__main__":
cli_main() cli_main()

View File

@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from contextlib import redirect_stdout
import math import math
import os import os
import re import re
import subprocess import subprocess
from contextlib import redirect_stdout
from fairseq import options from fairseq import options
from fairseq_cli import eval_lm, preprocess from fairseq_cli import eval_lm, preprocess
@ -20,7 +20,7 @@ def reprocess(fle):
# per source, so the values for hypothesis_dict are lists. # per source, so the values for hypothesis_dict are lists.
# parses output of generate.py # parses output of generate.py
with open(fle, 'r') as f: with open(fle, "r") as f:
txt = f.read() txt = f.read()
"""reprocess generate.py output""" """reprocess generate.py output"""
@ -45,7 +45,9 @@ def reprocess(fle):
if line_type == "H": if line_type == "H":
h_txt = line[j:] h_txt = line[j:]
hypo = re.search(hp, h_txt) hypo = re.search(hp, h_txt)
assert hypo is not None, ("regular expression failed to find the hypothesis scoring") assert (
hypo is not None
), "regular expression failed to find the hypothesis scoring"
_, i = hypo.span() _, i = hypo.span()
score = hypo.group() score = hypo.group()
if id_num in hypothesis_dict: if id_num in hypothesis_dict:
@ -56,9 +58,9 @@ def reprocess(fle):
score_dict[id_num] = [float(score)] score_dict[id_num] = [float(score)]
elif line_type == "S": elif line_type == "S":
source_dict[id_num] = (line[j:]) source_dict[id_num] = line[j:]
elif line_type == "T": elif line_type == "T":
target_dict[id_num] = (line[j:]) target_dict[id_num] = line[j:]
elif line_type == "P": elif line_type == "P":
pos_scores = (line[j:]).split() pos_scores = (line[j:]).split()
pos_scores = [float(x) for x in pos_scores] pos_scores = [float(x) for x in pos_scores]
@ -72,7 +74,7 @@ def reprocess(fle):
def reprocess_nbest(fle): def reprocess_nbest(fle):
"""reprocess interactive.py output""" """reprocess interactive.py output"""
with open(fle, 'r') as f: with open(fle, "r") as f:
txt = f.read() txt = f.read()
source_dict = {} source_dict = {}
@ -82,7 +84,7 @@ def reprocess_nbest(fle):
pos_score_dict = {} pos_score_dict = {}
lines = txt.split("\n") lines = txt.split("\n")
hp = re.compile(r'[-]?\d+[.]?\d+') hp = re.compile(r"[-]?\d+[.]?\d+")
j = -1 j = -1
for _i, line in enumerate(lines): for _i, line in enumerate(lines):
@ -119,59 +121,76 @@ def reprocess_nbest(fle):
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
def write_reprocessed(sources, hypos, targets, source_outfile, def write_reprocessed(
hypo_outfile, target_outfile, right_to_left=False, sources,
prefix_len=None, bpe_symbol=None, hypos,
target_prefix_frac=None, source_prefix_frac=None): targets,
source_outfile,
hypo_outfile,
target_outfile,
right_to_left=False,
prefix_len=None,
bpe_symbol=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
"""writes nbest hypothesis for rescoring""" """writes nbest hypothesis for rescoring"""
assert not (prefix_len is not None and target_prefix_frac is not None), \ assert not (
"in writing reprocessed, only one type of prefix may be used" prefix_len is not None and target_prefix_frac is not None
assert not (prefix_len is not None and source_prefix_frac is not None), \ ), "in writing reprocessed, only one type of prefix may be used"
"in writing reprocessed, only one type of prefix may be used" assert not (
assert not (target_prefix_frac is not None and source_prefix_frac is not None), \ prefix_len is not None and source_prefix_frac is not None
"in writing reprocessed, only one type of prefix may be used" ), "in writing reprocessed, only one type of prefix may be used"
assert not (
target_prefix_frac is not None and source_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
with open(source_outfile, 'w') as source_file, \ with open(source_outfile, "w") as source_file, open(
open(hypo_outfile, 'w') as hypo_file, \ hypo_outfile, "w"
open(target_outfile, 'w') as target_file: ) as hypo_file, open(target_outfile, "w") as target_file:
assert len(sources) == len(hypos), "sources and hypos list length mismatch" assert len(sources) == len(hypos), "sources and hypos list length mismatch"
if right_to_left: if right_to_left:
for i in range(len(sources)): for i in range(len(sources)):
for j in range(len(hypos[i])): for j in range(len(hypos[i])):
if prefix_len is None: if prefix_len is None:
hypo_file.write(make_right_to_left(hypos[i][j])+"\n") hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
else: else:
raise NotImplementedError() raise NotImplementedError()
source_file.write(make_right_to_left(sources[i])+"\n") source_file.write(make_right_to_left(sources[i]) + "\n")
target_file.write(make_right_to_left(targets[i])+"\n") target_file.write(make_right_to_left(targets[i]) + "\n")
else: else:
for i in sorted(sources.keys()): for i in sorted(sources.keys()):
for j in range(len(hypos[i])): for j in range(len(hypos[i])):
if prefix_len is not None: if prefix_len is not None:
shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n" shortened = (
hypo_file.write(shortened) get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
source_file.write(sources[i]) + "\n"
target_file.write(targets[i]) )
elif target_prefix_frac is not None: hypo_file.write(shortened)
num_words, shortened, num_bpe_tokens = \ source_file.write(sources[i])
calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol) target_file.write(targets[i])
shortened += "\n" elif target_prefix_frac is not None:
hypo_file.write(shortened) num_words, shortened, num_bpe_tokens = calc_length_from_frac(
source_file.write(sources[i]) hypos[i][j], target_prefix_frac, bpe_symbol
target_file.write(targets[i]) )
elif source_prefix_frac is not None: shortened += "\n"
num_words, shortened, num_bpe_tokensn = \ hypo_file.write(shortened)
calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol) source_file.write(sources[i])
shortened += "\n" target_file.write(targets[i])
hypo_file.write(hypos[i][j]) elif source_prefix_frac is not None:
source_file.write(shortened) num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
target_file.write(targets[i]) sources[i], source_prefix_frac, bpe_symbol
else: )
hypo_file.write(hypos[i][j]) shortened += "\n"
source_file.write(sources[i]) hypo_file.write(hypos[i][j])
target_file.write(targets[i]) source_file.write(shortened)
target_file.write(targets[i])
else:
hypo_file.write(hypos[i][j])
source_file.write(sources[i])
target_file.write(targets[i])
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol): def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
@ -207,7 +226,9 @@ def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
if bpe_count == 0: if bpe_count == 0:
return sentence[:prefix_len] return sentence[:prefix_len]
else: else:
return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count) return sentence[:prefix_len] + get_prefix_from_len(
sentence[prefix_len:], bpe_symbol, bpe_count
)
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len): def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
@ -225,9 +246,9 @@ def make_right_to_left(line):
def remove_bpe(line, bpe_symbol): def remove_bpe(line, bpe_symbol):
line = line.replace("\n", '') line = line.replace("\n", "")
line = (line + ' ').replace(bpe_symbol, '').rstrip() line = (line + " ").replace(bpe_symbol, "").rstrip()
return line+("\n") return line + ("\n")
def remove_bpe_dict(pred_dict, bpe_symbol): def remove_bpe_dict(pred_dict, bpe_symbol):
@ -242,7 +263,7 @@ def remove_bpe_dict(pred_dict, bpe_symbol):
def parse_bleu_scoring(line): def parse_bleu_scoring(line):
p = re.compile(r'(BLEU4 = )\d+[.]\d+') p = re.compile(r"(BLEU4 = )\d+[.]\d+")
res = re.search(p, line) res = re.search(p, line)
assert res is not None, line assert res is not None, line
return float(res.group()[8:]) return float(res.group()[8:])
@ -259,9 +280,21 @@ def get_full_from_prefix(hypo_prefix, hypos):
raise Exception() raise Exception()
def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None, def get_score(
lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False, a,
bitext2_backwards=False, normalize=False): b,
c,
target_len,
bitext_score1,
bitext_score2=None,
lm_score=None,
lenpen=None,
src_len=None,
tgt_len=None,
bitext1_backwards=False,
bitext2_backwards=False,
normalize=False,
):
if bitext1_backwards: if bitext1_backwards:
bitext1_norm = src_len bitext1_norm = src_len
else: else:
@ -275,9 +308,13 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
bitext2_norm = 1 bitext2_norm = 1
bitext_score2 = 0 bitext_score2 = 0
if normalize: if normalize:
score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len score = (
a * bitext_score1 / bitext1_norm
+ b * bitext_score2 / bitext2_norm
+ c * lm_score / src_len
)
else: else:
score = a*bitext_score1 + b*bitext_score2+c*lm_score score = a * bitext_score1 + b * bitext_score2 + c * lm_score
if lenpen is not None: if lenpen is not None:
score /= (target_len) ** float(lenpen) score /= (target_len) ** float(lenpen)
@ -286,8 +323,16 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
class BitextOutput(object): class BitextOutput(object):
def __init__(self, output_file, backwards, right_to_left, bpe_symbol, def __init__(
prefix_len=None, target_prefix_frac=None, source_prefix_frac=None): self,
output_file,
backwards,
right_to_left,
bpe_symbol,
prefix_len=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
"""process output from rescoring""" """process output from rescoring"""
source, hypo, score, target, pos_score = reprocess(output_file) source, hypo, score, target, pos_score = reprocess(output_file)
if backwards: if backwards:
@ -296,7 +341,9 @@ class BitextOutput(object):
self.hypo_fracs = target_prefix_frac self.hypo_fracs = target_prefix_frac
# remove length penalty so we can use raw scores # remove length penalty so we can use raw scores
score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards) score, num_bpe_tokens = get_score_from_pos(
pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
)
source_lengths = {} source_lengths = {}
target_lengths = {} target_lengths = {}
@ -341,7 +388,9 @@ class BitextOutput(object):
score[i] = float(score[i][0]) score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0] pos_score[i] = pos_score[i][0]
else: else:
assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence" assert (
len(hypo[i]) == 1
), "expected only one hypothesis per source sentence"
source[i] = remove_bpe(source[i], bpe_symbol) source[i] = remove_bpe(source[i], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol) target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol) hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
@ -360,11 +409,26 @@ class BitextOutput(object):
class BitextOutputFromGen(object): class BitextOutputFromGen(object):
def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None): def __init__(
self,
predictions_bpe_file,
bpe_symbol=None,
nbest=False,
prefix_len=None,
target_prefix_frac=None,
):
if nbest: if nbest:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file) (
pred_source,
pred_hypo,
pred_score,
pred_target,
pred_pos_score,
) = reprocess_nbest(predictions_bpe_file)
else: else:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file) pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
predictions_bpe_file
)
assert len(pred_source) == len(pred_hypo) assert len(pred_source) == len(pred_hypo)
assert len(pred_source) == len(pred_score) assert len(pred_source) == len(pred_score)
@ -372,8 +436,9 @@ class BitextOutputFromGen(object):
assert len(pred_source) == len(pred_pos_score) assert len(pred_source) == len(pred_pos_score)
# remove length penalty so we can use raw scores # remove length penalty so we can use raw scores
pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo, pred_score, num_bpe_tokens = get_score_from_pos(
bpe_symbol, target_prefix_frac, False) pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
)
self.source = pred_source self.source = pred_source
self.target = pred_target self.target = pred_target
@ -414,7 +479,9 @@ class BitextOutputFromGen(object):
index += 1 index += 1
def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards): def get_score_from_pos(
pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
):
score_dict = {} score_dict = {}
num_bpe_tokens_dict = {} num_bpe_tokens_dict = {}
assert prefix_len is None or hypo_frac is None assert prefix_len is None or hypo_frac is None
@ -423,11 +490,15 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
num_bpe_tokens_dict[key] = [] num_bpe_tokens_dict[key] = []
for i in range(len(pos_score_dict[key])): for i in range(len(pos_score_dict[key])):
if prefix_len is not None and not backwards: if prefix_len is not None and not backwards:
num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len) num_bpe_tokens = get_num_bpe_tokens_from_len(
hypo_dict[key][i], bpe_symbol, prefix_len
)
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens])) score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
num_bpe_tokens_dict[key].append(num_bpe_tokens) num_bpe_tokens_dict[key].append(num_bpe_tokens)
elif hypo_frac is not None: elif hypo_frac is not None:
num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol) num_words, shortened, hypo_prefix_len = calc_length_from_frac(
hypo_dict[key][i], hypo_frac, bpe_symbol
)
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len])) score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
num_bpe_tokens_dict[key].append(hypo_prefix_len) num_bpe_tokens_dict[key].append(hypo_prefix_len)
else: else:
@ -437,10 +508,26 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
class LMOutput(object): class LMOutput(object):
def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): def __init__(
lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \ self,
parse_lm(lm_score_file, prefix_len=prefix_len, lm_score_file,
bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac) lm_dict=None,
prefix_len=None,
bpe_symbol=None,
target_prefix_frac=None,
):
(
lm_sentences,
lm_sen_scores,
lm_sen_pos_scores,
lm_no_bpe_sentences,
lm_bpe_tokens,
) = parse_lm(
lm_score_file,
prefix_len=prefix_len,
bpe_symbol=bpe_symbol,
target_prefix_frac=target_prefix_frac,
)
self.sentences = lm_sentences self.sentences = lm_sentences
self.score = lm_sen_scores self.score = lm_sen_scores
@ -452,7 +539,7 @@ class LMOutput(object):
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
"""parse output of eval_lm""" """parse output of eval_lm"""
with open(input_file, 'r') as f: with open(input_file, "r") as f:
text = f.readlines() text = f.readlines()
text = text[7:] text = text[7:]
cleaned_text = text[:-2] cleaned_text = text[:-2]
@ -467,20 +554,23 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
if tokens[0].isdigit(): if tokens[0].isdigit():
line_id = int(tokens[0]) line_id = int(tokens[0])
scores = [float(x[1:-1]) for x in tokens[2::2]] scores = [float(x[1:-1]) for x in tokens[2::2]]
sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n" sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
if bpe_symbol is not None: if bpe_symbol is not None:
# exclude <eos> symbol to match output from generate.py # exclude <eos> symbol to match output from generate.py
bpe_sen = " ".join(tokens[1::2][:-1])+"\n" bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol) no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
no_bpe_sentences[line_id] = no_bpe_sen no_bpe_sentences[line_id] = no_bpe_sen
if prefix_len is not None: if prefix_len is not None:
num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len) num_bpe_tokens = get_num_bpe_tokens_from_len(
bpe_sen, bpe_symbol, prefix_len
)
sen_scores[line_id] = sum(scores[:num_bpe_tokens]) sen_scores[line_id] = sum(scores[:num_bpe_tokens])
num_bpe_tokens_dict[line_id] = num_bpe_tokens num_bpe_tokens_dict[line_id] = num_bpe_tokens
elif target_prefix_frac is not None: elif target_prefix_frac is not None:
num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac, num_words, shortened, target_prefix_len = calc_length_from_frac(
bpe_symbol) bpe_sen, target_prefix_frac, bpe_symbol
)
sen_scores[line_id] = sum(scores[:target_prefix_len]) sen_scores[line_id] = sum(scores[:target_prefix_len])
num_bpe_tokens_dict[line_id] = target_prefix_len num_bpe_tokens_dict[line_id] = target_prefix_len
else: else:
@ -492,160 +582,269 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
def get_directories(data_dir_name, num_rescore, gen_subset, def get_directories(
fw_name, shard_id, num_shards, data_dir_name,
sampling=False, prefix_len=None, num_rescore,
target_prefix_frac=None, source_prefix_frac=None): gen_subset,
nbest_file_id = "nbest_" + str(num_rescore) + \ fw_name,
"_subset_" + gen_subset + \ shard_id,
"_fw_name_" + fw_name + \ num_shards,
"_shard_" + str(shard_id) + \ sampling=False,
"_of_" + str(num_shards) prefix_len=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
nbest_file_id = (
"nbest_"
+ str(num_rescore)
+ "_subset_"
+ gen_subset
+ "_fw_name_"
+ fw_name
+ "_shard_"
+ str(shard_id)
+ "_of_"
+ str(num_shards)
)
if sampling: if sampling:
nbest_file_id += "_sampling" nbest_file_id += "_sampling"
# the directory containing all information for this nbest list # the directory containing all information for this nbest list
pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id pre_gen = (
os.path.join(os.path.dirname(__file__))
+ "/rerank_data/"
+ data_dir_name
+ "/"
+ nbest_file_id
)
# the directory to store the preprocessed nbest list, for left to right rescoring # the directory to store the preprocessed nbest list, for left to right rescoring
left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed" left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
if source_prefix_frac is not None: if source_prefix_frac is not None:
left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac) left_to_right_preprocessed_dir = (
left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
)
# the directory to store the preprocessed nbest list, for right to left rescoring # the directory to store the preprocessed nbest list, for right to left rescoring
right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed" right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
# the directory to store the preprocessed nbest list, for backwards rescoring # the directory to store the preprocessed nbest list, for backwards rescoring
backwards_preprocessed_dir = pre_gen+"/backwards" backwards_preprocessed_dir = pre_gen + "/backwards"
if target_prefix_frac is not None: if target_prefix_frac is not None:
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac) backwards_preprocessed_dir = (
backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
)
elif prefix_len is not None: elif prefix_len is not None:
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len) backwards_preprocessed_dir = (
backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
)
# the directory to store the preprocessed nbest list, for rescoring with P(T) # the directory to store the preprocessed nbest list, for rescoring with P(T)
lm_preprocessed_dir = pre_gen+"/lm_preprocessed" lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ return (
backwards_preprocessed_dir, lm_preprocessed_dir pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
)
def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen, def lm_scoring(
cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code, preprocess_directory,
batch_size, lm_score_file, target_lang, source_lang, prefix_len=None): bpe_status,
gen_output,
pre_gen,
cur_lm_dict,
cur_lm_name,
cur_language_model,
cur_lm_bpe_code,
batch_size,
lm_score_file,
target_lang,
source_lang,
prefix_len=None,
):
if prefix_len is not None: if prefix_len is not None:
assert bpe_status == "different", "bpe status must be different to use prefix len" assert (
bpe_status == "different"
), "bpe status must be different to use prefix len"
if bpe_status == "no bpe": if bpe_status == "no bpe":
# run lm on output without bpe # run lm on output without bpe
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, write_reprocessed(
gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de", gen_output.no_bpe_source,
pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe") gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
pre_gen + "/rescore_data_no_bpe.de",
pre_gen + "/rescore_data_no_bpe.en",
pre_gen + "/reference_file_no_bpe",
)
preprocess_lm_param = ["--only-source", preprocess_lm_param = [
"--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang, "--only-source",
"--srcdict", cur_lm_dict, "--trainpref",
"--destdir", preprocess_directory] pre_gen + "/rescore_data_no_bpe." + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_directory,
]
preprocess_parser = options.get_preprocessing_parser() preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param) input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args) preprocess.main(input_args)
eval_lm_param = [preprocess_directory, eval_lm_param = [
"--path", cur_language_model, preprocess_directory,
"--output-word-probs", "--path",
"--batch-size", str(batch_size), cur_language_model,
"--max-tokens", "1024", "--output-word-probs",
"--sample-break-mode", "eos", "--batch-size",
"--gen-subset", "train"] str(batch_size),
"--max-tokens",
"1024",
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser() eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f: with open(lm_score_file, "w") as f:
with redirect_stdout(f): with redirect_stdout(f):
eval_lm.main(input_args) eval_lm.main(input_args)
elif bpe_status == "shared": elif bpe_status == "shared":
preprocess_lm_param = ["--only-source", preprocess_lm_param = [
"--trainpref", pre_gen+"/rescore_data."+target_lang, "--only-source",
"--srcdict", cur_lm_dict, "--trainpref",
"--destdir", preprocess_directory] pre_gen + "/rescore_data." + target_lang,
preprocess_parser = options.get_preprocessing_parser() "--srcdict",
input_args = preprocess_parser.parse_args(preprocess_lm_param) cur_lm_dict,
preprocess.main(input_args) "--destdir",
preprocess_directory,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_directory, eval_lm_param = [
"--path", cur_language_model, preprocess_directory,
"--output-word-probs", "--path",
"--batch-size", str(batch_size), cur_language_model,
"--sample-break-mode", "eos", "--output-word-probs",
"--gen-subset", "train"] "--batch-size",
str(batch_size),
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser() eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f: with open(lm_score_file, "w") as f:
with redirect_stdout(f): with redirect_stdout(f):
eval_lm.main(input_args) eval_lm.main(input_args)
elif bpe_status == "different": elif bpe_status == "different":
rescore_file = pre_gen+"/rescore_data_no_bpe" rescore_file = pre_gen + "/rescore_data_no_bpe"
rescore_bpe = pre_gen+"/rescore_data_new_bpe" rescore_bpe = pre_gen + "/rescore_data_new_bpe"
rescore_file += "." rescore_file += "."
rescore_bpe += "." rescore_bpe += "."
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, write_reprocessed(
gen_output.no_bpe_target, rescore_file+source_lang, gen_output.no_bpe_source,
rescore_file+target_lang, pre_gen+"/reference_file_no_bpe", gen_output.no_bpe_hypo,
bpe_symbol=None) gen_output.no_bpe_target,
rescore_file + source_lang,
rescore_file + target_lang,
pre_gen + "/reference_file_no_bpe",
bpe_symbol=None,
)
# apply LM bpe to nbest list # apply LM bpe to nbest list
bpe_src_param = ["-c", cur_lm_bpe_code, bpe_src_param = [
"--input", rescore_file+target_lang, "-c",
"--output", rescore_bpe+target_lang] cur_lm_bpe_code,
subprocess.call(["python", "--input",
os.path.join(os.path.dirname(__file__), rescore_file + target_lang,
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, "--output",
shell=False) rescore_bpe + target_lang,
]
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_src_param,
shell=False,
)
# uncomment to use fastbpe instead of subword-nmt bpe # uncomment to use fastbpe instead of subword-nmt bpe
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code] # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False) # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
preprocess_dir = preprocess_directory preprocess_dir = preprocess_directory
preprocess_lm_param = ["--only-source", preprocess_lm_param = [
"--trainpref", rescore_bpe+target_lang, "--only-source",
"--srcdict", cur_lm_dict, "--trainpref",
"--destdir", preprocess_dir] rescore_bpe + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_dir,
]
preprocess_parser = options.get_preprocessing_parser() preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param) input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args) preprocess.main(input_args)
eval_lm_param = [preprocess_dir, eval_lm_param = [
"--path", cur_language_model, preprocess_dir,
"--output-word-probs", "--path",
"--batch-size", str(batch_size), cur_language_model,
"--max-tokens", "1024", "--output-word-probs",
"--sample-break-mode", "eos", "--batch-size",
"--gen-subset", "train"] str(batch_size),
"--max-tokens",
"1024",
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser() eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f: with open(lm_score_file, "w") as f:
with redirect_stdout(f): with redirect_stdout(f):
eval_lm.main(input_args) eval_lm.main(input_args)
def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False, def rescore_file_name(
target_prefix_frac=None, source_prefix_frac=None, backwards=None): nbest_dir,
prefix_len,
scorer_name,
lm_file=False,
target_prefix_frac=None,
source_prefix_frac=None,
backwards=None,
):
if lm_file: if lm_file:
score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt" score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
else: else:
score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt" score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
if backwards: if backwards:
if prefix_len is not None: if prefix_len is not None:
score_file += "prefix_len"+str(prefix_len) score_file += "prefix_len" + str(prefix_len)
elif target_prefix_frac is not None: elif target_prefix_frac is not None:
score_file += "target_prefix_frac"+str(target_prefix_frac) score_file += "target_prefix_frac" + str(target_prefix_frac)
else: else:
if source_prefix_frac is not None: if source_prefix_frac is not None:
score_file += "source_prefix_frac"+str(source_prefix_frac) score_file += "source_prefix_frac" + str(source_prefix_frac)
return score_file return score_file

View File

@ -13,57 +13,66 @@ logging.getLogger().setLevel(logging.INFO)
def main(): def main():
parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description="")
parser.add_argument('--en2fr', required=True, parser.add_argument("--en2fr", required=True, help="path to en2fr model")
help='path to en2fr model') parser.add_argument(
parser.add_argument('--fr2en', required=True, "--fr2en", required=True, help="path to fr2en mixture of experts model"
help='path to fr2en mixture of experts model') )
parser.add_argument('--user-dir', parser.add_argument(
help='path to fairseq examples/translation_moe/src directory') "--user-dir", help="path to fairseq examples/translation_moe/src directory"
parser.add_argument('--num-experts', type=int, default=10, )
help='(keep at 10 unless using a different model)') parser.add_argument(
parser.add_argument('files', nargs='*', default=['-'], "--num-experts",
help='input files to paraphrase; "-" for stdin') type=int,
default=10,
help="(keep at 10 unless using a different model)",
)
parser.add_argument(
"files",
nargs="*",
default=["-"],
help='input files to paraphrase; "-" for stdin',
)
args = parser.parse_args() args = parser.parse_args()
if args.user_dir is None: if args.user_dir is None:
args.user_dir = os.path.join( args.user_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
'translation_moe', "translation_moe",
'src', "src",
) )
if os.path.exists(args.user_dir): if os.path.exists(args.user_dir):
logging.info('found user_dir:' + args.user_dir) logging.info("found user_dir:" + args.user_dir)
else: else:
raise RuntimeError( raise RuntimeError(
'cannot find fairseq examples/translation_moe/src ' "cannot find fairseq examples/translation_moe/src "
'(tried looking here: {})'.format(args.user_dir) "(tried looking here: {})".format(args.user_dir)
) )
logging.info('loading en2fr model from:' + args.en2fr) logging.info("loading en2fr model from:" + args.en2fr)
en2fr = TransformerModel.from_pretrained( en2fr = TransformerModel.from_pretrained(
model_name_or_path=args.en2fr, model_name_or_path=args.en2fr,
tokenizer='moses', tokenizer="moses",
bpe='sentencepiece', bpe="sentencepiece",
).eval() ).eval()
logging.info('loading fr2en model from:' + args.fr2en) logging.info("loading fr2en model from:" + args.fr2en)
fr2en = TransformerModel.from_pretrained( fr2en = TransformerModel.from_pretrained(
model_name_or_path=args.fr2en, model_name_or_path=args.fr2en,
tokenizer='moses', tokenizer="moses",
bpe='sentencepiece', bpe="sentencepiece",
user_dir=args.user_dir, user_dir=args.user_dir,
task='translation_moe', task="translation_moe",
).eval() ).eval()
def gen_paraphrases(en): def gen_paraphrases(en):
fr = en2fr.translate(en) fr = en2fr.translate(en)
return [ return [
fr2en.translate(fr, inference_step_args={'expert': i}) fr2en.translate(fr, inference_step_args={"expert": i})
for i in range(args.num_experts) for i in range(args.num_experts)
] ]
logging.info('Type the input sentence and press return:') logging.info("Type the input sentence and press return:")
for line in fileinput.input(args.files): for line in fileinput.input(args.files):
line = line.strip() line = line.strip()
if len(line) == 0: if len(line) == 0:
@ -72,5 +81,5 @@ def main():
print(paraphrase) print(paraphrase)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sys
import re
import argparse import argparse
import re
import sys
class OOVIndexError(IndexError): class OOVIndexError(IndexError):
@ -25,8 +25,8 @@ class OOVIndexError(IndexError):
def replace_oovs(source_in, target_in, target_out): def replace_oovs(source_in, target_in, target_out):
"""Replaces <unk-N> tokens in the target text with the corresponding word in """Replaces <unk-N> tokens in the target text with the corresponding word in
the source text. the source text.
""" """
oov_re = re.compile("^<unk-([0-9]+)>$") oov_re = re.compile("^<unk-([0-9]+)>$")

View File

@ -10,8 +10,8 @@ from itertools import zip_longest
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
"""Replaces out-of-vocabulary words in source and target text with <unk-N>, """Replaces out-of-vocabulary words in source and target text with <unk-N>,
where N in is the position of the word in the source sequence. where N in is the position of the word in the source sequence.
""" """
def format_unk(pos): def format_unk(pos):
return "<unk-{}>".format(pos) return "<unk-{}>".format(pos)

View File

@ -8,19 +8,17 @@ from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import metrics, utils
from fairseq import utils, metrics
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import ( from fairseq.models.transformer import (
TransformerModel,
TransformerDecoder,
TransformerEncoder,
base_architecture,
DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_SOURCE_POSITIONS,
DEFAULT_MAX_TARGET_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS,
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture,
) )
from torch import Tensor from torch import Tensor

View File

@ -8,40 +8,44 @@ import os
import numpy as np import numpy as np
import torch import torch
from fairseq.data import ( from fairseq.data import (
data_utils,
Dictionary, Dictionary,
encoders,
IdDataset, IdDataset,
ListDataset, ListDataset,
NestedDictionaryDataset, NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset, NumelDataset,
NumSamplesDataset,
RawLabelDataset, RawLabelDataset,
RightPadDataset, RightPadDataset,
SortDataset, SortDataset,
data_utils,
encoders,
) )
from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.tasks import LegacyFairseqTask, register_task
@register_task('commonsense_qa') @register_task("commonsense_qa")
class CommonsenseQATask(LegacyFairseqTask): class CommonsenseQATask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Commonsense QA.""" """Task to finetune RoBERTa for Commonsense QA."""
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', parser.add_argument(
help='path to data directory; we load <split>.jsonl') "data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
parser.add_argument('--init-token', type=int, default=None, )
help='add token at the beginning of each batch item') parser.add_argument(
parser.add_argument('--num-classes', type=int, default=5) "--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
parser.add_argument("--num-classes", type=int, default=5)
def __init__(self, args, vocab): def __init__(self, args, vocab):
super().__init__(args) super().__init__(args)
self.vocab = vocab self.vocab = vocab
self.mask = vocab.add_symbol('<mask>') self.mask = vocab.add_symbol("<mask>")
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(args)
@ -53,20 +57,24 @@ class CommonsenseQATask(LegacyFairseqTask):
filename (str): the filename filename (str): the filename
""" """
dictionary = Dictionary.load(filename) dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>') dictionary.add_symbol("<mask>")
return dictionary return dictionary
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking' assert (
args.criterion == "sentence_ranking"
), "Must set --criterion=sentence_ranking"
# load data and label dictionaries # load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print('| dictionary: {} types'.format(len(vocab))) print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab) return cls(args, vocab)
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
@ -77,16 +85,18 @@ class CommonsenseQATask(LegacyFairseqTask):
if self.bpe is not None: if self.bpe is not None:
s = self.bpe.encode(s) s = self.bpe.encode(s)
tokens = self.vocab.encode_line( tokens = self.vocab.encode_line(
s, append_eos=True, add_if_not_exist=False, s,
append_eos=True,
add_if_not_exist=False,
).long() ).long()
if append_bos and self.args.init_token is not None: if append_bos and self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens return tokens
if data_path is None: if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl') data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path): if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path)) raise FileNotFoundError("Cannot find data: {}".format(data_path))
src_tokens = [[] for i in range(self.args.num_classes)] src_tokens = [[] for i in range(self.args.num_classes)]
src_lengths = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)]
@ -95,20 +105,23 @@ class CommonsenseQATask(LegacyFairseqTask):
with open(data_path) as h: with open(data_path) as h:
for line in h: for line in h:
example = json.loads(line.strip()) example = json.loads(line.strip())
if 'answerKey' in example: if "answerKey" in example:
label = ord(example['answerKey']) - ord('A') label = ord(example["answerKey"]) - ord("A")
labels.append(label) labels.append(label)
question = example['question']['stem'] question = example["question"]["stem"]
assert len(example['question']['choices']) == self.args.num_classes assert len(example["question"]["choices"]) == self.args.num_classes
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
question = 'Q: ' + question question = "Q: " + question
question_toks = binarize(question, append_bos=True) question_toks = binarize(question, append_bos=True)
for i, choice in enumerate(example['question']['choices']): for i, choice in enumerate(example["question"]["choices"]):
src = 'A: ' + choice['text'] src = "A: " + choice["text"]
src_bin = torch.cat([question_toks, binarize(src)]) src_bin = torch.cat([question_toks, binarize(src)])
src_tokens[i].append(src_bin) src_tokens[i].append(src_bin)
src_lengths[i].append(len(src_bin)) src_lengths[i].append(len(src_bin))
assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert all(
len(src_tokens[0]) == len(src_tokens[i])
for i in range(self.args.num_classes)
)
assert len(src_tokens[0]) == len(src_lengths[0]) assert len(src_tokens[0]) == len(src_lengths[0])
assert len(labels) == 0 or len(labels) == len(src_tokens[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0])
@ -118,24 +131,26 @@ class CommonsenseQATask(LegacyFairseqTask):
src_lengths[i] = ListDataset(src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i])
dataset = { dataset = {
'id': IdDataset(), "id": IdDataset(),
'nsentences': NumSamplesDataset(), "nsentences": NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens[0], reduce=True), "ntokens": NumelDataset(src_tokens[0], reduce=True),
} }
for i in range(self.args.num_classes): for i in range(self.args.num_classes):
dataset.update({ dataset.update(
'net_input{}'.format(i + 1): { {
'src_tokens': RightPadDataset( "net_input{}".format(i + 1): {
src_tokens[i], "src_tokens": RightPadDataset(
pad_idx=self.source_dictionary.pad(), src_tokens[i],
), pad_idx=self.source_dictionary.pad(),
'src_lengths': src_lengths[i], ),
"src_lengths": src_lengths[i],
}
} }
}) )
if len(labels) > 0: if len(labels) > 0:
dataset.update({'target': RawLabelDataset(labels)}) dataset.update({"target": RawLabelDataset(labels)})
dataset = NestedDictionaryDataset( dataset = NestedDictionaryDataset(
dataset, dataset,
@ -149,17 +164,18 @@ class CommonsenseQATask(LegacyFairseqTask):
sort_order=[np.random.permutation(len(dataset))], sort_order=[np.random.permutation(len(dataset))],
) )
print('| Loaded {} with {} samples'.format(split, len(dataset))) print("| Loaded {} with {} samples".format(split, len(dataset)))
self.datasets[split] = dataset self.datasets[split] = dataset
return self.datasets[split] return self.datasets[split]
def build_model(self, args): def build_model(self, args):
from fairseq import models from fairseq import models
model = models.build_model(args, self) model = models.build_model(args, self)
model.register_classification_head( model.register_classification_head(
'sentence_classification_head', "sentence_classification_head",
num_classes=1, num_classes=1,
) )

View File

@ -8,7 +8,6 @@
import argparse import argparse
import contextlib import contextlib
import sys import sys
from collections import Counter from collections import Counter
from multiprocessing import Pool from multiprocessing import Pool
@ -26,23 +25,23 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--encoder-json", "--encoder-json",
help='path to encoder.json', help="path to encoder.json",
) )
parser.add_argument( parser.add_argument(
"--vocab-bpe", "--vocab-bpe",
type=str, type=str,
help='path to vocab.bpe', help="path to vocab.bpe",
) )
parser.add_argument( parser.add_argument(
"--inputs", "--inputs",
nargs="+", nargs="+",
default=['-'], default=["-"],
help="input files to filter/encode", help="input files to filter/encode",
) )
parser.add_argument( parser.add_argument(
"--outputs", "--outputs",
nargs="+", nargs="+",
default=['-'], default=["-"],
help="path to save encoded outputs", help="path to save encoded outputs",
) )
parser.add_argument( parser.add_argument(
@ -53,18 +52,21 @@ def main():
parser.add_argument("--workers", type=int, default=20) parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args() args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \ assert len(args.inputs) == len(
"number of input and output paths should match" args.outputs
), "number of input and output paths should match"
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
inputs = [ inputs = [
stack.enter_context(open(input, "r", encoding="utf-8")) stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin if input != "-"
else sys.stdin
for input in args.inputs for input in args.inputs
] ]
outputs = [ outputs = [
stack.enter_context(open(output, "w", encoding="utf-8")) stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-" else sys.stdout if output != "-"
else sys.stdout
for output in args.outputs for output in args.outputs
] ]
@ -87,7 +89,6 @@ def main():
class MultiprocessingEncoder(object): class MultiprocessingEncoder(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args

View File

@ -25,7 +25,7 @@ def get_examples(data_dir, set_type):
examples = [] examples = []
levels = ["middle", "high"] levels = ["middle", "high"]
set_type_c = set_type.split('-') set_type_c = set_type.split("-")
if len(set_type_c) == 2: if len(set_type_c) == 2:
levels = [set_type_c[1]] levels = [set_type_c[1]]
set_type = set_type_c[0] set_type = set_type_c[0]
@ -33,13 +33,13 @@ def get_examples(data_dir, set_type):
cur_dir = os.path.join(data_dir, set_type, level) cur_dir = os.path.join(data_dir, set_type, level)
for filename in os.listdir(cur_dir): for filename in os.listdir(cur_dir):
cur_path = os.path.join(cur_dir, filename) cur_path = os.path.join(cur_dir, filename)
with open(cur_path, 'r') as f: with open(cur_path, "r") as f:
cur_data = json.load(f) cur_data = json.load(f)
answers = cur_data["answers"] answers = cur_data["answers"]
options = cur_data["options"] options = cur_data["options"]
questions = cur_data["questions"] questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ") context = cur_data["article"].replace("\n", " ")
context = re.sub(r'\s+', ' ', context) context = re.sub(r"\s+", " ", context)
for i in range(len(answers)): for i in range(len(answers)):
label = ord(answers[i]) - ord("A") label = ord(answers[i]) - ord("A")
qa_list = [] qa_list = []
@ -50,7 +50,7 @@ def get_examples(data_dir, set_type):
qa_cat = question.replace("_", option) qa_cat = question.replace("_", option)
else: else:
qa_cat = " ".join([question, option]) qa_cat = " ".join([question, option])
qa_cat = re.sub(r'\s+', ' ', qa_cat) qa_cat = re.sub(r"\s+", " ", qa_cat)
qa_list.append(qa_cat) qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label)) examples.append(InputExample(context, qa_list, label))
@ -64,11 +64,11 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--input-dir", "--input-dir",
help='input directory for downloaded RACE dataset', help="input directory for downloaded RACE dataset",
) )
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
help='output directory for extracted data', help="output directory for extracted data",
) )
args = parser.parse_args() args = parser.parse_args()
@ -77,17 +77,20 @@ def main():
for set_type in ["train", "dev", "test-middle", "test-high"]: for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type) examples = get_examples(args.input_dir, set_type)
qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)] qa_file_paths = [
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths] os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
for i in range(4)
]
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
outf_context_path = os.path.join(args.output_dir, set_type + ".input0") outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = os.path.join(args.output_dir, set_type + ".label") outf_label_path = os.path.join(args.output_dir, set_type + ".label")
outf_context = open(outf_context_path, 'w') outf_context = open(outf_context_path, "w")
outf_label = open(outf_label_path, 'w') outf_label = open(outf_label_path, "w")
for example in examples: for example in examples:
outf_context.write(example.paragraph + '\n') outf_context.write(example.paragraph + "\n")
for i in range(4): for i in range(4):
qa_files[i].write(example.qa_list[i] + '\n') qa_files[i].write(example.qa_list[i] + "\n")
outf_label.write(str(example.label) + '\n') outf_label.write(str(example.label) + "\n")
for f in qa_files: for f in qa_files:
f.close() f.close()
@ -95,5 +98,5 @@ def main():
outf_context.close() outf_context.close()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -7,19 +7,17 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.data import encoders
from fairseq.criterions import LegacyFairseqCriterion, register_criterion from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data import encoders
@register_criterion('wsc') @register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion): class WSCCriterion(LegacyFairseqCriterion):
def __init__(self, args, task): def __init__(self, args, task):
super().__init__(args, task) super().__init__(args, task)
if self.args.save_predictions is not None: if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, 'w') self.prediction_h = open(self.args.save_predictions, "w")
else: else:
self.prediction_h = None self.prediction_h = None
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(args)
@ -32,12 +30,16 @@ class WSCCriterion(LegacyFairseqCriterion):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0) parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0) parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
parser.add_argument('--wsc-cross-entropy', action='store_true', parser.add_argument(
help='use cross entropy formulation instead of margin loss') "--wsc-cross-entropy",
parser.add_argument('--save-predictions', metavar='FILE', action="store_true",
help='file to save predictions to') help="use cross entropy formulation instead of margin loss",
)
parser.add_argument(
"--save-predictions", metavar="FILE", help="file to save predictions to"
)
def get_masked_input(self, tokens, mask): def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone() masked_tokens = tokens.clone()
@ -60,27 +62,26 @@ class WSCCriterion(LegacyFairseqCriterion):
) )
else: else:
return ( return (
- query_lprobs -query_lprobs
+ self.args.wsc_margin_alpha * ( + self.args.wsc_margin_alpha
cand_lprobs - query_lprobs + self.args.wsc_margin_beta * (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
).clamp(min=0)
).sum() ).sum()
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
# compute loss and accuracy # compute loss and accuracy
loss, nloss = 0., 0 loss, nloss = 0.0, 0
ncorrect, nqueries = 0, 0 ncorrect, nqueries = 0, 0
for i, label in enumerate(sample['labels']): for i, label in enumerate(sample["labels"]):
query_lprobs = self.get_lprobs( query_lprobs = self.get_lprobs(
model, model,
sample['query_tokens'][i].unsqueeze(0), sample["query_tokens"][i].unsqueeze(0),
sample['query_masks'][i].unsqueeze(0), sample["query_masks"][i].unsqueeze(0),
) )
cand_lprobs = self.get_lprobs( cand_lprobs = self.get_lprobs(
model, model,
sample['candidate_tokens'][i], sample["candidate_tokens"][i],
sample['candidate_masks'][i], sample["candidate_masks"][i],
) )
pred = (query_lprobs >= cand_lprobs).all().item() pred = (query_lprobs >= cand_lprobs).all().item()
@ -95,72 +96,72 @@ class WSCCriterion(LegacyFairseqCriterion):
nloss += 1 nloss += 1
loss += self.get_loss(query_lprobs, cand_lprobs) loss += self.get_loss(query_lprobs, cand_lprobs)
id = sample['id'][i].item() id = sample["id"][i].item()
if self.prediction_h is not None: if self.prediction_h is not None:
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h) print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
if nloss == 0: if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True) loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1 sample_size = nqueries if nqueries > 0 else 1
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, "loss": utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], "ntokens": sample["ntokens"],
'nsentences': sample['nsentences'], "nsentences": sample["nsentences"],
'sample_size': sample_size, "sample_size": sample_size,
'ncorrect': ncorrect, "ncorrect": ncorrect,
'nqueries': nqueries, "nqueries": nqueries,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = { agg_output = {
'loss': loss_sum / sample_size / math.log(2), "loss": loss_sum / sample_size / math.log(2),
'ntokens': ntokens, "ntokens": ntokens,
'nsentences': nsentences, "nsentences": nsentences,
'sample_size': sample_size, "sample_size": sample_size,
} }
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs) nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
if nqueries > 0: if nqueries > 0:
agg_output['accuracy'] = ncorrect / float(nqueries) agg_output["accuracy"] = ncorrect / float(nqueries)
return agg_output return agg_output
@register_criterion('winogrande') @register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion): class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
# compute loss and accuracy # compute loss and accuracy
query_lprobs = self.get_lprobs( query_lprobs = self.get_lprobs(
model, model,
sample['query_tokens'], sample["query_tokens"],
sample['query_masks'], sample["query_masks"],
) )
cand_lprobs = self.get_lprobs( cand_lprobs = self.get_lprobs(
model, model,
sample['candidate_tokens'], sample["candidate_tokens"],
sample['candidate_masks'], sample["candidate_masks"],
) )
pred = query_lprobs >= cand_lprobs pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs) loss = self.get_loss(query_lprobs, cand_lprobs)
sample_size = sample['query_tokens'].size(0) sample_size = sample["query_tokens"].size(0)
ncorrect = pred.sum().item() ncorrect = pred.sum().item()
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, "loss": utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], "ntokens": sample["ntokens"],
'nsentences': sample['nsentences'], "nsentences": sample["nsentences"],
'sample_size': sample_size, "sample_size": sample_size,
'ncorrect': ncorrect, "ncorrect": ncorrect,
'nqueries': sample_size, "nqueries": sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output

View File

@ -10,47 +10,51 @@ import tempfile
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.data import ( from fairseq.data import (
data_utils,
Dictionary, Dictionary,
encoders,
IdDataset, IdDataset,
ListDataset, ListDataset,
NestedDictionaryDataset, NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset, NumelDataset,
NumSamplesDataset,
PadDataset, PadDataset,
SortDataset, SortDataset,
data_utils,
encoders,
) )
from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.tasks import LegacyFairseqTask, register_task
from . import wsc_utils from . import wsc_utils
@register_task('wsc') @register_task("wsc")
class WSCTask(LegacyFairseqTask): class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas.""" """Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', parser.add_argument(
help='path to data directory; we load <split>.jsonl') "data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
parser.add_argument('--init-token', type=int, default=None, )
help='add token at the beginning of each batch item') parser.add_argument(
"--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
def __init__(self, args, vocab): def __init__(self, args, vocab):
super().__init__(args) super().__init__(args)
self.vocab = vocab self.vocab = vocab
self.mask = vocab.add_symbol('<mask>') self.mask = vocab.add_symbol("<mask>")
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args) self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces # hack to handle GPT-2 BPE, which includes leading spaces
if args.bpe == 'gpt2': if args.bpe == "gpt2":
self.leading_space = True self.leading_space = True
self.trailing_space = False self.trailing_space = False
else: else:
@ -65,16 +69,16 @@ class WSCTask(LegacyFairseqTask):
filename (str): the filename filename (str): the filename
""" """
dictionary = Dictionary.load(filename) dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>') dictionary.add_symbol("<mask>")
return dictionary return dictionary
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
assert args.criterion == 'wsc', 'Must set --criterion=wsc' assert args.criterion == "wsc", "Must set --criterion=wsc"
# load data and label dictionaries # load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print('| dictionary: {} types'.format(len(vocab))) print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab) return cls(args, vocab)
@ -84,7 +88,9 @@ class WSCTask(LegacyFairseqTask):
if self.bpe is not None: if self.bpe is not None:
s = self.bpe.encode(s) s = self.bpe.encode(s)
tokens = self.vocab.encode_line( tokens = self.vocab.encode_line(
s, append_eos=append_eos, add_if_not_exist=False, s,
append_eos=append_eos,
add_if_not_exist=False,
).long() ).long()
if self.args.init_token is not None: if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
@ -98,19 +104,21 @@ class WSCTask(LegacyFairseqTask):
mask = torch.zeros_like(toks, dtype=torch.bool) mask = torch.zeros_like(toks, dtype=torch.bool)
mask_start = len(self.binarize(prefix)) mask_start = len(self.binarize(prefix))
mask_size = len(self.binarize(leading_space + txt)) mask_size = len(self.binarize(leading_space + txt))
mask[mask_start:mask_start + mask_size] = 1 mask[mask_start : mask_start + mask_size] = 1
return toks, mask return toks, mask
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
split (str): name of the split (e.g., train, valid, test) split (str): name of the split (e.g., train, valid, test)
""" """
if data_path is None: if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl') data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path): if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path)) raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = [] query_tokens = []
query_masks = [] query_masks = []
@ -121,13 +129,15 @@ class WSCTask(LegacyFairseqTask):
labels = [] labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path): for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
prefix = sentence[:pronoun_span.start].text prefix = sentence[: pronoun_span.start].text
suffix = sentence[pronoun_span.end:].text_with_ws suffix = sentence[pronoun_span.end :].text_with_ws
# spaCy spans include trailing spaces, but we need to know about # spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE # leading spaces for the GPT-2 BPE
leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else '' leading_space = (
trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else '' " " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
)
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
# get noun phrases, excluding pronouns and anything overlapping with the query # get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks( cand_spans = wsc_utils.filter_noun_chunks(
@ -152,7 +162,11 @@ class WSCTask(LegacyFairseqTask):
cand_toks, cand_masks = [], [] cand_toks, cand_masks = [], []
for cand_span in cand_spans: for cand_span in cand_spans:
toks, mask = self.binarize_with_mask( toks, mask = self.binarize_with_mask(
cand_span.text, prefix, suffix, leading_space, trailing_space, cand_span.text,
prefix,
suffix,
leading_space,
trailing_space,
) )
cand_toks.append(toks) cand_toks.append(toks)
cand_masks.append(mask) cand_masks.append(mask)
@ -176,17 +190,17 @@ class WSCTask(LegacyFairseqTask):
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths)
labels = ListDataset(labels, [1]*len(labels)) labels = ListDataset(labels, [1] * len(labels))
dataset = { dataset = {
'id': IdDataset(), "id": IdDataset(),
'query_tokens': query_tokens, "query_tokens": query_tokens,
'query_masks': query_masks, "query_masks": query_masks,
'candidate_tokens': candidate_tokens, "candidate_tokens": candidate_tokens,
'candidate_masks': candidate_masks, "candidate_masks": candidate_masks,
'labels': labels, "labels": labels,
'nsentences': NumSamplesDataset(), "nsentences": NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True), "ntokens": NumelDataset(query_tokens, reduce=True),
} }
nested_dataset = NestedDictionaryDataset( nested_dataset = NestedDictionaryDataset(
@ -210,9 +224,9 @@ class WSCTask(LegacyFairseqTask):
def build_dataset_for_inference(self, sample_json): def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h: with tempfile.NamedTemporaryFile(buffering=0) as h:
h.write((json.dumps(sample_json) + '\n').encode('utf-8')) h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
dataset = self.load_dataset( dataset = self.load_dataset(
'disambiguate_pronoun', "disambiguate_pronoun",
data_path=h.name, data_path=h.name,
return_only=True, return_only=True,
) )
@ -239,19 +253,19 @@ class WSCTask(LegacyFairseqTask):
return scores return scores
cand_lprobs = get_lprobs( cand_lprobs = get_lprobs(
sample['candidate_tokens'][0], sample["candidate_tokens"][0],
sample['candidate_masks'][0], sample["candidate_masks"][0],
) )
if sample['query_tokens'][0] is not None: if sample["query_tokens"][0] is not None:
query_lprobs = get_lprobs( query_lprobs = get_lprobs(
sample['query_tokens'][0].unsqueeze(0), sample["query_tokens"][0].unsqueeze(0),
sample['query_masks'][0].unsqueeze(0), sample["query_masks"][0].unsqueeze(0),
) )
return (query_lprobs >= cand_lprobs).all().item() == 1 return (query_lprobs >= cand_lprobs).all().item() == 1
else: else:
best_idx = cand_lprobs.argmax().item() best_idx = cand_lprobs.argmax().item()
full_cand = sample['candidate_tokens'][0][best_idx] full_cand = sample["candidate_tokens"][0][best_idx]
mask = sample['candidate_masks'][0][best_idx] mask = sample["candidate_masks"][0][best_idx]
toks = full_cand[mask.bool()] toks = full_cand[mask.bool()]
return self.bpe.decode(self.source_dictionary.string(toks)).strip() return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@ -264,7 +278,7 @@ class WSCTask(LegacyFairseqTask):
return self.vocab return self.vocab
@register_task('winogrande') @register_task("winogrande")
class WinograndeTask(WSCTask): class WinograndeTask(WSCTask):
""" """
Task for WinoGrande dataset. Efficient implementation for Winograd schema Task for WinoGrande dataset. Efficient implementation for Winograd schema
@ -273,24 +287,26 @@ class WinograndeTask(WSCTask):
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
assert args.criterion == 'winogrande', 'Must set --criterion=winogrande' assert args.criterion == "winogrande", "Must set --criterion=winogrande"
# load data and label dictionaries # load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print('| dictionary: {} types'.format(len(vocab))) print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab) return cls(args, vocab)
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
split (str): name of the split (e.g., train, valid, test) split (str): name of the split (e.g., train, valid, test)
""" """
if data_path is None: if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl') data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path): if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path)) raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = [] query_tokens = []
query_masks = [] query_masks = []
@ -299,19 +315,23 @@ class WinograndeTask(WSCTask):
candidate_masks = [] candidate_masks = []
candidate_lengths = [] candidate_lengths = []
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test')) itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
for sample in itr: for sample in itr:
sentence, pronoun_span, query, cand_text = sample sentence, pronoun_span, query, cand_text = sample
prefix = sentence[:pronoun_span[0]].rstrip() prefix = sentence[: pronoun_span[0]].rstrip()
suffix = sentence[pronoun_span[1]:] suffix = sentence[pronoun_span[1] :]
leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else '' leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
trailing_space = '' trailing_space = ""
if query is not None: if query is not None:
query_toks, query_mask = self.binarize_with_mask( query_toks, query_mask = self.binarize_with_mask(
query, prefix, suffix, leading_space, trailing_space, query,
prefix,
suffix,
leading_space,
trailing_space,
) )
query_len = len(query_toks) query_len = len(query_toks)
else: else:
@ -322,7 +342,11 @@ class WinograndeTask(WSCTask):
query_lengths.append(query_len) query_lengths.append(query_len)
cand_toks, cand_mask = self.binarize_with_mask( cand_toks, cand_mask = self.binarize_with_mask(
cand_text, prefix, suffix, leading_space, trailing_space, cand_text,
prefix,
suffix,
leading_space,
trailing_space,
) )
candidate_tokens.append(cand_toks) candidate_tokens.append(cand_toks)
@ -342,17 +366,19 @@ class WinograndeTask(WSCTask):
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0) query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
candidate_lengths = np.array(candidate_lengths) candidate_lengths = np.array(candidate_lengths)
candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad()) candidate_tokens = get_pad_dataset_fn(
candidate_tokens, candidate_lengths, self.vocab.pad()
)
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0) candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
dataset = { dataset = {
'id': IdDataset(), "id": IdDataset(),
'query_tokens': query_tokens, "query_tokens": query_tokens,
'query_masks': query_masks, "query_masks": query_masks,
'candidate_tokens': candidate_tokens, "candidate_tokens": candidate_tokens,
'candidate_masks': candidate_masks, "candidate_masks": candidate_masks,
'nsentences': NumSamplesDataset(), "nsentences": NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True), "ntokens": NumelDataset(query_tokens, reduce=True),
} }
nested_dataset = NestedDictionaryDataset( nested_dataset = NestedDictionaryDataset(

View File

@ -3,48 +3,48 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from functools import lru_cache
import json import json
from functools import lru_cache
def convert_sentence_to_json(sentence): def convert_sentence_to_json(sentence):
if '_' in sentence: if "_" in sentence:
prefix, rest = sentence.split('_', 1) prefix, rest = sentence.split("_", 1)
query, rest = rest.split('_', 1) query, rest = rest.split("_", 1)
query_index = len(prefix.rstrip().split(' ')) query_index = len(prefix.rstrip().split(" "))
else: else:
query, query_index = None, None query, query_index = None, None
prefix, rest = sentence.split('[', 1) prefix, rest = sentence.split("[", 1)
pronoun, rest = rest.split(']', 1) pronoun, rest = rest.split("]", 1)
pronoun_index = len(prefix.rstrip().split(' ')) pronoun_index = len(prefix.rstrip().split(" "))
sentence = sentence.replace('_', '').replace('[', '').replace(']', '') sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
return { return {
'idx': 0, "idx": 0,
'text': sentence, "text": sentence,
'target': { "target": {
'span1_index': query_index, "span1_index": query_index,
'span1_text': query, "span1_text": query,
'span2_index': pronoun_index, "span2_index": pronoun_index,
'span2_text': pronoun, "span2_text": pronoun,
}, },
} }
def extended_noun_chunks(sentence): def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
np_start, cur_np = 0, 'NONE' np_start, cur_np = 0, "NONE"
for i, token in enumerate(sentence): for i, token in enumerate(sentence):
np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE' np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
if np_type != cur_np: if np_type != cur_np:
if cur_np != 'NONE': if cur_np != "NONE":
noun_chunks.add((np_start, i)) noun_chunks.add((np_start, i))
if np_type != 'NONE': if np_type != "NONE":
np_start = i np_start = i
cur_np = np_type cur_np = np_type
if cur_np != 'NONE': if cur_np != "NONE":
noun_chunks.add((np_start, len(sentence))) noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)] return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
@ -61,14 +61,14 @@ def find_token(sentence, start_pos):
def find_span(sentence, search_text, start=0): def find_span(sentence, search_text, start=0):
search_text = search_text.lower() search_text = search_text.lower()
for tok in sentence[start:]: for tok in sentence[start:]:
remainder = sentence[tok.i:].text.lower() remainder = sentence[tok.i :].text.lower()
if remainder.startswith(search_text): if remainder.startswith(search_text):
len_to_consume = len(search_text) len_to_consume = len(search_text)
start_idx = tok.idx start_idx = tok.idx
for next_tok in sentence[tok.i:]: for next_tok in sentence[tok.i :]:
end_idx = next_tok.idx + len(next_tok.text) end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume: if end_idx - start_idx == len_to_consume:
span = sentence[tok.i:next_tok.i + 1] span = sentence[tok.i : next_tok.i + 1]
return span return span
return None return None
@ -76,13 +76,15 @@ def find_span(sentence, search_text, start=0):
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_detokenizer(): def get_detokenizer():
from sacremoses import MosesDetokenizer from sacremoses import MosesDetokenizer
detok = MosesDetokenizer(lang='en')
detok = MosesDetokenizer(lang="en")
return detok return detok
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_spacy_nlp(): def get_spacy_nlp():
import en_core_web_lg import en_core_web_lg
nlp = en_core_web_lg.load() nlp = en_core_web_lg.load()
return nlp return nlp
@ -95,45 +97,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
for line in fin: for line in fin:
sample = json.loads(line.strip()) sample = json.loads(line.strip())
if positive_only and 'label' in sample and not sample['label']: if positive_only and "label" in sample and not sample["label"]:
# only consider examples where the query is correct # only consider examples where the query is correct
continue continue
target = sample['target'] target = sample["target"]
# clean up the query # clean up the query
query = target['span1_text'] query = target["span1_text"]
if query is not None: if query is not None:
if '\n' in query: if "\n" in query:
continue continue
if query.endswith('.') or query.endswith(','): if query.endswith(".") or query.endswith(","):
query = query[:-1] query = query[:-1]
# split tokens # split tokens
tokens = sample['text'].split(' ') tokens = sample["text"].split(" ")
def strip_pronoun(x): def strip_pronoun(x):
return x.rstrip('.,"') return x.rstrip('.,"')
# find the pronoun # find the pronoun
pronoun_idx = target['span2_index'] pronoun_idx = target["span2_index"]
pronoun = strip_pronoun(target['span2_text']) pronoun = strip_pronoun(target["span2_text"])
if strip_pronoun(tokens[pronoun_idx]) != pronoun: if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned # hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1 pronoun_idx += 1
else: else:
raise Exception('Misaligned pronoun!') raise Exception("Misaligned pronoun!")
assert strip_pronoun(tokens[pronoun_idx]) == pronoun assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun # split tokens before and after the pronoun
before = tokens[:pronoun_idx] before = tokens[:pronoun_idx]
after = tokens[pronoun_idx + 1:] after = tokens[pronoun_idx + 1 :]
# the GPT BPE attaches leading spaces to tokens, so we keep track # the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun # of whether we need spaces before or after the pronoun
leading_space = ' ' if pronoun_idx > 0 else '' leading_space = " " if pronoun_idx > 0 else ""
trailing_space = ' ' if len(after) > 0 else '' trailing_space = " " if len(after) > 0 else ""
# detokenize # detokenize
before = detok.detokenize(before, return_str=True) before = detok.detokenize(before, return_str=True)
@ -142,14 +144,14 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
# hack: when the pronoun ends in a period (or comma), move the # hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part # punctuation to the "after" part
if pronoun.endswith('.') or pronoun.endswith(','): if pronoun.endswith(".") or pronoun.endswith(","):
after = pronoun[-1] + trailing_space + after after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1] pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove # hack: when the "after" part begins with a comma or period, remove
# the trailing space # the trailing space
if after.startswith('.') or after.startswith(','): if after.startswith(".") or after.startswith(","):
trailing_space = '' trailing_space = ""
# parse sentence with spacy # parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after) sentence = nlp(before + leading_space + pronoun + trailing_space + after)
@ -164,13 +166,13 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
# convert to format where pronoun is surrounded by "[]" and # convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_" # query is surrounded by "_"
query_span = find_span(sentence, query) query_span = find_span(sentence, query)
query_with_ws = '_{}_{}'.format( query_with_ws = "_{}_{}".format(
query_span.text, query_span.text,
(' ' if query_span.text_with_ws.endswith(' ') else '') (" " if query_span.text_with_ws.endswith(" ") else ""),
) )
pronoun_with_ws = '[{}]{}'.format( pronoun_with_ws = "[{}]{}".format(
pronoun_span.text, pronoun_span.text,
(' ' if pronoun_span.text_with_ws.endswith(' ') else '') (" " if pronoun_span.text_with_ws.endswith(" ") else ""),
) )
if query_span.start < pronoun_span.start: if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws) first = (query_span, query_with_ws)
@ -179,41 +181,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
first = (pronoun_span, pronoun_with_ws) first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws) second = (query_span, query_with_ws)
sentence = ( sentence = (
sentence[:first[0].start].text_with_ws sentence[: first[0].start].text_with_ws
+ first[1] + first[1]
+ sentence[first[0].end:second[0].start].text_with_ws + sentence[first[0].end : second[0].start].text_with_ws
+ second[1] + second[1]
+ sentence[second[0].end:].text + sentence[second[0].end :].text
) )
yield sentence, sample.get('label', None) yield sentence, sample.get("label", None)
else: else:
yield sentence, pronoun_span, query, sample.get('label', None) yield sentence, pronoun_span, query, sample.get("label", None)
def winogrande_jsonl_iterator(input_fname, eval=False): def winogrande_jsonl_iterator(input_fname, eval=False):
with open(input_fname) as fin: with open(input_fname) as fin:
for line in fin: for line in fin:
sample = json.loads(line.strip()) sample = json.loads(line.strip())
sentence, option1, option2 = sample['sentence'], sample['option1'],\ sentence, option1, option2 = (
sample['option2'] sample["sentence"],
sample["option1"],
sample["option2"],
)
pronoun_span = (sentence.index('_'), sentence.index('_') + 1) pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
if eval: if eval:
query, cand = option1, option2 query, cand = option1, option2
else: else:
query = option1 if sample['answer'] == '1' else option2 query = option1 if sample["answer"] == "1" else option2
cand = option2 if sample['answer'] == '1' else option1 cand = option2 if sample["answer"] == "1" else option1
yield sentence, pronoun_span, query, cand yield sentence, pronoun_span, query, cand
def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False): def filter_noun_chunks(
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
):
if exclude_pronouns: if exclude_pronouns:
chunks = [ chunks = [
np for np in chunks if ( np
np.lemma_ != '-PRON-' for np in chunks
and not all(tok.pos_ == 'PRON' for tok in np) if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
)
] ]
if exclude_query is not None: if exclude_query is not None:
@ -224,9 +230,8 @@ def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact
found = False found = False
for excl in excl_txt: for excl in excl_txt:
if ( if (
(not exact_match and (lower_chunk in excl or excl in lower_chunk)) not exact_match and (lower_chunk in excl or excl in lower_chunk)
or lower_chunk == excl ) or lower_chunk == excl:
):
found = True found = True
break break
if not found: if not found:

View File

@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import criterions, models, eval # noqa from . import criterions, eval, models # noqa

View File

@ -6,6 +6,7 @@
import importlib import importlib
import os import os
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"): if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")] criterion_name = file[: file.find(".py")]

View File

@ -3,21 +3,17 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from examples.simultaneous_translation.utils.latency import LatencyTraining
from fairseq.criterions import register_criterion from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import ( from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion LabelSmoothedCrossEntropyCriterion,
)
from examples.simultaneous_translation.utils.latency import (
LatencyTraining
) )
@register_criterion('latency_augmented_label_smoothed_cross_entropy') @register_criterion("latency_augmented_label_smoothed_cross_entropy")
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion LabelSmoothedCrossEntropyCriterion
): ):
def __init__(self, args, task): def __init__(self, args, task):
super().__init__(args, task) super().__init__(args, task)
self.eps = args.label_smoothing self.eps = args.label_smoothing
@ -40,7 +36,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
def add_args(parser): def add_args(parser):
super( super(
LatencyAugmentedLabelSmoothedCrossEntropyCriterion, LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
LatencyAugmentedLabelSmoothedCrossEntropyCriterion LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
).add_args(parser) ).add_args(parser)
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
# fmt: off # fmt: off
@ -69,7 +65,8 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
# Get latency loss # Get latency loss
latency_loss = self.latency_train.loss( latency_loss = self.latency_train.loss(
attn_list, source_padding_mask, target_padding_mask) attn_list, source_padding_mask, target_padding_mask
)
loss += latency_loss loss += latency_loss

View File

@ -5,16 +5,20 @@
import importlib import importlib
import os import os
from fairseq import registry from fairseq import registry
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type')
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
"--agent-type"
)
DEFAULT_EOS = '</s>' DEFAULT_EOS = "</s>"
GET = 0 GET = 0
SEND = 1 SEND = 1
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
module = file[:file.find('.py')] module = file[: file.find(".py")]
importlib.import_module('agents.' + module) importlib.import_module("agents." + module)

View File

@ -3,14 +3,16 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import GET, SEND, DEFAULT_EOS
import time import time
from multiprocessing.pool import ThreadPool as Pool
from functools import partial from functools import partial
from multiprocessing.pool import ThreadPool as Pool
from . import DEFAULT_EOS, GET, SEND
class Agent(object): class Agent(object):
"an agent needs to follow this pattern" "an agent needs to follow this pattern"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -40,26 +42,26 @@ class Agent(object):
with Pool(10) as p: with Pool(10) as p:
p.map( p.map(
partial(self._decode_one, session), partial(self._decode_one, session),
[sent_id for sent_id in range(low, high + 1)] [sent_id for sent_id in range(low, high + 1)],
) )
else: else:
for sent_id in range(low, high + 1): for sent_id in range(low, high + 1):
self._decode_one(session, sent_id) self._decode_one(session, sent_id)
print(f'Finished {low} to {high} in {time.time() - t0}s') print(f"Finished {low} to {high} in {time.time() - t0}s")
def _decode_one(self, session, sent_id): def _decode_one(self, session, sent_id):
action = {} action = {}
self.reset() self.reset()
states = self.init_states() states = self.init_states()
while action.get('value', None) != DEFAULT_EOS: while action.get("value", None) != DEFAULT_EOS:
# take an action # take an action
action = self.policy(states) action = self.policy(states)
if action['key'] == GET: if action["key"] == GET:
new_states = session.get_src(sent_id, action["value"]) new_states = session.get_src(sent_id, action["value"])
states = self.update_states(states, new_states) states = self.update_states(states, new_states)
elif action['key'] == SEND: elif action["key"] == SEND:
session.send_hypo(sent_id, action['value']) session.send_hypo(sent_id, action["value"])
print(" ".join(states["tokens"]["tgt"])) print(" ".join(states["tokens"]["tgt"]))

View File

@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . agent import Agent
from . import DEFAULT_EOS, GET, SEND
from fairseq import checkpoint_utils, utils, tasks
import os
import json import json
import os
from fairseq import checkpoint_utils, tasks, utils
from . import DEFAULT_EOS, GET, SEND
from .agent import Agent
class SimulTransAgent(Agent): class SimulTransAgent(Agent):
@ -51,13 +53,15 @@ class SimulTransAgent(Agent):
raise NotImplementedError raise NotImplementedError
def load_model(self, args): def load_model(self, args):
args.user_dir = os.path.join(os.path.dirname(__file__), '..', '..') args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
utils.import_user_module(args) utils.import_user_module(args)
filename = args.model_path filename = args.model_path
if not os.path.exists(filename): if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename)) raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename, json.loads(args.model_overrides)) state = checkpoint_utils.load_checkpoint_to_cpu(
filename, json.loads(args.model_overrides)
)
saved_args = state["args"] saved_args = state["args"]
saved_args.data = args.data_bin saved_args.data = args.data_bin
@ -79,7 +83,7 @@ class SimulTransAgent(Agent):
"steps": {"src": 0, "tgt": 0}, "steps": {"src": 0, "tgt": 0},
"finished": False, "finished": False,
"finish_read": False, "finish_read": False,
"model_states": {} "model_states": {},
} }
def update_states(self, states, new_state): def update_states(self, states, new_state):
@ -115,38 +119,38 @@ class SimulTransAgent(Agent):
def write_action(self, states): def write_action(self, states):
token, index = self.model.predict_from_states(states) token, index = self.model.predict_from_states(states)
if index == self.dict["tgt"].eos() or len(states["tokens"]["tgt"]) > self.max_len: if (
index == self.dict["tgt"].eos()
or len(states["tokens"]["tgt"]) > self.max_len
):
# Finish this sentence is predict EOS # Finish this sentence is predict EOS
states["finished"] = True states["finished"] = True
end_idx_last_full_word = self._target_length(states) end_idx_last_full_word = self._target_length(states)
else: else:
states["tokens"]["tgt"] += [token] states["tokens"]["tgt"] += [token]
end_idx_last_full_word = ( end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
self.word_splitter["tgt"] states["tokens"]["tgt"]
.end_idx_last_full_word(states["tokens"]["tgt"])
) )
self._append_indices(states, [index], "tgt") self._append_indices(states, [index], "tgt")
if end_idx_last_full_word > states["steps"]["tgt"]: if end_idx_last_full_word > states["steps"]["tgt"]:
# Only sent detokenized full words to the server # Only sent detokenized full words to the server
word = self.word_splitter["tgt"].merge( word = self.word_splitter["tgt"].merge(
states["tokens"]["tgt"][ states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
states["steps"]["tgt"]: end_idx_last_full_word
]
) )
states["steps"]["tgt"] = end_idx_last_full_word states["steps"]["tgt"] = end_idx_last_full_word
states["segments"]["tgt"] += [word] states["segments"]["tgt"] += [word]
return {'key': SEND, 'value': word} return {"key": SEND, "value": word}
else: else:
return None return None
def read_action(self, states): def read_action(self, states):
return {'key': GET, 'value': None} return {"key": GET, "value": None}
def finish_action(self): def finish_action(self):
return {'key': SEND, 'value': DEFAULT_EOS} return {"key": SEND, "value": DEFAULT_EOS}
def reset(self): def reset(self):
pass pass
@ -160,4 +164,4 @@ class SimulTransAgent(Agent):
states["indices"][key] += new_indices states["indices"][key] += new_indices
def _target_length(self, states): def _target_length(self, states):
return len(states["tokens"]['tgt']) return len(states["tokens"]["tgt"])

View File

@ -3,10 +3,9 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . simul_trans_agent import SimulTransAgent from . import DEFAULT_EOS, GET, register_agent
from . import DEFAULT_EOS, GET from .simul_trans_agent import SimulTransAgent
from . import register_agent from .word_splitter import SPLITTER_DICT
from . word_splitter import SPLITTER_DICT
@register_agent("simul_trans_text") @register_agent("simul_trans_text")
@ -15,11 +14,11 @@ class SimulTransTextAgent(SimulTransAgent):
self.word_splitter = {} self.word_splitter = {}
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type]( self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
getattr(args, f"src_splitter_path") getattr(args, f"src_splitter_path")
) )
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type]( self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
getattr(args, f"tgt_splitter_path") getattr(args, f"tgt_splitter_path")
) )
def load_dictionary(self, task): def load_dictionary(self, task):
self.dict = {} self.dict = {}
@ -37,12 +36,16 @@ class SimulTransTextAgent(SimulTransAgent):
tokens = self.word_splitter["src"].split(new_word) tokens = self.word_splitter["src"].split(new_word)
# Get indices from dictionary # Get indices from dictionary
# You can change to you own dictionary # You can change to you own dictionary
indices = self.dict["src"].encode_line( indices = (
tokens, self.dict["src"]
line_tokenizer=lambda x: x, .encode_line(
add_if_not_exist=False, tokens,
append_eos=False line_tokenizer=lambda x: x,
).tolist() add_if_not_exist=False,
append_eos=False,
)
.tolist()
)
else: else:
tokens = [new_word] tokens = [new_word]
indices = [self.dict["src"].eos()] indices = [self.dict["src"].eos()]
@ -61,11 +64,11 @@ class SimulTransTextAgent(SimulTransAgent):
# At leat one word is read # At leat one word is read
if len(states["tokens"]["src"]) == 0: if len(states["tokens"]["src"]) == 0:
return {'key': GET, 'value': None} return {"key": GET, "value": None}
# Only request new word if there is no buffered tokens # Only request new word if there is no buffered tokens
if len(states["tokens"]["src"]) <= states["steps"]["src"]: if len(states["tokens"]["src"]) <= states["steps"]["src"]:
return {'key': GET, 'value': None} return {"key": GET, "value": None}
return None return None

View File

@ -40,6 +40,7 @@ class BPEWordSplitter(object):
def __init__(self, model_path): def __init__(self, model_path):
super().__init__() super().__init__()
from subword_nmt.apply_bpe import BPE from subword_nmt.apply_bpe import BPE
with open(model_path) as f: with open(model_path) as f:
self.model = BPE(f) self.model = BPE(f)
@ -48,7 +49,7 @@ class BPEWordSplitter(object):
def end_idx_last_full_word(self, tokens): def end_idx_last_full_word(self, tokens):
# Begin of word indices # Begin of word indices
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@'] bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
if len(bow_indices) < 2: if len(bow_indices) < 2:
return 0 return 0
@ -63,6 +64,7 @@ class SentencePieceModelWordSplitter(object):
def __init__(self, model_path): def __init__(self, model_path):
super().__init__() super().__init__()
import sentencepiece as spm import sentencepiece as spm
self.model = spm.SentencePieceProcessor() self.model = spm.SentencePieceProcessor()
self.model.Load(model_path) self.model.Load(model_path)
@ -71,7 +73,7 @@ class SentencePieceModelWordSplitter(object):
def end_idx_last_full_word(self, tokens): def end_idx_last_full_word(self, tokens):
# Begin of word indices # Begin of word indices
bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581'] bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
if len(bow_indices) < 2: if len(bow_indices) < 2:
return 0 return 0

View File

@ -3,19 +3,20 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import requests
from typing import Optional from typing import Optional
import requests
from scorers import build_scorer from scorers import build_scorer
class SimulSTEvaluationService(object): class SimulSTEvaluationService(object):
DEFAULT_HOSTNAME = 'localhost' DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321 DEFAULT_PORT = 12321
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT): def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
self.hostname = hostname self.hostname = hostname
self.port = port self.port = port
self.base_url = f'http://{self.hostname}:{self.port}' self.base_url = f"http://{self.hostname}:{self.port}"
def __enter__(self): def __enter__(self):
self.new_session() self.new_session()
@ -25,56 +26,53 @@ class SimulSTEvaluationService(object):
def new_session(self): def new_session(self):
# start eval session # start eval session
url = f'{self.base_url}' url = f"{self.base_url}"
try: try:
_ = requests.post(url) _ = requests.post(url)
except Exception as e: except Exception as e:
print(f'Failed to start an evaluation session: {e}') print(f"Failed to start an evaluation session: {e}")
print('Evaluation session started.') print("Evaluation session started.")
return self return self
def get_scores(self): def get_scores(self):
# end eval session # end eval session
url = f'{self.base_url}/result' url = f"{self.base_url}/result"
try: try:
r = requests.get(url) r = requests.get(url)
print('Scores: {}'.format(r.json())) print("Scores: {}".format(r.json()))
print('Evaluation session finished.') print("Evaluation session finished.")
except Exception as e: except Exception as e:
print(f'Failed to end an evaluation session: {e}') print(f"Failed to end an evaluation session: {e}")
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str: def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
url = f'{self.base_url}/src' url = f"{self.base_url}/src"
params = {"sent_id": sent_id} params = {"sent_id": sent_id}
if extra_params is not None: if extra_params is not None:
for key in extra_params.keys(): for key in extra_params.keys():
params[key] = extra_params[key] params[key] = extra_params[key]
try: try:
r = requests.get( r = requests.get(url, params=params)
url,
params=params
)
except Exception as e: except Exception as e:
print(f'Failed to request a source segment: {e}') print(f"Failed to request a source segment: {e}")
return r.json() return r.json()
def send_hypo(self, sent_id: int, hypo: str) -> None: def send_hypo(self, sent_id: int, hypo: str) -> None:
url = f'{self.base_url}/hypo' url = f"{self.base_url}/hypo"
params = {"sent_id": sent_id} params = {"sent_id": sent_id}
try: try:
requests.put(url, params=params, data=hypo.encode("utf-8")) requests.put(url, params=params, data=hypo.encode("utf-8"))
except Exception as e: except Exception as e:
print(f'Failed to send a translated segment: {e}') print(f"Failed to send a translated segment: {e}")
def corpus_info(self): def corpus_info(self):
url = f'{self.base_url}' url = f"{self.base_url}"
try: try:
r = requests.get(url) r = requests.get(url)
except Exception as e: except Exception as e:
print(f'Failed to request corpus information: {e}') print(f"Failed to request corpus information: {e}")
return r.json() return r.json()

View File

@ -3,20 +3,21 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from examples.simultaneous_translation.utils.latency import LatencyInference
import argparse import argparse
import torch
import json import json
import torch
from examples.simultaneous_translation.utils.latency import LatencyInference
LATENCY_METRICS = [ LATENCY_METRICS = [
'differentiable_average_lagging', "differentiable_average_lagging",
'average_lagging', "average_lagging",
'average_proportion', "average_proportion",
] ]
class LatencyScorer(): class LatencyScorer:
def __init__(self, start_from_zero=True): def __init__(self, start_from_zero=True):
self.recorder = [] self.recorder = []
self.scores = {} self.scores = {}
@ -26,10 +27,7 @@ class LatencyScorer():
def update_reorder(self, list_of_dict): def update_reorder(self, list_of_dict):
self.recorder = [] self.recorder = []
for info in list_of_dict: for info in list_of_dict:
delays = [ delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
int(x) - int(not self.start_from_zero)
for x in info["delays"]
]
delays = torch.LongTensor(delays).unsqueeze(0) delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0) src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
@ -59,7 +57,7 @@ if __name__ == "__main__":
scorer = LatencyInference() scorer = LatencyInference()
recorder = [] recorder = []
with open(args.input, 'r') as f: with open(args.input, "r") as f:
for line in f: for line in f:
info = json.loads(line) info = json.loads(line)
@ -74,7 +72,7 @@ if __name__ == "__main__":
average_results = {} average_results = {}
for metric in LATENCY_METRICS: for metric in LATENCY_METRICS:
average_results[metric] = sum( average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
[x[metric][0, 0].item() for x in recorder] recorder
) / len(recorder) )
print(f"{metric}: {average_results[metric]}") print(f"{metric}: {average_results[metric]}")

View File

@ -5,37 +5,48 @@
import argparse import argparse
from agents import build_agent
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES from fairseq.registry import REGISTRIES
from agents import build_agent
DEFAULT_HOSTNAME = 'localhost'
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321 DEFAULT_PORT = 12321
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME, parser.add_argument(
help='server hostname') "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
parser.add_argument('--port', type=int, default=DEFAULT_PORT, )
help='server port number') parser.add_argument(
parser.add_argument('--agent-type', default='simul_trans_text', "--port", type=int, default=DEFAULT_PORT, help="server port number"
help='Agent type') )
parser.add_argument('--scorer-type', default='text', parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
help='Scorer type') parser.add_argument("--scorer-type", default="text", help="Scorer type")
parser.add_argument('--start-idx', type=int, default=0, parser.add_argument(
help='Start index of the sentence to evaluate') "--start-idx",
parser.add_argument('--end-idx', type=int, default=float('inf'), type=int,
help='End index of the sentence to evaluate') default=0,
parser.add_argument('--scores', action="store_true", help="Start index of the sentence to evaluate",
help='Request scores from server') )
parser.add_argument('--reset-server', action="store_true", parser.add_argument(
help='Reset the server') "--end-idx",
parser.add_argument('--num-threads', type=int, default=10, type=int,
help='Number of threads used by agent') default=float("inf"),
parser.add_argument('--local', action="store_true", default=False, help="End index of the sentence to evaluate",
help='Local evaluation') )
parser.add_argument(
"--scores", action="store_true", help="Request scores from server"
)
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
parser.add_argument(
"--num-threads", type=int, default=10, help="Number of threads used by agent"
)
parser.add_argument(
"--local", action="store_true", default=False, help="Local evaluation"
)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()

View File

@ -5,15 +5,15 @@
import importlib import importlib
import os import os
from fairseq import registry from fairseq import registry
(
build_scorer,
register_scorer, (build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
SCORER_REGISTRIES, "--scorer-type"
_ )
) = registry.setup_registry('--scorer-type')
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
module = file[:file.find('.py')] module = file[: file.find(".py")]
importlib.import_module('scorers.' + module) importlib.import_module("scorers." + module)

View File

@ -3,16 +3,17 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from vizseq.scorers.bleu import BLEUScorer
from vizseq.scorers.ter import TERScorer
from vizseq.scorers.meteor import METEORScorer
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
from collections import defaultdict
import json import json
import os import os
from collections import defaultdict
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
from vizseq.scorers.bleu import BLEUScorer
from vizseq.scorers.meteor import METEORScorer
from vizseq.scorers.ter import TERScorer
DEFAULT_EOS = '</s>' DEFAULT_EOS = "</s>"
class SimulScorer(object): class SimulScorer(object):
@ -23,7 +24,7 @@ class SimulScorer(object):
self.output_files = { self.output_files = {
"text": os.path.join(args.output, "text"), "text": os.path.join(args.output, "text"),
"delay": os.path.join(args.output, "delay"), "delay": os.path.join(args.output, "delay"),
"scores": os.path.join(args.output, "scores") "scores": os.path.join(args.output, "scores"),
} }
else: else:
self.output_files = None self.output_files = None
@ -52,14 +53,7 @@ class SimulScorer(object):
def recv_hyp(self, sent_id, list_of_tokens): def recv_hyp(self, sent_id, list_of_tokens):
for token in list_of_tokens: for token in list_of_tokens:
self.translations[ self.translations[sent_id].append((token, self.steps[sent_id]))
sent_id
].append(
(
token,
self.steps[sent_id]
)
)
def reset(self): def reset(self):
self.steps = defaultdict(int) self.steps = defaultdict(int)
@ -76,8 +70,9 @@ class SimulScorer(object):
delays += [[t[1] for t in self.translations[i]]] delays += [[t[1] for t in self.translations[i]]]
bleu_score = BLEUScorer( bleu_score = BLEUScorer(
sent_level=False, corpus_level=True, sent_level=False,
extra_args={'bleu_tokenizer': self.tokenizer} corpus_level=True,
extra_args={"bleu_tokenizer": self.tokenizer},
).score(translations, [self.data["tgt"]]) ).score(translations, [self.data["tgt"]])
ter_score = TERScorer(sent_level=False, corpus_level=True).score( ter_score = TERScorer(sent_level=False, corpus_level=True).score(
@ -92,16 +87,16 @@ class SimulScorer(object):
{"src_len": src_len, "delays": delay} {"src_len": src_len, "delays": delay}
for src_len, delay in zip(self.src_lengths(), delays) for src_len, delay in zip(self.src_lengths(), delays)
], ],
start_from_zero=False start_from_zero=False,
) )
scores = { scores = {
'BLEU': bleu_score[0], "BLEU": bleu_score[0],
'TER': ter_score[0], "TER": ter_score[0],
'METEOR': meteor_score[0], "METEOR": meteor_score[0],
'DAL': latency_score['differentiable_average_lagging'], "DAL": latency_score["differentiable_average_lagging"],
'AL': latency_score['average_lagging'], "AL": latency_score["average_lagging"],
'AP': latency_score['average_proportion'], "AP": latency_score["average_proportion"],
} }
if self.output_files is not None: if self.output_files is not None:
@ -109,9 +104,9 @@ class SimulScorer(object):
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
self.write_results_to_file(translations, delays, scores) self.write_results_to_file(translations, delays, scores)
except BaseException as be: except BaseException as be:
print(f'Failed to write results to {self.output_dir}.') print(f"Failed to write results to {self.output_dir}.")
print(be) print(be)
print('Skip writing predictions') print("Skip writing predictions")
return scores return scores
@ -125,12 +120,8 @@ class SimulScorer(object):
with open(self.output_files["delay"], "w") as f: with open(self.output_files["delay"], "w") as f:
for i, delay in enumerate(delays): for i, delay in enumerate(delays):
f.write( f.write(
json.dumps( json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
{ + "\n"
"src_len": self.src_lengths()[i],
"delays": delay
}
) + "\n"
) )
with open(self.output_files["scores"], "w") as f: with open(self.output_files["scores"], "w") as f:
@ -163,7 +154,7 @@ class SimulScorer(object):
list_to_return.append( list_to_return.append(
{ {
"path": item["input"]["path"].strip(), "path": item["input"]["path"].strip(),
"length": item["input"]["length_ms"] "length": item["input"]["length_ms"],
} }
) )
return list_to_return return list_to_return

View File

@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . scorer import SimulScorer
from . import register_scorer from . import register_scorer
from .scorer import SimulScorer
@register_scorer("text") @register_scorer("text")
@ -13,7 +13,7 @@ class SimulTextScorer(SimulScorer):
super().__init__(args) super().__init__(args)
self.data = { self.data = {
"src": self._load_text_file(args.src_file, split=True), "src": self._load_text_file(args.src_file, split=True),
"tgt": self._load_text_file(args.tgt_file, split=False) "tgt": self._load_text_file(args.tgt_file, split=False),
} }
def send_src(self, sent_id, *args): def send_src(self, sent_id, *args):
@ -21,7 +21,7 @@ class SimulTextScorer(SimulScorer):
dict_to_return = { dict_to_return = {
"sent_id": sent_id, "sent_id": sent_id,
"segment_id": self.steps[sent_id], "segment_id": self.steps[sent_id],
"segment": self.eos "segment": self.eos,
} }
# Consider EOS # Consider EOS
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1 self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
@ -29,7 +29,7 @@ class SimulTextScorer(SimulScorer):
dict_to_return = { dict_to_return = {
"sent_id": sent_id, "sent_id": sent_id,
"segment_id": self.steps[sent_id], "segment_id": self.steps[sent_id],
"segment": self.data["src"][sent_id][self.steps[sent_id]] "segment": self.data["src"][sent_id][self.steps[sent_id]],
} }
self.steps[sent_id] += 1 self.steps[sent_id] += 1

View File

@ -3,12 +3,14 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse import argparse
import sys
import json import json
from tornado import web, ioloop import sys
from scorers import build_scorer
DEFAULT_HOSTNAME = 'localhost' from scorers import build_scorer
from tornado import ioloop, web
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321 DEFAULT_PORT = 12321
@ -34,10 +36,10 @@ class ResultHandler(ScorerHandler):
class SourceHandler(ScorerHandler): class SourceHandler(ScorerHandler):
def get(self): def get(self):
sent_id = int(self.get_argument('sent_id')) sent_id = int(self.get_argument("sent_id"))
segment_size = None segment_size = None
if "segment_size" in self.request.arguments: if "segment_size" in self.request.arguments:
string = self.get_argument('segment_size') string = self.get_argument("segment_size")
if len(string) > 0: if len(string) > 0:
segment_size = int(string) segment_size = int(string)
@ -48,8 +50,8 @@ class SourceHandler(ScorerHandler):
class HypothesisHandler(ScorerHandler): class HypothesisHandler(ScorerHandler):
def put(self): def put(self):
sent_id = int(self.get_argument('sent_id')) sent_id = int(self.get_argument("sent_id"))
list_of_tokens = self.request.body.decode('utf-8').strip().split() list_of_tokens = self.request.body.decode("utf-8").strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens) self.scorer.recv_hyp(sent_id, list_of_tokens)
@ -67,18 +69,21 @@ def add_args():
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False): def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
app = web.Application([ app = web.Application(
(r'/result', ResultHandler, dict(scorer=scorer)), [
(r'/src', SourceHandler, dict(scorer=scorer)), (r"/result", ResultHandler, dict(scorer=scorer)),
(r'/hypo', HypothesisHandler, dict(scorer=scorer)), (r"/src", SourceHandler, dict(scorer=scorer)),
(r'/', EvalSessionHandler, dict(scorer=scorer)), (r"/hypo", HypothesisHandler, dict(scorer=scorer)),
], debug=debug) (r"/", EvalSessionHandler, dict(scorer=scorer)),
],
debug=debug,
)
app.listen(port, max_buffer_size=1024 ** 3) app.listen(port, max_buffer_size=1024 ** 3)
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n") sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
ioloop.IOLoop.current().start() ioloop.IOLoop.current().start()
if __name__ == '__main__': if __name__ == "__main__":
args = add_args() args = add_args()
scorer = build_scorer(args) scorer = build_scorer(args)
start_server(scorer, args.hostname, args.port, args.debug) start_server(scorer, args.hostname, args.port, args.debug)

View File

@ -6,7 +6,10 @@
import importlib import importlib
import os import os
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
model_name = file[:file.find('.py')] model_name = file[: file.find(".py")]
importlib.import_module('examples.simultaneous_translation.models.' + model_name) importlib.import_module(
"examples.simultaneous_translation.models." + model_name
)

View File

@ -6,42 +6,34 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
from fairseq.models import ( TransformerMonotonicDecoderLayer,
register_model, TransformerMonotonicEncoderLayer,
register_model_architecture,
) )
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import ( from fairseq.models.transformer import (
TransformerModel,
TransformerEncoder,
TransformerDecoder, TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture, base_architecture,
transformer_iwslt_de_en, transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_de_big,
) )
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_unidirectional') @register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel): class TransformerUnidirectionalModel(TransformerModel):
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens) return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@register_model('transformer_monotonic') @register_model("transformer_monotonic")
class TransformerMonotonicModel(TransformerModel): class TransformerMonotonicModel(TransformerModel):
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens) return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@ -62,26 +54,17 @@ class TransformerMonotonicModel(TransformerModel):
) )
tgt_indices = tensor( tgt_indices = tensor(
[ [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
[self.decoder.dictionary.eos()]
+ states["indices"]["tgt"]
]
) )
else: else:
src_indices = states["indices"]["src"][: 1 + src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
states["steps"]["src"]]
tgt_indices = states["indices"]["tgt"] tgt_indices = states["indices"]["tgt"]
return src_indices, None, tgt_indices return src_indices, None, tgt_indices
def predict_from_states(self, states): def predict_from_states(self, states):
decoder_states = self.decoder.output_layer( decoder_states = self.decoder.output_layer(states["decoder_features"])
states["decoder_features"] lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
)
lprobs = self.get_normalized_probs(
[decoder_states[:, -1:]],
log_probs=True
)
index = lprobs.argmax(dim=-1) index = lprobs.argmax(dim=-1)
@ -90,25 +73,24 @@ class TransformerMonotonicModel(TransformerModel):
return token, index[0, 0].item() return token, index[0, 0].item()
def decision_from_states(self, states): def decision_from_states(self, states):
''' """
This funcion take states dictionary as input, and gives the agent This funcion take states dictionary as input, and gives the agent
a decision of whether read a token from server. Moreover, the decoder a decision of whether read a token from server. Moreover, the decoder
states are also calculated here so we can directly generate a target states are also calculated here so we can directly generate a target
token without recompute every thing token without recompute every thing
''' """
self.eval() self.eval()
if len(states["tokens"]["src"]) == 0: if len(states["tokens"]["src"]) == 0:
return 0 return 0
src_indices, src_lengths, tgt_indices = self._indices_from_states( src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
states)
# Update encoder states if needed # Update encoder states if needed
if ( if (
"encoder_states" not in states or "encoder_states" not in states
states["encoder_states"][0].size(1) <= states["steps"]["src"] or states["encoder_states"][0].size(1) <= states["steps"]["src"]
): ):
encoder_out_dict = self.encoder(src_indices, src_lengths) encoder_out_dict = self.encoder(src_indices, src_lengths)
states["encoder_states"] = encoder_out_dict states["encoder_states"] = encoder_out_dict
@ -136,16 +118,14 @@ class TransformerMonotonicModel(TransformerModel):
class TransformerMonotonicEncoder(TransformerEncoder): class TransformerMonotonicEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens) super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary self.dictionary = dictionary
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend(
TransformerMonotonicEncoderLayer(args) [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
for i in range(args.encoder_layers) )
])
class TransformerMonotonicDecoder(TransformerDecoder): class TransformerMonotonicDecoder(TransformerDecoder):
@ -166,19 +146,24 @@ class TransformerMonotonicDecoder(TransformerDecoder):
self.dictionary = dictionary self.dictionary = dictionary
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend(
TransformerMonotonicDecoderLayer(args, no_encoder_attn) [
for _ in range(args.decoder_layers) TransformerMonotonicDecoderLayer(args, no_encoder_attn)
]) for _ in range(args.decoder_layers)
]
)
def pre_attention( def pre_attention(
self, prev_output_tokens, encoder_out_dict, self, prev_output_tokens, encoder_out_dict, incremental_state=None
incremental_state=None
): ):
positions = self.embed_positions( positions = (
prev_output_tokens, self.embed_positions(
incremental_state=incremental_state, prev_output_tokens,
) if self.embed_positions is not None else None incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
@ -216,8 +201,7 @@ class TransformerMonotonicDecoder(TransformerDecoder):
return x return x
def extract_features( def extract_features(
self, prev_output_tokens, encoder_out, self, prev_output_tokens, encoder_out, incremental_state=None, **unused
incremental_state=None, **unused
): ):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
@ -228,14 +212,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
- a dictionary with any model-specific outputs - a dictionary with any model-specific outputs
""" """
# incremental_state = None # incremental_state = None
( (x, encoder_outs, encoder_padding_mask) = self.pre_attention(
x, prev_output_tokens, encoder_out, incremental_state
encoder_outs,
encoder_padding_mask
) = self.pre_attention(
prev_output_tokens,
encoder_out,
incremental_state
) )
attn = None attn = None
inner_states = [x] inner_states = [x]
@ -250,7 +228,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
encoder_padding_mask=encoder_padding_mask, encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x) self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None else None, if incremental_state is None
else None,
) )
inner_states.append(x) inner_states.append(x)
@ -261,38 +240,30 @@ class TransformerMonotonicDecoder(TransformerDecoder):
step_list.append(curr_steps) step_list.append(curr_steps)
if incremental_state.get("online", False): if incremental_state.get("online", False):
p_choose = attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) p_choose = (
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
new_steps = (
curr_steps
+ (p_choose < 0.5).t().type_as(curr_steps)
) )
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
if (new_steps >= incremental_state["steps"]["src"]).any(): if (new_steps >= incremental_state["steps"]["src"]).any():
# We need to prune the last self_attn saved_state # We need to prune the last self_attn saved_state
# if model decide not to read # if model decide not to read
# otherwise there will be duplicated saved_state # otherwise there will be duplicated saved_state
for j in range(i + 1): for j in range(i + 1):
self.layers[j].prune_incremental_state( self.layers[j].prune_incremental_state(incremental_state)
incremental_state)
return x, {"action": 0} return x, {"action": 0}
if ( if incremental_state is not None and not incremental_state.get("online", False):
incremental_state is not None
and not incremental_state.get("online", False)
):
# Here is for fast evaluation # Here is for fast evaluation
fastest_step = torch.max( fastest_step = (
torch.cat(step_list, dim=1), torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
dim=1, )
keepdim=True
)[0] + 1
if "fastest_step" in incremental_state: if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = torch.cat( incremental_state["fastest_step"] = torch.cat(
[incremental_state["fastest_step"], fastest_step], [incremental_state["fastest_step"], fastest_step], dim=1
dim=1
) )
else: else:
incremental_state["fastest_step"] = fastest_step incremental_state["fastest_step"] = fastest_step
@ -310,25 +281,19 @@ class TransformerMonotonicDecoder(TransformerDecoder):
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order) super().reorder_incremental_state(incremental_state, new_order)
if "fastest_step" in incremental_state: if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = ( incremental_state["fastest_step"] = incremental_state[
incremental_state["fastest_step"] "fastest_step"
.index_select(0, new_order) ].index_select(0, new_order)
)
@register_model_architecture( @register_model_architecture("transformer_monotonic", "transformer_monotonic")
'transformer_monotonic',
'transformer_monotonic'
)
def base_monotonic_rchitecture(args): def base_monotonic_rchitecture(args):
base_architecture(args) base_architecture(args)
args.encoder_unidirectional = getattr( args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
args, 'encoder_unidirectional', False)
@register_model_architecture( @register_model_architecture(
'transformer_monotonic', "transformer_monotonic", "transformer_monotonic_iwslt_de_en"
'transformer_monotonic_iwslt_de_en'
) )
def transformer_monotonic_iwslt_de_en(args): def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args) transformer_iwslt_de_en(args)
@ -337,24 +302,21 @@ def transformer_monotonic_iwslt_de_en(args):
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture( @register_model_architecture(
'transformer_monotonic', "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
'transformer_monotonic_vaswani_wmt_en_de_big'
) )
def transformer_monotonic_vaswani_wmt_en_de_big(args): def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args) transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture( @register_model_architecture(
'transformer_monotonic', "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
'transformer_monotonic_vaswani_wmt_en_fr_big'
) )
def transformer_monotonic_vaswani_wmt_en_fr_big(args): def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args) transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture( @register_model_architecture(
'transformer_unidirectional', "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
'transformer_unidirectional_iwslt_de_en'
) )
def transformer_unidirectional_iwslt_de_en(args): def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args) transformer_iwslt_de_en(args)

View File

@ -7,14 +7,18 @@ import importlib
import os import os
from fairseq import registry from fairseq import registry
( (
build_monotonic_attention, build_monotonic_attention,
register_monotonic_attention, register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY, MONOTONIC_ATTENTION_REGISTRY,
_ _,
) = registry.setup_registry('--simul-type') ) = registry.setup_registry("--simul-type")
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
model_name = file[:file.find('.py')] model_name = file[: file.find(".py")]
importlib.import_module('examples.simultaneous_translation.modules.' + model_name) importlib.import_module(
"examples.simultaneous_translation.modules." + model_name
)

View File

@ -4,22 +4,19 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import torch import torch
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import MultiheadAttention
from examples.simultaneous_translation.utils.functions import ( from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod, exclusive_cumprod,
lengths_to_mask lengths_to_mask,
) )
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules import MultiheadAttention
from fairseq.utils import convert_padding_direction from fairseq.utils import convert_padding_direction
from . import register_monotonic_attention from . import register_monotonic_attention
@ -28,6 +25,7 @@ class MonotonicAttention(nn.Module):
""" """
Abstract class of monotonic attentions Abstract class of monotonic attentions
""" """
def __init__(self, args): def __init__(self, args):
self.eps = args.attention_eps self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation self.mass_preservation = args.mass_preservation
@ -38,7 +36,8 @@ class MonotonicAttention(nn.Module):
self.energy_bias_init = args.energy_bias_init self.energy_bias_init = args.energy_bias_init
self.energy_bias = ( self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1])) nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True else 0 if args.energy_bias is True
else 0
) )
@staticmethod @staticmethod
@ -90,7 +89,7 @@ class MonotonicAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill( attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
float('-inf'), float("-inf"),
) )
return attn_energy return attn_energy
@ -131,10 +130,7 @@ class MonotonicAttention(nn.Module):
alpha_i = ( alpha_i = (
p_choose[:, i] p_choose[:, i]
* cumprod_1mp[:, i] * cumprod_1mp[:, i]
* torch.cumsum( * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
dim=1
)
).clamp(0, 1.0) ).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1)) previous_attn.append(alpha_i.unsqueeze(1))
@ -170,8 +166,7 @@ class MonotonicAttention(nn.Module):
# prev_monotonic_step: bsz, num_heads # prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get( prev_monotonic_step = monotonic_cache.get(
"step", "step", p_choose.new_zeros([bsz, self.num_heads]).long()
p_choose.new_zeros([bsz, self.num_heads]).long()
) )
bsz, num_heads = prev_monotonic_step.size() bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads assert num_heads == self.num_heads
@ -181,8 +176,7 @@ class MonotonicAttention(nn.Module):
p_choose = p_choose.view(bsz, num_heads, src_len) p_choose = p_choose.view(bsz, num_heads, src_len)
if key_padding_mask is not None: if key_padding_mask is not None:
src_lengths = src_len - \ src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
key_padding_mask.sum(dim=1, keepdim=True).long()
else: else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
@ -197,10 +191,7 @@ class MonotonicAttention(nn.Module):
# left_pad_source = True: # left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True) step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
max_steps = ( max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
src_lengths - 1 if self.mass_preservation
else src_lengths
)
# finish_read: bsz, num_heads # finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps) finish_read = new_monotonic_step.eq(max_steps)
@ -210,11 +201,11 @@ class MonotonicAttention(nn.Module):
# only choose the p at monotonic steps # only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads # p_choose_i: bsz , self.num_heads
p_choose_i = ( p_choose_i = (
p_choose p_choose.gather(
.gather(
2, 2,
(step_offset + new_monotonic_step).unsqueeze(2) (step_offset + new_monotonic_step)
.clamp(0, src_len - 1) .unsqueeze(2)
.clamp(0, src_len - 1),
) )
).squeeze(2) ).squeeze(2)
@ -239,21 +230,17 @@ class MonotonicAttention(nn.Module):
# alpha: bsz * num_heads, 1, src_len # alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads # new_monotonic_step: bsz, num_heads
alpha = ( alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
p_choose 1,
.new_zeros([bsz * self.num_heads, src_len]) (step_offset + new_monotonic_step)
.scatter( .view(bsz * self.num_heads, 1)
1, .clamp(0, src_len - 1),
(step_offset + new_monotonic_step).view(bsz * 1,
self.num_heads, 1).clamp(0, src_len - 1),
1
)
) )
if not self.mass_preservation: if not self.mass_preservation:
alpha = alpha.masked_fill( alpha = alpha.masked_fill(
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
0
) )
alpha = alpha.unsqueeze(1) alpha = alpha.unsqueeze(1)
@ -266,8 +253,14 @@ class MonotonicAttention(nn.Module):
raise NotImplementedError raise NotImplementedError
def forward( def forward(
self, query, key, value, self,
key_padding_mask=None, incremental_state=None, *args, **kwargs, query,
key,
value,
key_padding_mask=None,
incremental_state=None,
*args,
**kwargs,
): ):
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
@ -280,25 +273,24 @@ class MonotonicAttention(nn.Module):
# expected alignment alpha # expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len # bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None: if incremental_state is not None:
alpha = self.expected_alignment_infer(p_choose, key_padding_mask, incremental_state) alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state
)
else: else:
alpha = self.expected_alignment_train(p_choose, key_padding_mask) alpha = self.expected_alignment_train(p_choose, key_padding_mask)
# expected attention beta # expected attention beta
# bsz * self.num_heads, tgt_len, src_len # bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(alpha, query, key, value, key_padding_mask, incremental_state) beta = self.expected_attention(
alpha, query, key, value, key_padding_mask, incremental_state
)
attn_weights = beta attn_weights = beta
v_proj = self.v_proj_output(value) v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = ( attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn
.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn = self.out_proj(attn) attn = self.out_proj(attn)
@ -318,26 +310,32 @@ class MonotonicAttention(nn.Module):
self._set_monotonic_buffer(incremental_state, input_buffer) self._set_monotonic_buffer(incremental_state, input_buffer)
def _get_monotonic_buffer(self, incremental_state): def _get_monotonic_buffer(self, incremental_state):
return utils.get_incremental_state( return (
self, utils.get_incremental_state(
incremental_state, self,
'monotonic', incremental_state,
) or {} "monotonic",
)
or {}
)
def _set_monotonic_buffer(self, incremental_state, buffer): def _set_monotonic_buffer(self, incremental_state, buffer):
utils.set_incremental_state( utils.set_incremental_state(
self, self,
incremental_state, incremental_state,
'monotonic', "monotonic",
buffer, buffer,
) )
def get_pointer(self, incremental_state): def get_pointer(self, incremental_state):
return utils.get_incremental_state( return (
self, utils.get_incremental_state(
incremental_state, self,
'monotonic', incremental_state,
) or {} "monotonic",
)
or {}
)
def get_fastest_pointer(self, incremental_state): def get_fastest_pointer(self, incremental_state):
return self.get_pointer(incremental_state)["step"].max(0)[0] return self.get_pointer(incremental_state)["step"].max(0)[0]
@ -354,23 +352,22 @@ class MonotonicAttention(nn.Module):
utils.set_incremental_state( utils.set_incremental_state(
self, self,
incremental_state, incremental_state,
'monotonic', "monotonic",
{"step": buffer}, {"step": buffer},
) )
@register_monotonic_attention("hard_aligned") @register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention): class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
def __init__(self, args): def __init__(self, args):
MultiheadAttention.__init__( MultiheadAttention.__init__(
self, self,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads, num_heads=args.decoder_attention_heads,
kdim=getattr(args, 'encoder_embed_dim', None), kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout, dropout=args.attention_dropout,
encoder_decoder_attention=True encoder_decoder_attention=True,
) )
MonotonicAttention.__init__(self, args) MonotonicAttention.__init__(self, args)
@ -395,21 +392,33 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
bsz = query.size(1) bsz = query.size(1)
q = self.q_in_proj[name](query) q = self.q_in_proj[name](query)
q *= self.scaling q *= self.scaling
q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) q = (
q.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else: else:
q = None q = None
if key is not None: if key is not None:
bsz = key.size(1) bsz = key.size(1)
k = self.k_in_proj[name](key) k = self.k_in_proj[name](key)
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else: else:
k = None k = None
if value is not None: if value is not None:
bsz = value.size(1) bsz = value.size(1)
v = self.v_in_proj[name](value) v = self.v_in_proj[name](value)
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else: else:
v = None v = None
@ -441,8 +450,7 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
if self.training: if self.training:
# add noise here to encourage discretness # add noise here to encourage discretness
noise = ( noise = (
torch torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
.normal(self.noise_mean, self.noise_var, attn_energy.size())
.type_as(attn_energy) .type_as(attn_energy)
.to(attn_energy.device) .to(attn_energy.device)
) )
@ -454,9 +462,9 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
return p_choose.view(-1, tgt_len, src_len) return p_choose.view(-1, tgt_len, src_len)
def expected_attention(self, alpha, *args): def expected_attention(self, alpha, *args):
''' """
For MMA-H, beta = alpha For MMA-H, beta = alpha
''' """
return alpha return alpha
def v_proj_output(self, value): def v_proj_output(self, value):
@ -479,13 +487,19 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
if self.qkv_same_dim: if self.qkv_same_dim:
# Empirically observed the convergence to be much better with # Empirically observed the convergence to be much better with
# the scaled initialization # the scaled initialization
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)) self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else: else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight) nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
def expected_attention(self, alpha, query, key, value, key_padding_mask, incremental_state): def expected_attention(
self, alpha, query, key, value, key_padding_mask, incremental_state
):
# monotonic attention, we will calculate milk here # monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size() bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads) bsz = int(bsz_x_num_heads / self.num_heads)
@ -507,9 +521,10 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
step_offset = key_padding_mask.sum(dim=-1, keepdim=True) step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_step += step_offset monotonic_step += step_offset
mask = lengths_to_mask( mask = lengths_to_mask(
monotonic_step.view(-1), soft_energy.size(2), 1).unsqueeze(1) monotonic_step.view(-1), soft_energy.size(2), 1
).unsqueeze(1)
soft_energy = soft_energy.masked_fill(~ mask.bool(), float('-inf')) soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy) exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2) exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
@ -524,14 +539,20 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
if key_padding_mask is not None: if key_padding_mask is not None:
if key_padding_mask.any(): if key_padding_mask.any():
exp_soft_energy_cumsum = ( exp_soft_energy_cumsum = (
exp_soft_energy_cumsum.view(-1, self.num_heads, tgt_len, src_len) exp_soft_energy_cumsum.view(
.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps) -1, self.num_heads, tgt_len, src_len
)
.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
)
.view(-1, tgt_len, src_len) .view(-1, tgt_len, src_len)
) )
inner_items = alpha / exp_soft_energy_cumsum inner_items = alpha / exp_soft_energy_cumsum
beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2]) beta = exp_soft_energy * torch.cumsum(
inner_items.flip(dims=[2]), dim=2
).flip(dims=[2])
beta = self.dropout_module(beta) beta = self.dropout_module(beta)
@ -547,7 +568,9 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
self.q_in_proj["soft"] = self.q_in_proj["monotonic"] self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"] self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging self.waitk_lagging = args.waitk_lagging
assert self.waitk_lagging > 0, f"Lagging has to been larger than 0, get {self.waitk_lagging}." assert (
self.waitk_lagging > 0
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
@ -556,10 +579,13 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
MonotonicMultiheadAttentionWaitk, MonotonicMultiheadAttentionWaitk,
).add_args(parser) ).add_args(parser)
parser.add_argument('--waitk-lagging', type=int, required=True, parser.add_argument(
help='Wait k lagging') "--waitk-lagging", type=int, required=True, help="Wait k lagging"
)
def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None): def p_choose(
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
):
""" """
query: bsz, tgt_len query: bsz, tgt_len
key: bsz, src_len key: bsz, src_len
@ -574,16 +600,22 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any(): if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
# Left pad source # Left pad source
# add -1 to the end # add -1 to the end
p_choose = p_choose.masked_fill(key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1) p_choose = p_choose.masked_fill(
p_choose = convert_padding_direction(p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True) key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
)
p_choose = convert_padding_direction(
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
)
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query) p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
# remove -1 # remove -1
p_choose[p_choose.eq(-1)] = 0 p_choose[p_choose.eq(-1)] = 0
# Extend to each head # Extend to each head
p_choose = ( p_choose = (
p_choose.contiguous().unsqueeze(1) p_choose.contiguous()
.expand(-1, self.num_heads, -1, -1).contiguous() .unsqueeze(1)
.expand(-1, self.num_heads, -1, -1)
.contiguous()
.view(-1, tgt_len, src_len) .view(-1, tgt_len, src_len)
) )

View File

@ -3,37 +3,32 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from fairseq.modules import ( from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
LayerNorm,
TransformerEncoderLayer,
TransformerDecoderLayer
)
from . import build_monotonic_attention from . import build_monotonic_attention
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer): class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
def forward(self, x, encoder_padding_mask): def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size() seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1) attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float('-inf')) attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask) return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer): class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
def __init__(
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__( super().__init__(
args, args,
no_encoder_attn=True, no_encoder_attn=True,
add_bias_kv=add_bias_kv, add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn add_zero_attn=add_zero_attn,
) )
self.encoder_attn = build_monotonic_attention(args) self.encoder_attn = build_monotonic_attention(args)
self.encoder_attn_layer_norm = LayerNorm( self.encoder_attn_layer_norm = LayerNorm(
self.embed_dim, self.embed_dim, export=getattr(args, "char_inputs", False)
export=getattr(args, 'char_inputs', False)
) )
def prune_incremental_state(self, incremental_state): def prune_incremental_state(self, incremental_state):
@ -46,12 +41,8 @@ class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
input_buffer = {} input_buffer = {}
break break
module._set_input_buffer(incremental_state, input_buffer) module._set_input_buffer(incremental_state, input_buffer)
prune(self.self_attn) prune(self.self_attn)
def get_steps(self, incremental_state): def get_steps(self, incremental_state):
return ( return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
self.encoder_attn
._get_monotonic_buffer(
incremental_state
).get("step", 0)
)

View File

@ -9,6 +9,6 @@ import os
# automatically import any Python files in the criterions/ directory # automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
module = file[:file.find('.py')] module = file[: file.find(".py")]
importlib.import_module('examples.simultaneous_translation.utils.' + module) importlib.import_module("examples.simultaneous_translation.utils." + module)

View File

@ -16,7 +16,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
tensor_size = list(tensor.size()) tensor_size = list(tensor.size())
tensor_size[dim] = 1 tensor_size[dim] = 1
return_tensor = safe_cumprod( return_tensor = safe_cumprod(
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
dim=dim,
eps=eps,
) )
if dim == 0: if dim == 0:
@ -132,12 +134,14 @@ def moving_sum(x, start_idx: int, end_idx: int):
# batch_size, 1, src_len # batch_size, 1, src_len
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1]) moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
moving_sum = torch.nn.functional.conv1d( moving_sum = (
x, torch.nn.functional.conv1d(
moving_sum_weight, x, moving_sum_weight, padding=start_idx + end_idx - 1
padding=start_idx + end_idx - 1 )
).squeeze(1).t() .squeeze(1)
moving_sum = moving_sum[end_idx: -start_idx] .t()
)
moving_sum = moving_sum[end_idx:-start_idx]
assert src_len == moving_sum.size(0) assert src_len == moving_sum.size(0)
assert batch_size == moving_sum.size(1) assert batch_size == moving_sum.size(1)

View File

@ -18,7 +18,7 @@ class LatencyMetric(object):
src_lens, src_lens,
target_padding_mask=None, target_padding_mask=None,
batch_first: bool = False, batch_first: bool = False,
start_from_zero: bool = True start_from_zero: bool = True,
): ):
assert len(delays.size()) == 2 assert len(delays.size()) == 2
assert len(src_lens.size()) == 2 assert len(src_lens.size()) == 2
@ -59,11 +59,7 @@ class LatencyMetric(object):
start_from_zero: bool = True, start_from_zero: bool = True,
): ):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric( delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
delays, delays, src_lens, target_padding_mask, batch_first, start_from_zero
src_lens,
target_padding_mask,
batch_first,
start_from_zero
) )
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask) return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@ -89,10 +85,13 @@ class AverageProportion(LatencyMetric):
AP = 1 / (|x||y]) sum_i^|Y| deleys_i AP = 1 / (|x||y]) sum_i^|Y| deleys_i
""" """
@staticmethod @staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None: if target_padding_mask is not None:
AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True) AP = torch.sum(
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
)
else: else:
AP = torch.sum(delays, dim=0, keepdim=True) AP = torch.sum(delays, dim=0, keepdim=True)
@ -116,14 +115,24 @@ class AverageLagging(LatencyMetric):
gamma = |y| / |x| gamma = |y| / |x|
tau = argmin_i(delays_i = |x|) tau = argmin_i(delays_i = |x|)
""" """
@staticmethod @staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|) # tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size() tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens lagging_padding_mask = delays >= src_lens
lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :] lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask.t(), (1, 0)
).t()[:-1, :]
gamma = tgt_lens / src_lens gamma = tgt_lens / src_lens
lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma lagging = (
delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
lagging.masked_fill_(lagging_padding_mask, 0) lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True) tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau AL = lagging.sum(dim=0, keepdim=True) / tau
@ -149,6 +158,7 @@ class DifferentiableAverageLagging(LatencyMetric):
2. max(delays_i, delays'_{i-1} + 1 / gamma) 2. max(delays_i, delays'_{i-1} + 1 / gamma)
""" """
@staticmethod @staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size() tgt_len, bsz = delays.size()
@ -163,13 +173,18 @@ class DifferentiableAverageLagging(LatencyMetric):
new_delays[i] = torch.cat( new_delays[i] = torch.cat(
[ [
new_delays[i - 1].unsqueeze(0) + 1 / gamma, new_delays[i - 1].unsqueeze(0) + 1 / gamma,
delays[i].unsqueeze(0) delays[i].unsqueeze(0),
], ],
dim=0 dim=0,
).max(dim=0)[0] ).max(dim=0)[0]
DAL = ( DAL = (
new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma new_delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
) )
if target_padding_mask is not None: if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0) DAL = DAL.masked_fill(target_padding_mask, 0)
@ -186,7 +201,7 @@ class LatencyMetricVariance(LatencyMetric):
src_lens, src_lens,
target_padding_mask=None, target_padding_mask=None,
batch_first: bool = True, batch_first: bool = True,
start_from_zero: bool = True start_from_zero: bool = True,
): ):
assert batch_first assert batch_first
assert len(delays.size()) == 3 assert len(delays.size()) == 3
@ -256,25 +271,21 @@ class LatencyInference(object):
src_lens = src_lens src_lens = src_lens
delays = ( delays = monotonic_step.view(
monotonic_step monotonic_step.size(0), -1, monotonic_step.size(-1)
.view(monotonic_step.size(0), -1, monotonic_step.size(-1)) ).max(dim=1)[0]
.max(dim=1)[0]
)
delays = ( delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
delays.masked_fill(delays >= src_lens, 0) delays
+ (src_lens - 1) ).masked_fill(delays < src_lens, 0)
.expand_as(delays)
.masked_fill(delays < src_lens, 0)
)
return_dict = {} return_dict = {}
for key, func in self.metric_calculator.items(): for key, func in self.metric_calculator.items():
return_dict[key] = func( return_dict[key] = func(
delays.float(), src_lens.float(), delays.float(),
src_lens.float(),
target_padding_mask=None, target_padding_mask=None,
batch_first=True, batch_first=True,
start_from_zero=True start_from_zero=True,
).t() ).t()
return return_dict return return_dict
@ -282,8 +293,13 @@ class LatencyInference(object):
class LatencyTraining(object): class LatencyTraining(object):
def __init__( def __init__(
self, avg_weight, var_weight, avg_type, var_type, self,
stay_on_last_token, average_method, avg_weight,
var_weight,
avg_type,
var_type,
stay_on_last_token,
average_method,
): ):
self.avg_weight = avg_weight self.avg_weight = avg_weight
self.var_weight = var_weight self.var_weight = var_weight
@ -319,17 +335,12 @@ class LatencyTraining(object):
attention = attention.view(-1, tgt_len, src_len) attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token: if not self.stay_on_last_token:
residual_attention = \ residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
1 - attention[:, :, :-1].sum(dim=2, keepdim=True) attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
attention = torch.cat(
[attention[:, :, :-1], residual_attention],
dim=2
)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA # bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = ( steps = (
torch torch.arange(1, 1 + src_len)
.arange(1, 1 + src_len)
.unsqueeze(0) .unsqueeze(0)
.unsqueeze(1) .unsqueeze(1)
.expand_as(attention) .expand_as(attention)
@ -355,15 +366,12 @@ class LatencyTraining(object):
src_lens = src_lens.view(-1, 1) src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len # bsz * num_heads_num_layers, tgt_len, src_len
expected_delays = (steps * attention).sum(dim=2).view( expected_delays = (
bsz, num_heads_x_layers, tgt_len (steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
) )
if target_padding_mask is not None: if target_padding_mask is not None:
expected_delays.masked_fill_( expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
target_padding_mask.unsqueeze(1),
0
)
return expected_delays, src_lens return expected_delays, src_lens
@ -371,8 +379,7 @@ class LatencyTraining(object):
bsz, num_heads_x_layers, tgt_len = expected_delays.size() bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = ( target_padding_mask = (
target_padding_mask target_padding_mask.unsqueeze(1)
.unsqueeze(1)
.expand_as(expected_delays) .expand_as(expected_delays)
.contiguous() .contiguous()
.view(-1, tgt_len) .view(-1, tgt_len)
@ -396,8 +403,11 @@ class LatencyTraining(object):
if self.avg_weight > 0.0: if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator: if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type]( average_delays = self.metric_calculator[self.avg_type](
expected_delays, src_lens, target_padding_mask, expected_delays,
batch_first=True, start_from_zero=False src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
) )
else: else:
raise RuntimeError(f"{self.avg_type} is not supported.") raise RuntimeError(f"{self.avg_type} is not supported.")
@ -408,12 +418,17 @@ class LatencyTraining(object):
return 0.0 return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask): def var_loss(self, expected_delays, src_lens, target_padding_mask):
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1] src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
:, :1
]
if self.var_weight > 0.0: if self.var_weight > 0.0:
if self.var_type in self.variance_calculator: if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type]( variance_delays = self.variance_calculator[self.var_type](
expected_delays, src_lens, target_padding_mask, expected_delays,
batch_first=True, start_from_zero=False src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
) )
else: else:
raise RuntimeError(f"{self.var_type} is not supported.") raise RuntimeError(f"{self.var_type} is not supported.")

View File

@ -1 +1 @@
from . import tasks, criterions, models # noqa from . import criterions, models, tasks # noqa

View File

@ -6,9 +6,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.criterions import FairseqCriterion, register_criterion
from examples.speech_recognition.data.replabels import pack_replabels
@register_criterion("asg_loss") @register_criterion("asg_loss")

View File

@ -5,6 +5,7 @@
from .asr_dataset import AsrDataset from .asr_dataset import AsrDataset
__all__ = [ __all__ = [
'AsrDataset', "AsrDataset",
] ]

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import numpy as np import numpy as np
from fairseq.data import FairseqDataset from fairseq.data import FairseqDataset
@ -30,16 +31,22 @@ class AsrDataset(FairseqDataset):
""" """
def __init__( def __init__(
self, aud_paths, aud_durations_ms, tgt, self,
tgt_dict, ids, speakers, aud_paths,
num_mel_bins=80, frame_length=25.0, frame_shift=10.0 aud_durations_ms,
tgt,
tgt_dict,
ids,
speakers,
num_mel_bins=80,
frame_length=25.0,
frame_shift=10.0,
): ):
assert frame_length > 0 assert frame_length > 0
assert frame_shift > 0 assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms) assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [ self.frame_sizes = [
int(1 + (d - frame_length) / frame_shift) int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
for d in aud_durations_ms
] ]
assert len(aud_paths) > 0 assert len(aud_paths) > 0
@ -57,13 +64,17 @@ class AsrDataset(FairseqDataset):
self.frame_shift = frame_shift self.frame_shift = frame_shift
self.s2s_collater = Seq2SeqCollater( self.s2s_collater = Seq2SeqCollater(
0, 1, pad_index=self.tgt_dict.pad(), 0,
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True 1,
pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(),
move_eos_to_beginning=True,
) )
def __getitem__(self, index): def __getitem__(self, index):
import torchaudio import torchaudio
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index] path = self.aud_paths[index]
@ -74,7 +85,7 @@ class AsrDataset(FairseqDataset):
sound, sound,
num_mel_bins=self.num_mel_bins, num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length, frame_length=self.frame_length,
frame_shift=self.frame_shift frame_shift=self.frame_shift,
) )
output_cmvn = data_utils.apply_mv_norm(output) output_cmvn = data_utils.apply_mv_norm(output)

View File

@ -12,18 +12,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import numpy as np
import torch import torch
from fairseq.data import data_utils as fairseq_data_utils from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object): class Seq2SeqCollater(object):
""" """
Implements collate function mainly for seq2seq tasks Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and This expects each sample to contain feature (src_tokens) and
targets. targets.
This collator is also used for aligned training task. This collator is also used for aligned training task.
""" """
def __init__( def __init__(

View File

@ -6,52 +6,74 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
from collections import namedtuple
import concurrent.futures
from itertools import chain
import argparse import argparse
import os import concurrent.futures
import json import json
import sentencepiece as spm
import multiprocessing import multiprocessing
import os
from collections import namedtuple
from itertools import chain
import sentencepiece as spm
from fairseq.data import Dictionary from fairseq.data import Dictionary
MILLISECONDS_TO_SECONDS = 0.001 MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict): def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio import torchaudio
input = {} input = {}
output = {} output = {}
si, ei = torchaudio.info(aud_path) si, ei = torchaudio.info(aud_path)
input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS) input["length_ms"] = int(
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
)
input["path"] = aud_path input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable)) token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False) ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable output["text"] = lable
output["token"] = token output["token"] = token
output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids])) output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}} return {utt_id: {"input": input, "output": output}}
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True, parser.add_argument(
help="input directories with audio files") "--audio-dirs",
parser.add_argument("--labels", required=True, nargs="+",
help="aggregated input labels with format <ID LABEL> per line", default=["-"],
type=argparse.FileType('r', encoding='UTF-8')) required=True,
parser.add_argument("--spm-model", required=True, help="input directories with audio files",
help="sentencepiece model to use for encoding", )
type=argparse.FileType('r', encoding='UTF-8')) parser.add_argument(
parser.add_argument("--dictionary", required=True, "--labels",
help="file to load fairseq dictionary from", required=True,
type=argparse.FileType('r', encoding='UTF-8')) help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--spm-model",
required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--dictionary",
required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
parser.add_argument("--output", required=True, type=argparse.FileType('w'), parser.add_argument(
help="path to save json output") "--output",
required=True,
type=argparse.FileType("w"),
help="path to save json output",
)
args = parser.parse_args() args = parser.parse_args()
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
@ -64,15 +86,17 @@ def main():
(utt_id, label) = line.split(" ", 1) (utt_id, label) = line.split(" ", 1)
labels[utt_id] = label labels[utt_id] = label
if len(labels) == 0: if len(labels) == 0:
raise Exception('No labels found in ', args.labels_path) raise Exception("No labels found in ", args.labels_path)
Sample = namedtuple('Sample', 'aud_path utt_id') Sample = namedtuple("Sample", "aud_path utt_id")
samples = [] samples = []
for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs): for path, _, files in chain.from_iterable(
os.walk(path) for path in args.audio_dirs
):
for f in files: for f in files:
if f.endswith(args.audio_format): if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2: if len(os.path.splitext(f)) != 2:
raise Exception('Expect <utt_id.extension> file name. Got: ', f) raise Exception("Expect <utt_id.extension> file name. Got: ", f)
utt_id = os.path.splitext(f)[0] utt_id = os.path.splitext(f)[0]
if utt_id not in labels: if utt_id not in labels:
continue continue
@ -81,12 +105,17 @@ def main():
utts = {} utts = {}
num_cpu = multiprocessing.cpu_count() num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples} future_to_sample = {
executor.submit(
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
): s
for s in samples
}
for future in concurrent.futures.as_completed(future_to_sample): for future in concurrent.futures.as_completed(future_to_sample):
try: try:
data = future.result() data = future.result()
except Exception as exc: except Exception as exc:
print('generated an exception: ', exc) print("generated an exception: ", exc)
else: else:
utts.update(data) utts.update(data)
json.dump({"utts": utts}, args.output, indent=4) json.dump({"utts": utts}, args.output, indent=4)

View File

@ -8,17 +8,17 @@
Run inference for pre-processed data with a trained model. Run inference for pre-processed data with a trained model.
""" """
import editdistance
import logging import logging
import math import math
import os import os
import sys import sys
import editdistance
import numpy as np import numpy as np
import torch import torch
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.data.data_utils import post_process from fairseq.data.data_utils import post_process
from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig() logging.basicConfig()
@ -52,10 +52,12 @@ output units",
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
) )
parser.add_argument( parser.add_argument(
"--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder" "--w2l-decoder",
choices=["viterbi", "kenlm", "fairseqlm"],
help="use a w2l decoder",
) )
parser.add_argument("--lexicon", help="lexicon for w2l decoder") parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm") parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0) parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100) parser.add_argument("--beam-size-token", type=float, default=100)
@ -87,10 +89,10 @@ def check_args(args):
# assert args.path is not None, "--path required for generation!" # assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!" # assert args.results_path is not None, "--results_path required for generation!"
assert ( assert (
not args.sampling or args.nbest == args.beam not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam" ), "--sampling requires --nbest to be equal to --beam"
assert ( assert (
args.replace_unk is None or args.raw_text args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)" ), "--replace-unk requires a raw text dataset (--raw-text)"
@ -110,7 +112,7 @@ def get_dataset_itr(args, task, models):
def process_predictions( def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
): ):
for hypo in hypos[: min(len(hypos), args.nbest)]: for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
@ -122,16 +124,25 @@ def process_predictions(
if res_files is not None: if res_files is not None:
print( print(
"{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"] "{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
) )
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
tgt_pieces = tgt_dict.string(target_tokens) tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.remove_bpe) tgt_words = post_process(tgt_pieces, args.remove_bpe)
if res_files is not None: if res_files is not None:
print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"]) print(
print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) "{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
)
# only score top hypothesis # only score top hypothesis
if not args.quiet: if not args.quiet:
logger.debug("HYPO:" + hyp_words) logger.debug("HYPO:" + hyp_words)
@ -146,7 +157,7 @@ def process_predictions(
def prepare_result_files(args): def prepare_result_files(args):
def get_res_file(file_prefix): def get_res_file(file_prefix):
if args.num_shards > 1: if args.num_shards > 1:
file_prefix = f'{args.shard_id}_{file_prefix}' file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join( path = os.path.join(
args.results_path, args.results_path,
"{}-{}-{}.txt".format( "{}-{}-{}.txt".format(
@ -166,15 +177,17 @@ def prepare_result_files(args):
} }
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None): def load_models_and_criterions(
filenames, data_path, arg_overrides=None, task=None, model_state=None
):
models = [] models = []
criterions = [] criterions = []
if arg_overrides is None: if arg_overrides is None:
arg_overrides = {} arg_overrides = {}
arg_overrides['wer_args'] = None arg_overrides["wer_args"] = None
arg_overrides['data'] = data_path arg_overrides["data"] = data_path
if filenames is None: if filenames is None:
assert model_state is not None assert model_state is not None
@ -205,8 +218,7 @@ def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=No
def optimize_models(args, use_cuda, models): def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation """Optimize ensemble for generation"""
"""
for model in models: for model in models:
model.make_generation_fast_( model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
@ -229,7 +241,7 @@ class ExistingEmissionsDecoder(object):
emissions = np.stack(self.emissions[ids]) emissions = np.stack(self.emissions[ids])
except: except:
print([x.shape for x in self.emissions[ids]]) print([x.shape for x in self.emissions[ids]])
raise Exception('invalid sizes') raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions) emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions) return self.decoder.decode(emissions)
@ -300,7 +312,9 @@ def main(args, task=None, model_state=None):
return W2lFairseqLMDecoder(args, task.target_dictionary) return W2lFairseqLMDecoder(args, task.target_dictionary)
else: else:
print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment') print(
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args) generator = build_generator(args)
@ -361,7 +375,11 @@ def main(args, task=None, model_state=None):
encoder_out = models[0](**sample["net_input"]) encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]): for i, id in enumerate(sample["id"]):
padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None padding = (
encoder_out["encoder_padding_mask"][i].cpu().numpy()
if encoder_out["encoder_padding_mask"] is not None
else None
)
features[id.item()] = (feat[i], padding) features[id.item()] = (feat[i], padding)
continue continue
hypos = task.inference_step(generator, models, sample, prefix_tokens) hypos = task.inference_step(generator, models, sample, prefix_tokens)
@ -372,20 +390,31 @@ def main(args, task=None, model_state=None):
speaker = None speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)] # id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id id = sample_id
toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :] toks = (
target_tokens = ( sample["target"][i, :]
utils.strip_pad(toks, tgt_dict.pad()).int().cpu() if "target_label" not in sample
else sample["target_label"][i, :]
) )
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions # Process top predictions
errs, length = process_predictions( errs, length = process_predictions(
args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id args,
hypos[i],
None,
tgt_dict,
target_tokens,
res_files,
speaker,
id,
) )
errs_t += errs errs_t += errs
lengths_t += length lengths_t += length
wps_meter.update(num_generated_tokens) wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)}) t.log({"wps": round(wps_meter.avg)})
num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
wer = None wer = None
if args.dump_emissions: if args.dump_emissions:
@ -413,7 +442,7 @@ def main(args, task=None, model_state=None):
gen_timer.sum, gen_timer.sum,
num_sentences / gen_timer.sum, num_sentences / gen_timer.sum,
1.0 / gen_timer.avg, 1.0 / gen_timer.avg,
) )
) )
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer return task, wer
@ -424,6 +453,7 @@ def make_parser():
parser = add_asr_eval_argument(parser) parser = add_asr_eval_argument(parser)
return parser return parser
def cli_main(): def cli_main():
parser = make_parser() parser = make_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)

View File

@ -1,7 +1,8 @@
import importlib import importlib
import os import os
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
model_name = file[:file.find('.py')] model_name = file[: file.find(".py")]
importlib.import_module('examples.speech_recognition.models.' + model_name) importlib.import_module("examples.speech_recognition.models." + model_name)

View File

@ -9,18 +9,22 @@ from collections.abc import Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq import utils from fairseq import utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqEncoderModel, FairseqEncoderModel,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.modules import LinearizedConvolution from fairseq.modules import (
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask LinearizedConvolution,
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock TransformerDecoderLayer,
TransformerEncoderLayer,
VGGBlock,
)
@register_model("asr_vggtransformer") @register_model("asr_vggtransformer")
@ -29,6 +33,7 @@ class VGGTransformerModel(FairseqEncoderDecoderModel):
Transformers with convolutional context for ASR Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660 https://arxiv.org/abs/1904.11660
""" """
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
@ -602,18 +607,22 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]: if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(TransformerDecoderLayer( self.layers.append(
prepare_transformer_decoder_params(*transformer_config[0]) TransformerDecoderLayer(
)) prepare_transformer_decoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)): for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]: if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append( self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0]) Linear(transformer_config[i - 1][0], transformer_config[i][0])
) )
self.layers.append(TransformerDecoderLayer( self.layers.append(
prepare_transformer_decoder_params(*transformer_config[i]) TransformerDecoderLayer(
)) prepare_transformer_decoder_params(*transformer_config[i])
)
)
self.fc_out = Linear(transformer_config[-1][0], vocab_size) self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
@ -713,6 +722,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x = x.transpose(0, 1) x = x.transpose(0, 1)
return x return x
@register_model("asr_vggtransformer_encoder") @register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel): class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder): def __init__(self, encoder):

View File

@ -10,7 +10,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderModel, FairseqEncoderModel,

View File

@ -1,7 +1,8 @@
import importlib import importlib
import os import os
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_"):
task_name = file[:file.find('.py')] task_name = file[: file.find(".py")]
importlib.import_module('examples.speech_recognition.tasks.' + task_name) importlib.import_module("examples.speech_recognition.tasks." + task_name)

View File

@ -9,10 +9,10 @@ import re
import sys import sys
import torch import torch
from fairseq.data import Dictionary
from fairseq.tasks import register_task, LegacyFairseqTask
from examples.speech_recognition.data import AsrDataset from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict): def get_asr_dataset_from_json(data_json_path, tgt_dict):
@ -78,10 +78,20 @@ class SpeechRecognitionTask(LegacyFairseqTask):
parser.add_argument( parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)" "--silence-token", default="\u2581", help="token for silence (used by w2l)"
) )
parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N', parser.add_argument(
help='max number of frames in the source sequence') "--max-source-positions",
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', default=sys.maxsize,
help='max number of tokens in the target sequence') type=int,
metavar="N",
help="max number of frames in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict): def __init__(self, args, tgt_dict):
super().__init__(args) super().__init__(args)

View File

@ -9,16 +9,18 @@
Wav2letter decoders. Wav2letter decoders.
""" """
from collections import namedtuple, deque
import gc import gc
import itertools as it import itertools as it
import numpy as np
import torch
import os.path as osp import os.path as osp
import warnings import warnings
from collections import deque, namedtuple
import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks from fairseq import tasks
from fairseq.utils import apply_to_sample from fairseq.utils import apply_to_sample
from examples.speech_recognition.data.replabels import unpack_replabels
try: try:
from wav2letter.common import create_word_dict, load_words from wav2letter.common import create_word_dict, load_words

View File

@ -4,66 +4,76 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from multiprocessing import cpu_count import csv
import os import os
import os.path as op import os.path as op
from glob import glob
import zipfile import zipfile
import csv
from functools import reduce from functools import reduce
from typing import Dict, Any, List from glob import glob
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank from multiprocessing import cpu_count
from typing import Any, Dict, List
import sentencepiece as sp
from tqdm import tqdm
import numpy as np import numpy as np
import sentencepiece as sp
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
from tqdm import tqdm
UNK_TOKEN, UNK_TOKEN_ID = '<unk>', 3
BOS_TOKEN, BOS_TOKEN_ID = '<s>', 0 UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
EOS_TOKEN, EOS_TOKEN_ID = '</s>', 2 BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
PAD_TOKEN, PAD_TOKEN_ID = '<pad>', 1 EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
def gen_vocab( def gen_vocab(
input_path: str, output_path_prefix: str, model_type='bpe', input_path: str,
vocab_size=1000, output_path_prefix: str,
model_type="bpe",
vocab_size=1000,
): ):
# Train SentencePiece Model # Train SentencePiece Model
arguments = [ arguments = [
f'--input={input_path}', f"--input={input_path}",
f'--model_prefix={output_path_prefix}', f"--model_prefix={output_path_prefix}",
f'--model_type={model_type}', f"--model_type={model_type}",
f'--vocab_size={vocab_size}', f"--vocab_size={vocab_size}",
'--character_coverage=1.0', "--character_coverage=1.0",
f'--num_threads={cpu_count()}', f"--num_threads={cpu_count()}",
f'--unk_id={UNK_TOKEN_ID}', f"--unk_id={UNK_TOKEN_ID}",
f'--bos_id={BOS_TOKEN_ID}', f"--bos_id={BOS_TOKEN_ID}",
f'--eos_id={EOS_TOKEN_ID}', f"--eos_id={EOS_TOKEN_ID}",
f'--pad_id={PAD_TOKEN_ID}' f"--pad_id={PAD_TOKEN_ID}",
] ]
sp.SentencePieceTrainer.Train(' '.join(arguments)) sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary # Export fairseq dictionary
spm = sp.SentencePieceProcessor() spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix + '.model') spm.Load(output_path_prefix + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())} vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \ assert (
vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \ vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \ and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
vocab.get(EOS_TOKEN_ID) == EOS_TOKEN and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
)
vocab = { vocab = {
i: s for i, s in vocab.items() i: s
for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN} if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
} }
with open(output_path_prefix + '.txt', 'w') as f_out: with open(output_path_prefix + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]): for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f'{s} 1\n') f_out.write(f"{s} 1\n")
def extract_fbank_features(waveform, sample_rate, output_path=None, def extract_fbank_features(
n_mel_bins=80, apply_utterance_cmvn=True, waveform,
overwrite=False): sample_rate,
output_path=None,
n_mel_bins=80,
apply_utterance_cmvn=True,
overwrite=False,
):
if output_path is not None and op.exists(output_path) and not overwrite: if output_path is not None and op.exists(output_path) and not overwrite:
return return
@ -74,8 +84,10 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
if features is None: if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None: if features is None:
raise ImportError('Please install pyKaldi or torchaudio to enable ' raise ImportError(
'online filterbank feature extraction') "Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
if apply_utterance_cmvn: if apply_utterance_cmvn:
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True) cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
@ -89,8 +101,8 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
def create_zip(data_root, zip_path): def create_zip(data_root, zip_path):
cwd = os.path.abspath(os.curdir) cwd = os.path.abspath(os.curdir)
os.chdir(data_root) os.chdir(data_root)
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f: with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
for filename in tqdm(glob('*.npy')): for filename in tqdm(glob("*.npy")):
f.write(filename) f.write(filename)
os.chdir(cwd) os.chdir(cwd)
@ -101,69 +113,80 @@ def is_npy_data(data: bytes) -> bool:
def get_zip_manifest(zip_root, zip_filename): def get_zip_manifest(zip_root, zip_filename):
zip_path = op.join(zip_root, zip_filename) zip_path = op.join(zip_root, zip_filename)
with zipfile.ZipFile(zip_path, mode='r') as f: with zipfile.ZipFile(zip_path, mode="r") as f:
info = f.infolist() info = f.infolist()
manifest = {} manifest = {}
for i in tqdm(info): for i in tqdm(info):
utt_id = op.splitext(i.filename)[0] utt_id = op.splitext(i.filename)[0]
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}' manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
with open(zip_path, 'rb') as f: with open(zip_path, "rb") as f:
f.seek(offset) f.seek(offset)
data = f.read(file_size) data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data) assert len(data) > 1 and is_npy_data(data)
return manifest return manifest
def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml', def gen_config_yaml(
specaugment_policy='lb'): data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
assert specaugment_policy in {'lb', 'ld'} ):
assert specaugment_policy in {"lb", "ld"}
data_root = op.abspath(data_root) data_root = op.abspath(data_root)
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename)) writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
writer.set_audio_root(op.abspath(data_root)) writer.set_audio_root(op.abspath(data_root))
writer.set_vocab_filename(spm_filename.replace(".model", ".txt")) writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1) writer.set_input_channels(1)
writer.set_input_feat_per_channel(80) writer.set_input_feat_per_channel(80)
if specaugment_policy == 'lb': if specaugment_policy == "lb":
writer.set_specaugment_lb_policy() writer.set_specaugment_lb_policy()
else: else:
writer.set_specaugment_ld_policy() writer.set_specaugment_ld_policy()
writer.set_bpe_tokenizer( writer.set_bpe_tokenizer(
{'bpe': 'sentencepiece', {
'sentencepiece_model': op.join(data_root, spm_filename)} "bpe": "sentencepiece",
"sentencepiece_model": op.join(data_root, spm_filename),
}
) )
writer.set_feature_transforms('_train', ['specaugment']) writer.set_feature_transforms("_train", ["specaugment"])
writer.flush() writer.flush()
def save_df_to_tsv(dataframe, path): def save_df_to_tsv(dataframe, path):
dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8", dataframe.to_csv(
escapechar='\\', quoting=csv.QUOTE_NONE) path,
sep="\t",
header=True,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
def filter_manifest_df(df, is_train_split=False, extra_filters=None, def filter_manifest_df(
min_n_frames=5, max_n_frames=3000): df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
):
filters = { filters = {
'no speech': df['audio'] == '', "no speech": df["audio"] == "",
f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames, f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
'empty sentence': df['tgt_text'] == '', "empty sentence": df["tgt_text"] == "",
} }
if is_train_split: if is_train_split:
filters[f'long speech (>{max_n_frames} frames)'] = \ filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
df['n_frames'] > max_n_frames
if extra_filters is not None: if extra_filters is not None:
filters.update(extra_filters) filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values()) invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid valid = ~invalid
print( print(
'| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) + "| "
f', total {invalid.sum()} filtered, {valid.sum()} remained.' + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
) )
return df[valid] return df[valid]
class S2TDataConfigWriter(object): class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = 'dict.txt' DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1 DEFAULT_INPUT_CHANNELS = 1
@ -171,48 +194,69 @@ class S2TDataConfigWriter(object):
try: try:
import yaml import yaml
except ImportError: except ImportError:
print('Please install PyYAML to load YAML files for S2T data config') print("Please install PyYAML to load YAML files for S2T data config")
self.yaml = yaml self.yaml = yaml
self.yaml_path = yaml_path self.yaml_path = yaml_path
self.config = {} self.config = {}
def flush(self): def flush(self):
with open(self.yaml_path, 'w') as f: with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f) self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=''): def set_audio_root(self, audio_root=""):
self.config['audio_root'] = audio_root self.config["audio_root"] = audio_root
def set_vocab_filename(self, vocab_filename='dict.txt'): def set_vocab_filename(self, vocab_filename="dict.txt"):
self.config['vocab_filename'] = vocab_filename self.config["vocab_filename"] = vocab_filename
def set_specaugment(self, time_wrap_w: int, freq_mask_n: int, def set_specaugment(
freq_mask_f: int, time_mask_n: int, time_mask_t: int, self,
time_mask_p: float): time_wrap_w: int,
self.config['specaugment'] = { freq_mask_n: int,
'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n, freq_mask_f: int,
'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n, time_mask_n: int,
'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p, time_mask_t: int,
time_mask_p: float,
):
self.config["specaugment"] = {
"time_wrap_W": time_wrap_w,
"freq_mask_N": freq_mask_n,
"freq_mask_F": freq_mask_f,
"time_mask_N": time_mask_n,
"time_mask_T": time_mask_t,
"time_mask_p": time_mask_p,
} }
def set_specaugment_lb_policy(self): def set_specaugment_lb_policy(self):
self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27, self.set_specaugment(
time_mask_n=1, time_mask_t=100, time_mask_p=1.0) time_wrap_w=0,
freq_mask_n=1,
freq_mask_f=27,
time_mask_n=1,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_ld_policy(self): def set_specaugment_ld_policy(self):
self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27, self.set_specaugment(
time_mask_n=2, time_mask_t=100, time_mask_p=1.0) time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=100,
time_mask_p=1.0,
)
def set_input_channels(self, input_channels=1): def set_input_channels(self, input_channels=1):
self.config['input_channels'] = input_channels self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel=80): def set_input_feat_per_channel(self, input_feat_per_channel=80):
self.config['input_feat_per_channel'] = input_feat_per_channel self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config['bpe_tokenizer'] = bpe_tokenizer self.config["bpe_tokenizer"] = bpe_tokenizer
def set_feature_transforms(self, split, transforms: List[str]): def set_feature_transforms(self, split, transforms: List[str]):
if 'transforms' not in self.config: if "transforms" not in self.config:
self.config['transforms'] = {} self.config["transforms"] = {}
self.config['transforms'][split] = transforms self.config["transforms"][split] = transforms

View File

@ -5,30 +5,35 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse import argparse
import csv
import logging import logging
from tempfile import NamedTemporaryFile
import os import os
import os.path as op import os.path as op
import shutil import shutil
from typing import Tuple, Optional from tempfile import NamedTemporaryFile
import csv from typing import Optional, Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
from torch import Tensor
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml, filter_manifest_df
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class CoVoST(Dataset): class CoVoST(Dataset):
@ -44,40 +49,82 @@ class CoVoST(Dataset):
found at root path. (default: ``False``). found at root path. (default: ``False``).
""" """
CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \ CV_URL_TEMPLATE = (
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz" "https://voice-prod-bundler-ee1969a6ce8178826482b88"
COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \ "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz" )
COVOST_URL_TEMPLATE = (
"https://dl.fbaipublicfiles.com/covost/"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
)
VERSIONS = {2} VERSIONS = {2}
SPLITS = ['train', 'dev', 'test'] SPLITS = ["train", "dev", "test"]
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"} CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
XX_EN_LANGUAGES = { XX_EN_LANGUAGES = {
1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn', 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
'zh-CN'], 2: [
2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn', "fr",
'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy'] "de",
"es",
"ca",
"it",
"ru",
"zh-CN",
"pt",
"fa",
"et",
"mn",
"nl",
"tr",
"ar",
"sv-SE",
"lv",
"sl",
"ta",
"ja",
"id",
"cy",
],
} }
EN_XX_LANGUAGES = { EN_XX_LANGUAGES = {
1: [], 1: [],
2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et', 2: [
'id', "de",
'ar', 'ta', 'lv', 'ja'] "tr",
"fa",
"sv-SE",
"mn",
"zh-CN",
"cy",
"ca",
"sl",
"et",
"id",
"ar",
"ta",
"lv",
"ja",
],
} }
def __init__( def __init__(
self, root: str, split: str, source_language: str, self,
target_language: Optional[str] = None, version: int = 2, root: str,
download: bool = False split: str,
source_language: str,
target_language: Optional[str] = None,
version: int = 2,
download: bool = False,
) -> None: ) -> None:
assert version in self.VERSIONS and split in self.SPLITS assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None assert source_language is not None
self.no_translation = (target_language is None) self.no_translation = target_language is None
if not self.no_translation: if not self.no_translation:
assert 'en' in {source_language, target_language} assert "en" in {source_language, target_language}
if source_language == 'en': if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version] assert target_language in self.EN_XX_LANGUAGES[version]
else: else:
assert source_language in self.XX_EN_LANGUAGES[version] assert source_language in self.XX_EN_LANGUAGES[version]
@ -85,51 +132,60 @@ class CoVoST(Dataset):
# Hack here so that we can get "split" column from CoVoST TSV. # Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension # Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split. # to Common Voice train split.
target_language = 'de' if source_language == 'en' else 'en' target_language = "de" if source_language == "en" else "en"
self.root = os.path.join(root, 'raw') self.root = os.path.join(root, "raw")
os.makedirs(self.root, exist_ok=True) os.makedirs(self.root, exist_ok=True)
cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version], cv_url = self.CV_URL_TEMPLATE.format(
lang=source_language) ver=self.CV_VERSION_ID[version], lang=source_language
)
cv_archive = os.path.join(self.root, os.path.basename(cv_url)) cv_archive = os.path.join(self.root, os.path.basename(cv_url))
if download: if download:
if not os.path.isfile(cv_archive): if not os.path.isfile(cv_archive):
download_url(cv_url, self.root, hash_value=None) download_url(cv_url, self.root, hash_value=None)
extract_archive(cv_archive) extract_archive(cv_archive)
covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language, covost_url = self.COVOST_URL_TEMPLATE.format(
tgt_lang=target_language) src_lang=source_language, tgt_lang=target_language
)
covost_archive = os.path.join(self.root, os.path.basename(covost_url)) covost_archive = os.path.join(self.root, os.path.basename(covost_url))
if download: if download:
if not os.path.isfile(covost_archive): if not os.path.isfile(covost_archive):
download_url(covost_url, self.root, hash_value=None) download_url(covost_url, self.root, hash_value=None)
extract_archive(covost_archive) extract_archive(covost_archive)
cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv')) cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
covost_tsv = self.load_from_tsv( covost_tsv = self.load_from_tsv(
os.path.join(self.root, os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
os.path.basename(covost_url).replace('.tar.gz', ''))
) )
df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']], df = pd.merge(
right=covost_tsv[['path', 'translation', 'split']], left=cv_tsv[["path", "sentence", "client_id"]],
how='inner', on='path') right=covost_tsv[["path", "translation", "split"]],
if split == 'train': how="inner",
df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')] on="path",
)
if split == "train":
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else: else:
df = df[df['split'] == split] df = df[df["split"] == split]
self.data = df.to_dict(orient='index').items() self.data = df.to_dict(orient="index").items()
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])] self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
@classmethod @classmethod
def load_from_tsv(cls, path: str): def load_from_tsv(cls, path: str):
return pd.read_csv( return pd.read_csv(
path, sep='\t', header=0, encoding='utf-8', escapechar='\\', path,
quoting=csv.QUOTE_NONE, na_filter=False sep="\t",
header=0,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
na_filter=False,
) )
def __getitem__( def __getitem__(
self, n: int self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]: ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
@ -141,12 +197,12 @@ class CoVoST(Dataset):
sample_id)`` sample_id)``
""" """
data = self.data[n] data = self.data[n]
path = os.path.join(self.root, 'clips', data['path']) path = os.path.join(self.root, "clips", data["path"])
waveform, sample_rate = torchaudio.load(path) waveform, sample_rate = torchaudio.load(path)
sentence = data['sentence'] sentence = data["sentence"]
translation = None if self.no_translation else data['translation'] translation = None if self.no_translation else data["translation"]
speaker_id = data['client_id'] speaker_id = data["client_id"]
_id = data['path'].replace('.mp3', '') _id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int: def __len__(self) -> int:
@ -157,76 +213,82 @@ def process(args):
root = op.join(args.data_root, args.src_lang) root = op.join(args.data_root, args.src_lang)
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
# Extract features # Extract features
feature_root = op.join(root, 'fbank80') feature_root = op.join(root, "fbank80")
os.makedirs(feature_root, exist_ok=True) os.makedirs(feature_root, exist_ok=True)
for split in CoVoST.SPLITS: for split in CoVoST.SPLITS:
print(f'Fetching split {split}...') print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
download=True) print("Extracting log mel filter bank features...")
print('Extracting log mel filter bank features...')
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(waveform, sample_rate, extract_fbank_features(
op.join(feature_root, f'{utt_id}.npy')) waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP # Pack features into ZIP
zip_filename = 'fbank80.zip' zip_filename = "fbank80.zip"
zip_path = op.join(root, zip_filename) zip_path = op.join(root, zip_filename)
print('ZIPing features...') print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...') print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
f'{args.src_lang}/{zip_filename}')
# Generate TSV manifest # Generate TSV manifest
print('Generating manifest...') print("Generating manifest...")
train_text = [] train_text = []
task = f'asr_{args.src_lang}' task = f"asr_{args.src_lang}"
if args.tgt_lang is not None: if args.tgt_lang is not None:
task = f'st_{args.src_lang}_{args.tgt_lang}' task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS: for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest['id'].append(utt_id) manifest["id"].append(utt_id)
manifest['audio'].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000) duration_ms = int(wav.size(1) / sr * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest['tgt_text'].append( manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
src_utt if args.tgt_lang is None else tgt_utt manifest["speaker"].append(speaker_id)
) is_train_split = split.startswith("train")
manifest['speaker'].append(speaker_id)
is_train_split = split.startswith('train')
if is_train_split: if is_train_split:
train_text.extend(manifest['tgt_text']) train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv')) save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
# Generate vocab # Generate vocab
vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size) vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}' spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode='w') as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + '\n') f.write(t + "\n")
gen_vocab(f.name, op.join(root, spm_filename_prefix), gen_vocab(
args.vocab_type, args.vocab_size) f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
)
# Generate config YAML # Generate config YAML
gen_config_yaml(root, spm_filename_prefix + '.model', gen_config_yaml(
yaml_filename=f'config_{task}.yaml', root,
specaugment_policy='lb') spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up # Clean up
shutil.rmtree(feature_root) shutil.rmtree(feature_root)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data-root', '-d', required=True, type=str) parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument('--vocab-type', default='unigram', required=True, parser.add_argument(
type=str, choices=['bpe', 'unigram', 'char']), "--vocab-type",
parser.add_argument('--vocab-size', default=1000, type=int) default="unigram",
parser.add_argument('--src-lang', '-s', required=True, type=str) required=True,
parser.add_argument('--tgt-lang', '-t', type=str) type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args() args = parser.parse_args()
process(args) process(args)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -6,91 +6,114 @@
import argparse import argparse
import logging import logging
from tempfile import NamedTemporaryFile
import os import os
import shutil
import os.path as op import os.path as op
import shutil
from tempfile import NamedTemporaryFile
from tqdm import tqdm
from torchaudio.datasets import LIBRISPEECH
import pandas as pd import pandas as pd
from examples.speech_to_text.data_utils import ( from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, create_zip,
extract_fbank_features, gen_config_yaml extract_fbank_features,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
) )
from torchaudio.datasets import LIBRISPEECH
from tqdm import tqdm
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean', SPLITS = [
'dev-other', 'test-clean', 'test-other'] "train-clean-100",
"train-clean-360",
"train-other-500",
"dev-clean",
"dev-other",
"test-clean",
"test-other",
]
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
def process(args): def process(args):
os.makedirs(args.output_root, exist_ok=True) os.makedirs(args.output_root, exist_ok=True)
# Extract features # Extract features
feature_root = op.join(args.output_root, 'fbank80') feature_root = op.join(args.output_root, "fbank80")
os.makedirs(feature_root, exist_ok=True) os.makedirs(feature_root, exist_ok=True)
for split in SPLITS: for split in SPLITS:
print(f'Fetching split {split}...') print(f"Fetching split {split}...")
dataset = LIBRISPEECH(args.output_root, url=split, download=True) dataset = LIBRISPEECH(args.output_root, url=split, download=True)
print('Extracting log mel filter bank features...') print("Extracting log mel filter bank features...")
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset): for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f'{spk_id}-{chapter_id}-{utt_id}' sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
extract_fbank_features(wav, sample_rate, extract_fbank_features(
op.join(feature_root, f'{sample_id}.npy')) wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
)
# Pack features into ZIP # Pack features into ZIP
zip_filename = 'fbank80.zip' zip_filename = "fbank80.zip"
zip_path = op.join(args.output_root, zip_filename) zip_path = op.join(args.output_root, zip_filename)
print('ZIPing features...') print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...') print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.output_root, zip_filename) zip_manifest = get_zip_manifest(args.output_root, zip_filename)
# Generate TSV manifest # Generate TSV manifest
print('Generating manifest...') print("Generating manifest...")
train_text = [] train_text = []
for split in SPLITS: for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(args.output_root, url=split) dataset = LIBRISPEECH(args.output_root, url=split)
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset): for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f'{spk_id}-{chapter_id}-{utt_id}' sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
manifest['id'].append(sample_id) manifest["id"].append(sample_id)
manifest['audio'].append(zip_manifest[sample_id]) manifest["audio"].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000) duration_ms = int(wav.size(1) / sample_rate * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest['tgt_text'].append(utt) manifest["tgt_text"].append(utt)
manifest['speaker'].append(spk_id) manifest["speaker"].append(spk_id)
save_df_to_tsv(pd.DataFrame.from_dict(manifest), save_df_to_tsv(
op.join(args.output_root, f'{split}.tsv')) pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
if split.startswith('train'): )
train_text.extend(manifest['tgt_text']) if split.startswith("train"):
train_text.extend(manifest["tgt_text"])
# Generate vocab # Generate vocab
vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size) vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}' spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
with NamedTemporaryFile(mode='w') as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + '\n') f.write(t + "\n")
gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix), gen_vocab(
args.vocab_type, args.vocab_size) f.name,
op.join(args.output_root, spm_filename_prefix),
args.vocab_type,
args.vocab_size,
)
# Generate config YAML # Generate config YAML
gen_config_yaml(args.output_root, spm_filename_prefix + '.model', gen_config_yaml(
specaugment_policy='ld') args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
)
# Clean up # Clean up
shutil.rmtree(feature_root) shutil.rmtree(feature_root)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--output-root', '-o', required=True, type=str) parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument('--vocab-type', default='unigram', required=True, parser.add_argument(
type=str, choices=['bpe', 'unigram', 'char']), "--vocab-type",
parser.add_argument('--vocab-size', default=10000, type=int) default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
args = parser.parse_args() args = parser.parse_args()
process(args) process(args)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -6,29 +6,34 @@
import argparse import argparse
import logging import logging
from tempfile import NamedTemporaryFile
import os import os
import os.path as op import os.path as op
import shutil import shutil
from typing import Tuple
from itertools import groupby from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
from tqdm import tqdm
import pandas as pd import pandas as pd
from torch.utils.data import Dataset
import torchaudio import torchaudio
from torch import Tensor
from examples.speech_to_text.data_utils import ( from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, create_zip,
extract_fbank_features, gen_config_yaml, filter_manifest_df extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
) )
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
TASKS = ['asr', 'st'] TASKS = ["asr", "st"]
class MUSTC(Dataset): class MUSTC(Dataset):
@ -37,49 +42,55 @@ class MUSTC(Dataset):
waveform, sample_rate, source utterance, target utterance, speaker_id, waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id utterance_id
""" """
SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru'] SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None: def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES assert split in self.SPLITS and lang in self.LANGUAGES
_root = op.join(root, f'en-{lang}', 'data', split) _root = op.join(root, f"en-{lang}", "data", split)
wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt') wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root) assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
# Load audio segments # Load audio segments
try: try:
import yaml import yaml
except ImportError: except ImportError:
print('Please install PyYAML to load YAML files for ' print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
'the MuST-C dataset') with open(op.join(txt_root, f"{split}.yaml")) as f:
with open(op.join(txt_root, f'{split}.yaml')) as f:
segments = yaml.load(f, Loader=yaml.BaseLoader) segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances # Load source and target utterances
for _lang in ['en', lang]: for _lang in ["en", lang]:
with open(op.join(txt_root, f'{split}.{_lang}')) as f: with open(op.join(txt_root, f"{split}.{_lang}")) as f:
utterances = [r.strip() for r in f] utterances = [r.strip() for r in f]
assert len(segments) == len(utterances) assert len(segments) == len(utterances)
for i, u in enumerate(utterances): for i, u in enumerate(utterances):
segments[i][_lang] = u segments[i][_lang] = u
# Gather info # Gather info
self.data = [] self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']): for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = op.join(wav_root, wav_filename) wav_path = op.join(wav_root, wav_filename)
sample_rate = torchaudio.info(wav_path)[0].rate sample_rate = torchaudio.info(wav_path)[0].rate
seg_group = sorted(_seg_group, key=lambda x: x['offset']) seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment['offset']) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment['duration']) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate)
_id = f'{op.splitext(wav_filename)[0]}_{i}' _id = f"{op.splitext(wav_filename)[0]}_{i}"
self.data.append( self.data.append(
(wav_path, offset, n_frames, sample_rate, segment['en'], (
segment[lang], segment['speaker_id'], _id) wav_path,
offset,
n_frames,
sample_rate,
segment["en"],
segment[lang],
segment["speaker_id"],
_id,
)
) )
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \ wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
self.data[n] waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
waveform, _ = torchaudio.load(wav_path, offset=offset,
num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int: def __len__(self) -> int:
@ -88,85 +99,102 @@ class MUSTC(Dataset):
def process(args): def process(args):
for lang in MUSTC.LANGUAGES: for lang in MUSTC.LANGUAGES:
cur_root = op.join(args.data_root, f'en-{lang}') cur_root = op.join(args.data_root, f"en-{lang}")
if not op.isdir(cur_root): if not op.isdir(cur_root):
print(f'{cur_root} does not exist. Skipped.') print(f"{cur_root} does not exist. Skipped.")
continue continue
# Extract features # Extract features
feature_root = op.join(cur_root, 'fbank80') feature_root = op.join(cur_root, "fbank80")
os.makedirs(feature_root, exist_ok=True) os.makedirs(feature_root, exist_ok=True)
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
print(f'Fetching split {split}...') print(f"Fetching split {split}...")
dataset = MUSTC(args.data_root, lang, split) dataset = MUSTC(args.data_root, lang, split)
print('Extracting log mel filter bank features...') print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(waveform, sample_rate, extract_fbank_features(
op.join(feature_root, f'{utt_id}.npy')) waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP # Pack features into ZIP
zip_filename = 'fbank80.zip' zip_filename = "fbank80.zip"
zip_path = op.join(cur_root, zip_filename) zip_path = op.join(cur_root, zip_filename)
print('ZIPing features...') print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...') print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
f'en-{lang}/{zip_filename}')
# Generate TSV manifest # Generate TSV manifest
print('Generating manifest...') print("Generating manifest...")
train_text = {task: [] for task in TASKS} train_text = {task: [] for task in TASKS}
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
is_train_split = split.startswith('train') is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
text = {task: [] for task in TASKS} text = {task: [] for task in TASKS}
dataset = MUSTC(args.data_root, lang, split) dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest['id'].append(utt_id) manifest["id"].append(utt_id)
manifest['audio'].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000) duration_ms = int(wav.size(1) / sr * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
text['asr'].append(src_utt) text["asr"].append(src_utt)
text['st'].append(tgt_utt) text["st"].append(tgt_utt)
manifest['speaker'].append(speaker_id) manifest["speaker"].append(speaker_id)
if is_train_split: if is_train_split:
for task in TASKS: for task in TASKS:
train_text[task].extend(text[task]) train_text[task].extend(text[task])
for task in TASKS: for task in TASKS:
manifest['tgt_text'] = text[task] manifest["tgt_text"] = text[task]
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv')) save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
# Generate vocab # Generate vocab
for task in TASKS: for task in TASKS:
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
if task == 'st': if task == "st":
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
vocab_size_str = '' if vocab_type == 'char' else str(vocab_size) vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}' spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode='w') as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text[task]: for t in train_text[task]:
f.write(t + '\n') f.write(t + "\n")
gen_vocab(f.name, op.join(cur_root, spm_filename_prefix), gen_vocab(
vocab_type, vocab_size) f.name,
op.join(cur_root, spm_filename_prefix),
vocab_type,
vocab_size,
)
# Generate config YAML # Generate config YAML
gen_config_yaml(cur_root, spm_filename_prefix + '.model', gen_config_yaml(
yaml_filename=f'config_{task}.yaml', cur_root,
specaugment_policy='lb') spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up # Clean up
shutil.rmtree(feature_root) shutil.rmtree(feature_root)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data-root', '-d', required=True, type=str) parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument('--asr-vocab-type', default='unigram', required=True, parser.add_argument(
type=str, choices=['bpe', 'unigram', 'char']), "--asr-vocab-type",
parser.add_argument('--st-vocab-type', default='unigram', required=True, default="unigram",
type=str, choices=['bpe', 'unigram', 'char']), required=True,
parser.add_argument('--asr-vocab-size', default=5000, type=int) type=str,
parser.add_argument('--st-vocab-size', default=8000, type=int) choices=["bpe", "unigram", "char"],
),
parser.add_argument(
"--st-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--asr-vocab-size", default=5000, type=int)
parser.add_argument("--st-vocab-size", default=8000, type=int)
args = parser.parse_args() args = parser.parse_args()
process(args) process(args)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -12,9 +12,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
""" """
import argparse import argparse
from itertools import chain
import sys
import random import random
import sys
from itertools import chain
import numpy as np import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
@ -22,17 +22,21 @@ from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
def main(): def main():
parser = argparse.ArgumentParser(sys.argv[0]) parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('--sys', nargs='*', default='', metavar='FILE', parser.add_argument(
help='path to system output') "--sys", nargs="*", default="", metavar="FILE", help="path to system output"
parser.add_argument('--ref', default='', metavar='FILE', )
help='path to references') parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
parser.add_argument('--output', default='', metavar='FILE', parser.add_argument(
help='print outputs into a pretty format') "--output",
default="",
metavar="FILE",
help="print outputs into a pretty format",
)
args = parser.parse_args() args = parser.parse_args()
if args.sys: if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys) src, tgt, hypos, log_probs = load_sys(args.sys)
print('pairwise BLEU: %.2f' % pairwise(hypos)) print("pairwise BLEU: %.2f" % pairwise(hypos))
if args.output: if args.output:
merge(src, tgt, hypos, log_probs, args.output) merge(src, tgt, hypos, log_probs, args.output)
@ -58,18 +62,18 @@ def load_sys(paths):
# S: source # S: source
# T: target # T: target
# D: detokenized system output # D: detokenized system output
if line.startswith(('S-', 'T-', 'D-')): if line.startswith(("S-", "T-", "D-")):
i = int(line[line.find('-')+1:line.find('\t')]) i = int(line[line.find("-") + 1 : line.find("\t")])
if line.startswith('S-'): if line.startswith("S-"):
src[i] = line.split('\t')[1] src[i] = line.split("\t")[1]
if line.startswith('T-'): if line.startswith("T-"):
tgt[i] = line.split('\t')[1] tgt[i] = line.split("\t")[1]
if line.startswith('D-'): if line.startswith("D-"):
if i not in hypos: if i not in hypos:
hypos[i] = [] hypos[i] = []
log_probs[i] = [] log_probs[i] = []
hypos[i].append(line.split('\t')[2]) hypos[i].append(line.split("\t")[2])
log_probs[i].append(float(line.split('\t')[1])) log_probs[i].append(float(line.split("\t")[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs) return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
@ -79,34 +83,34 @@ def load_ref(path):
src, tgt, refs = [], [], [] src, tgt, refs = [], [], []
i = 0 i = 0
while i < len(lines): while i < len(lines):
if lines[i].startswith('S-'): if lines[i].startswith("S-"):
src.append(lines[i].split('\t')[1].rstrip()) src.append(lines[i].split("\t")[1].rstrip())
i += 1 i += 1
elif lines[i].startswith('T-'): elif lines[i].startswith("T-"):
tgt.append(lines[i].split('\t')[1].rstrip()) tgt.append(lines[i].split("\t")[1].rstrip())
i += 1 i += 1
else: else:
a = [] a = []
while i < len(lines) and lines[i].startswith('R'): while i < len(lines) and lines[i].startswith("R"):
a.append(lines[i].split('\t')[1].rstrip()) a.append(lines[i].split("\t")[1].rstrip())
i += 1 i += 1
refs.append(a) refs.append(a)
return src, tgt, refs return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path): def merge(src, tgt, hypos, log_probs, path):
with open(path, 'w') as f: with open(path, "w") as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs): for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
f.write(s + '\n') f.write(s + "\n")
f.write(t + '\n') f.write(t + "\n")
f.write('\n') f.write("\n")
for h, lp in zip(hs, lps): for h, lp in zip(hs, lps):
f.write('\t%f\t%s\n' % (lp, h.strip())) f.write("\t%f\t%s\n" % (lp, h.strip()))
f.write('------------------------------------------------------\n') f.write("------------------------------------------------------\n")
def corpus_bleu(sys_stream, ref_streams): def corpus_bleu(sys_stream, ref_streams):
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none') bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
return bleu.score return bleu.score
@ -116,9 +120,11 @@ def sentence_bleu(hypothesis, reference):
bleu.counts[i] += 1 bleu.counts[i] += 1
bleu.totals[i] += 1 bleu.totals[i] += 1
bleu = compute_bleu( bleu = compute_bleu(
bleu.counts, bleu.totals, bleu.counts,
bleu.sys_len, bleu.ref_len, bleu.totals,
smooth_method='exp', bleu.sys_len,
bleu.ref_len,
smooth_method="exp",
) )
return bleu.score return bleu.score
@ -150,7 +156,7 @@ def multi_ref(refs, hypos):
best = [k for k in range(len(rs)) if s[k] == s[j]] best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best)) a.add(random.choice(best))
ref_cnt += len(a) ref_cnt += len(a)
print('#refs covered: %.2f' % (ref_cnt / len(refs))) print("#refs covered: %.2f" % (ref_cnt / len(refs)))
# transpose refs and hypos # transpose refs and hypos
refs = list(zip(*refs)) refs = list(zip(*refs))
@ -160,33 +166,32 @@ def multi_ref(refs, hypos):
k = len(hypos) k = len(hypos)
m = len(refs) m = len(refs)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)] flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
duplicated_refs = [ duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
[ref for ref in refs_i for _ in range(k)]
for refs_i in refs
]
loo_bleus = [] loo_bleus = []
for held_out_ref in range(m): for held_out_ref in range(m):
remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:] remaining_refs = (
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
)
assert len(remaining_refs) == m - 1 assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs)) loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus)) print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
def intra_ref(refs): def intra_ref(refs):
print('ref pairwise BLEU: %.2f' % pairwise(refs)) print("ref pairwise BLEU: %.2f" % pairwise(refs))
refs = list(zip(*refs)) refs = list(zip(*refs))
m = len(refs) m = len(refs)
concat_h = [] concat_h = []
concat_rest = [[] for j in range(m - 1)] concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs): for i, h in enumerate(refs):
rest = refs[:i] + refs[i+1:] rest = refs[:i] + refs[i + 1 :]
concat_h.append(h) concat_h.append(h)
for j in range(m - 1): for j in range(m - 1):
concat_rest[j].extend(rest[j]) concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h)) concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest) bleu = corpus_bleu(concat_h, concat_rest)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu) print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -21,6 +21,6 @@ class LogSumExpMoE(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
posterior, = ctx.saved_tensors (posterior,) = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None return grad_logp, None, None

View File

@ -26,15 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
def forward(self, encoder_out): def forward(self, encoder_out):
if not ( if not (
hasattr(encoder_out, 'encoder_out') hasattr(encoder_out, "encoder_out")
and hasattr(encoder_out, 'encoder_padding_mask') and hasattr(encoder_out, "encoder_padding_mask")
and encoder_out.encoder_out.size(2) == self.embed_dim and encoder_out.encoder_out.size(2) == self.embed_dim
): ):
raise ValueError('Unexpected format for encoder_out') raise ValueError("Unexpected format for encoder_out")
# mean pooling over time # mean pooling over time
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0 encoder_out[encoder_padding_mask] = 0

View File

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from fairseq import metrics, utils from fairseq import metrics, utils
from fairseq.tasks import register_task from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask from fairseq.tasks.translation import TranslationTask
@ -13,7 +12,7 @@ from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork from .mean_pool_gating_network import MeanPoolGatingNetwork
@register_task('translation_moe') @register_task("translation_moe")
class TranslationMoETask(TranslationTask): class TranslationMoETask(TranslationTask):
""" """
Translation task for Mixture of Experts (MoE) models. Translation task for Mixture of Experts (MoE) models.
@ -58,19 +57,19 @@ class TranslationMoETask(TranslationTask):
# fmt: on # fmt: on
def __init__(self, args, src_dict, tgt_dict): def __init__(self, args, src_dict, tgt_dict):
if args.method == 'sMoElp': if args.method == "sMoElp":
# soft MoE with learned prior # soft MoE with learned prior
self.uniform_prior = False self.uniform_prior = False
self.hard_selection = False self.hard_selection = False
elif args.method == 'sMoEup': elif args.method == "sMoEup":
# soft MoE with uniform prior # soft MoE with uniform prior
self.uniform_prior = True self.uniform_prior = True
self.hard_selection = False self.hard_selection = False
elif args.method == 'hMoElp': elif args.method == "hMoElp":
# hard MoE with learned prior # hard MoE with learned prior
self.uniform_prior = False self.uniform_prior = False
self.hard_selection = True self.hard_selection = True
elif args.method == 'hMoEup': elif args.method == "hMoEup":
# hard MoE with uniform prior # hard MoE with uniform prior
self.uniform_prior = True self.uniform_prior = True
self.hard_selection = True self.hard_selection = True
@ -78,50 +77,56 @@ class TranslationMoETask(TranslationTask):
# add indicator tokens for each expert # add indicator tokens for each expert
for i in range(args.num_experts): for i in range(args.num_experts):
# add to both dictionaries in case we're sharing embeddings # add to both dictionaries in case we're sharing embeddings
src_dict.add_symbol('<expert_{}>'.format(i)) src_dict.add_symbol("<expert_{}>".format(i))
tgt_dict.add_symbol('<expert_{}>'.format(i)) tgt_dict.add_symbol("<expert_{}>".format(i))
super().__init__(args, src_dict, tgt_dict) super().__init__(args, src_dict, tgt_dict)
def build_model(self, args): def build_model(self, args):
from fairseq import models from fairseq import models
model = models.build_model(args, self) model = models.build_model(args, self)
if not self.uniform_prior and not hasattr(model, 'gating_network'): if not self.uniform_prior and not hasattr(model, "gating_network"):
if self.args.mean_pool_gating_network: if self.args.mean_pool_gating_network:
if getattr(args, 'mean_pool_gating_network_encoder_dim', None): if getattr(args, "mean_pool_gating_network_encoder_dim", None):
encoder_dim = args.mean_pool_gating_network_encoder_dim encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, 'encoder_embed_dim', None): elif getattr(args, "encoder_embed_dim", None):
# assume that encoder_embed_dim is the encoder's output dimension # assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim encoder_dim = args.encoder_embed_dim
else: else:
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim') raise ValueError(
"Must specify --mean-pool-gating-network-encoder-dim"
)
if getattr(args, 'mean_pool_gating_network_dropout', None): if getattr(args, "mean_pool_gating_network_dropout", None):
dropout = args.mean_pool_gating_network_dropout dropout = args.mean_pool_gating_network_dropout
elif getattr(args, 'dropout', None): elif getattr(args, "dropout", None):
dropout = args.dropout dropout = args.dropout
else: else:
raise ValueError('Must specify --mean-pool-gating-network-dropout') raise ValueError("Must specify --mean-pool-gating-network-dropout")
model.gating_network = MeanPoolGatingNetwork( model.gating_network = MeanPoolGatingNetwork(
encoder_dim, args.num_experts, dropout, encoder_dim,
args.num_experts,
dropout,
) )
else: else:
raise ValueError( raise ValueError(
'translation_moe task with learned prior requires the model to ' "translation_moe task with learned prior requires the model to "
'have a gating network; try using --mean-pool-gating-network' "have a gating network; try using --mean-pool-gating-network"
) )
return model return model
def expert_index(self, i): def expert_index(self, i):
return i + self.tgt_dict.index('<expert_0>') return i + self.tgt_dict.index("<expert_0>")
def _get_loss(self, sample, model, criterion): def _get_loss(self, sample, model, criterion):
assert hasattr(criterion, 'compute_loss'), \ assert hasattr(
'translation_moe task requires the criterion to implement the compute_loss() method' criterion, "compute_loss"
), "translation_moe task requires the criterion to implement the compute_loss() method"
k = self.args.num_experts k = self.args.num_experts
bsz = sample['target'].size(0) bsz = sample["target"].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k): def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder( net_output = model.decoder(
@ -134,20 +139,22 @@ class TranslationMoETask(TranslationTask):
def get_lprob_yz(winners=None): def get_lprob_yz(winners=None):
encoder_out = model.encoder( encoder_out = model.encoder(
src_tokens=sample['net_input']['src_tokens'], src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample['net_input']['src_lengths'], src_lengths=sample["net_input"]["src_lengths"],
) )
if winners is None: if winners is None:
lprob_y = [] lprob_y = []
for i in range(k): for i in range(k):
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone() prev_output_tokens_k = sample["net_input"][
"prev_output_tokens"
].clone()
assert not prev_output_tokens_k.requires_grad assert not prev_output_tokens_k.requires_grad
prev_output_tokens_k[:, 0] = self.expert_index(i) prev_output_tokens_k[:, 0] = self.expert_index(i)
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k)) lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
else: else:
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone() prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
prev_output_tokens_k[:, 0] = self.expert_index(winners) prev_output_tokens_k[:, 0] = self.expert_index(winners)
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
@ -177,17 +184,21 @@ class TranslationMoETask(TranslationTask):
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = loss.sum() loss = loss.sum()
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = { logging_output = {
'loss': utils.item(loss.data), "loss": utils.item(loss.data),
'ntokens': sample['ntokens'], "ntokens": sample["ntokens"],
'nsentences': bsz, "nsentences": bsz,
'sample_size': sample_size, "sample_size": sample_size,
'posterior': prob_z_xy.float().sum(dim=0).cpu(), "posterior": prob_z_xy.float().sum(dim=0).cpu(),
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train() model.train()
loss, sample_size, logging_output = self._get_loss(sample, model, criterion) loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
if ignore_grad: if ignore_grad:
@ -201,7 +212,15 @@ class TranslationMoETask(TranslationTask):
loss, sample_size, logging_output = self._get_loss(sample, model, criterion) loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None): def inference_step(
self,
generator,
models,
sample,
prefix_tokens=None,
expert=None,
constraints=None,
):
expert = expert or self.args.gen_expert expert = expert or self.args.gen_expert
with torch.no_grad(): with torch.no_grad():
return generator.generate( return generator.generate(
@ -215,6 +234,6 @@ class TranslationMoETask(TranslationTask):
def reduce_metrics(self, logging_outputs, criterion): def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion) super().reduce_metrics(logging_outputs, criterion)
metrics.log_scalar( metrics.log_scalar(
'posterior', "posterior",
sum(log['posterior'] for log in logging_outputs if 'posterior' in log) sum(log["posterior"] for log in logging_outputs if "posterior" in log),
) )

View File

@ -4,37 +4,38 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse import argparse
import numpy as np
import sys import sys
import numpy as np
aggregate_funcs = { aggregate_funcs = {
'std': np.std, "std": np.std,
'var': np.var, "var": np.var,
'median': np.median, "median": np.median,
'mean': np.mean, "mean": np.mean,
'min': np.min, "min": np.min,
'max': np.max, "max": np.max,
} }
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_file', required=True, type=str) parser.add_argument("-i", "--input_file", required=True, type=str)
parser.add_argument('-n', '--repeat_times', required=True, type=int) parser.add_argument("-n", "--repeat_times", required=True, type=int)
parser.add_argument('-o', '--output_file', required=False) parser.add_argument("-o", "--output_file", required=False)
parser.add_argument('-f', '--func', required=False, default='mean') parser.add_argument("-f", "--func", required=False, default="mean")
args = parser.parse_args() args = parser.parse_args()
stream = open(args.output_file, 'w') if args.output_file else sys.stdout stream = open(args.output_file, "w") if args.output_file else sys.stdout
segment_scores = [] segment_scores = []
for line in open(args.input_file): for line in open(args.input_file):
segment_scores.append(float(line.strip())) segment_scores.append(float(line.strip()))
if len(segment_scores) == args.repeat_times: if len(segment_scores) == args.repeat_times:
stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores))) stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
segment_scores = [] segment_scores = []
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -4,14 +4,13 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse import argparse
import os
import sys
import subprocess
import tempfile
import math import math
import os
from itertools import combinations import subprocess
import sys
import tempfile
from collections import defaultdict from collections import defaultdict
from itertools import combinations
def read_translations(path, n_repeats): def read_translations(path, n_repeats):
@ -19,7 +18,7 @@ def read_translations(path, n_repeats):
segment_translations = [] segment_translations = []
translations = defaultdict(list) translations = defaultdict(list)
for line in open(path): for line in open(path):
segment_translations.append(' '.join(line.split())) segment_translations.append(" ".join(line.split()))
if len(segment_translations) == n_repeats: if len(segment_translations) == n_repeats:
translations[segment_counter] = segment_translations translations[segment_counter] = segment_translations
segment_translations = [] segment_translations = []
@ -30,42 +29,55 @@ def read_translations(path, n_repeats):
def generate_input(translations, n_repeats): def generate_input(translations, n_repeats):
_, ref_path = tempfile.mkstemp() _, ref_path = tempfile.mkstemp()
_, mt_path = tempfile.mkstemp() _, mt_path = tempfile.mkstemp()
ref_fh = open(ref_path, 'w') ref_fh = open(ref_path, "w")
mt_fh = open(mt_path, 'w') mt_fh = open(mt_path, "w")
for segid in sorted(translations.keys()): for segid in sorted(translations.keys()):
assert len(translations[segid]) == n_repeats assert len(translations[segid]) == n_repeats
indexes = combinations(range(n_repeats), 2) indexes = combinations(range(n_repeats), 2)
for idx1, idx2 in indexes: for idx1, idx2 in indexes:
mt_fh.write(translations[segid][idx1].strip() + '\n') mt_fh.write(translations[segid][idx1].strip() + "\n")
ref_fh.write(translations[segid][idx2].strip() + '\n') ref_fh.write(translations[segid][idx2].strip() + "\n")
sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path)) sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
return ref_path, mt_path return ref_path, mt_path
def run_meteor(ref_path, mt_path, metric_path, lang='en'): def run_meteor(ref_path, mt_path, metric_path, lang="en"):
_, out_path = tempfile.mkstemp() _, out_path = tempfile.mkstemp()
subprocess.call([ subprocess.call(
'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path, [
'-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R "java",
'-norm', "-Xmx2G",
'-l', lang], stdout=open(out_path, 'w')) "-jar",
metric_path,
mt_path,
ref_path,
"-p",
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
"-norm",
"-l",
lang,
],
stdout=open(out_path, "w"),
)
os.remove(ref_path) os.remove(ref_path)
os.remove(mt_path) os.remove(mt_path)
sys.stderr.write('\nSaved Meteor output to %s' % out_path) sys.stderr.write("\nSaved Meteor output to %s" % out_path)
return out_path return out_path
def read_output(meteor_output_path, n_repeats): def read_output(meteor_output_path, n_repeats):
n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2)) n_combinations = math.factorial(n_repeats) / (
math.factorial(2) * math.factorial(n_repeats - 2)
)
raw_scores = [] raw_scores = []
average_scores = [] average_scores = []
for line in open(meteor_output_path): for line in open(meteor_output_path):
if not line.startswith('Segment '): if not line.startswith("Segment "):
continue continue
score = float(line.strip().split('\t')[1]) score = float(line.strip().split("\t")[1])
raw_scores.append(score) raw_scores.append(score)
if len(raw_scores) == n_combinations: if len(raw_scores) == n_combinations:
average_scores.append(sum(raw_scores)/n_combinations) average_scores.append(sum(raw_scores) / n_combinations)
raw_scores = [] raw_scores = []
os.remove(meteor_output_path) os.remove(meteor_output_path)
return average_scores return average_scores
@ -73,25 +85,25 @@ def read_output(meteor_output_path, n_repeats):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input') parser.add_argument("-i", "--input")
parser.add_argument('-n', '--repeat_times', type=int) parser.add_argument("-n", "--repeat_times", type=int)
parser.add_argument('-m', '--meteor') parser.add_argument("-m", "--meteor")
parser.add_argument('-o', '--output') parser.add_argument("-o", "--output")
args = parser.parse_args() args = parser.parse_args()
translations = read_translations(args.infile, args.repetitions) translations = read_translations(args.infile, args.repetitions)
sys.stderr.write('\nGenerating input for Meteor...') sys.stderr.write("\nGenerating input for Meteor...")
ref_path, mt_path = generate_input(translations, args.repetitions) ref_path, mt_path = generate_input(translations, args.repetitions)
sys.stderr.write('\nRunning Meteor...') sys.stderr.write("\nRunning Meteor...")
out_path = run_meteor(ref_path, mt_path, args.meteor) out_path = run_meteor(ref_path, mt_path, args.meteor)
sys.stderr.write('\nReading output...') sys.stderr.write("\nReading output...")
scores = read_output(out_path, args.repetitions) scores = read_output(out_path, args.repetitions)
sys.stderr.write('\nWriting results...') sys.stderr.write("\nWriting results...")
with open(args.output, 'w') as o: with open(args.output, "w") as o:
for scr in scores: for scr in scores:
o.write('{}\n'.format(scr)) o.write("{}\n".format(scr))
o.close() o.close()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -8,21 +8,21 @@ import sys
def _normalize_spaces(line): def _normalize_spaces(line):
return ' '.join(line.split()) return " ".join(line.split())
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_file', required=True, type=str) parser.add_argument("-i", "--input_file", required=True, type=str)
parser.add_argument('-n', '--repeat_times', required=True, type=int) parser.add_argument("-n", "--repeat_times", required=True, type=int)
parser.add_argument('-o', '--output_file', required=False, type=str) parser.add_argument("-o", "--output_file", required=False, type=str)
args = parser.parse_args() args = parser.parse_args()
stream = open(args.output_file, 'w') if args.output_file else sys.stdout stream = open(args.output_file, "w") if args.output_file else sys.stdout
for line in open(args.input_file): for line in open(args.input_file):
for _ in range(args.repeat_times): for _ in range(args.repeat_times):
stream.write(_normalize_spaces(line) + '\n') stream.write(_normalize_spaces(line) + "\n")
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -8,30 +8,31 @@
Helper script to pre-compute embeddings for a wav2letter++ dataset Helper script to pre-compute embeddings for a wav2letter++ dataset
""" """
import argparse
import glob
import os
import os.path as osp
import pprint import pprint
import glob, os, argparse
import soundfile as sf
import torch import torch
import tqdm
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
from torch import nn from torch import nn
from torch.utils.data import DataLoader
try: try:
import tqdm import tqdm
except: except:
print("Install tqdm to use --log-format=tqdm") print("Install tqdm to use --log-format=tqdm")
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
import tqdm
import soundfile as sf
from torch.utils.data import DataLoader
import os.path as osp
class FilesDataset: class FilesDataset:
def __init__(self, files, labels): def __init__(self, files, labels):
self.files = files self.files = files
if labels and osp.exists(labels): if labels and osp.exists(labels):
with open(labels, 'r') as lbl_f: with open(labels, "r") as lbl_f:
self.labels = [line.rstrip() for line in lbl_f] self.labels = [line.rstrip() for line in lbl_f]
else: else:
self.labels = labels self.labels = labels
@ -50,7 +51,7 @@ class FilesDataset:
if self.labels: if self.labels:
if isinstance(self.labels, str): if isinstance(self.labels, str):
lbl_file = osp.splitext(fname)[0] + "." + self.labels lbl_file = osp.splitext(fname)[0] + "." + self.labels
with open(lbl_file, 'r') as lblf: with open(lbl_file, "r") as lblf:
lbls = lblf.readline() lbls = lblf.readline()
assert lbls is not None assert lbls is not None
else: else:
@ -116,24 +117,24 @@ class DatasetWriter:
assert len(files) > 0 assert len(files) > 0
if self.args.shard is not None: if self.args.shard is not None:
files = files[self.args.shard::self.args.num_shards] files = files[self.args.shard :: self.args.num_shards]
lbls = [] lbls = []
with open(self.data_file(split), 'w') as srcf: with open(self.data_file(split), "w") as srcf:
for line, lbl in self.iterate(files): for line, lbl in self.iterate(files):
print(line, file=srcf) print(line, file=srcf)
if self.args.labels: if self.args.labels:
lbls.append(lbl + '\n') lbls.append(lbl + "\n")
if self.args.labels: if self.args.labels:
assert all(a is not None for a in lbls) assert all(a is not None for a in lbls)
with open(self.lbl_file(split), 'w') as lblf: with open(self.lbl_file(split), "w") as lblf:
lblf.writelines(lbls) lblf.writelines(lbls)
def iterate(self, files): def iterate(self, files):
data = self.load_data(files) data = self.load_data(files)
for samples in tqdm.tqdm(data, total=len(files)//32): for samples in tqdm.tqdm(data, total=len(files) // 32):
for wav, lbl in samples: for wav, lbl in samples:
x = wav.unsqueeze(0).float().cuda() x = wav.unsqueeze(0).float().cuda()
@ -162,7 +163,6 @@ class DatasetWriter:
idx = torch.cat(result, dim=0) idx = torch.cat(result, dim=0)
yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
def lbl_file(self, name): def lbl_file(self, name):
shard_part = "" if self.args.shard is None else f".{self.args.shard}" shard_part = "" if self.args.shard is None else f".{self.args.shard}"
return osp.join(self.output_dir, f"{name}.lbl{shard_part}") return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
@ -230,7 +230,9 @@ class DatasetWriter:
self.process_splits() self.process_splits()
if hasattr(self.model.feature_extractor, "vars") and (self.args.shard is None or self.args.shard == 0): if hasattr(self.model.feature_extractor, "vars") and (
self.args.shard is None or self.args.shard == 0
):
vars = ( vars = (
self.model.feature_extractor.vars.view( self.model.feature_extractor.vars.view(
self.model.feature_extractor.banks, self.model.feature_extractor.banks,
@ -248,4 +250,4 @@ if __name__ == "__main__":
write_data = DatasetWriter() write_data = DatasetWriter()
write_data() write_data()
print("Done.") print("Done.")

View File

@ -14,13 +14,12 @@ import os
from shutil import copy from shutil import copy
import h5py import h5py
import soundfile as sf
import numpy as np import numpy as np
import soundfile as sf
import torch import torch
from torch import nn
import tqdm import tqdm
from fairseq.models.wav2vec.wav2vec import Wav2VecModel from fairseq.models.wav2vec.wav2vec import Wav2VecModel
from torch import nn
def read_audio(fname): def read_audio(fname):
@ -33,7 +32,6 @@ def read_audio(fname):
class PretrainedWav2VecModel(nn.Module): class PretrainedWav2VecModel(nn.Module):
def __init__(self, fname): def __init__(self, fname):
super().__init__() super().__init__()
@ -55,32 +53,33 @@ class PretrainedWav2VecModel(nn.Module):
class EmbeddingWriterConfig(argparse.ArgumentParser): class EmbeddingWriterConfig(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super().__init__("Pre-compute embeddings for wav2letter++ datasets") super().__init__("Pre-compute embeddings for wav2letter++ datasets")
kwargs = {"action": "store", "type": str, "required": True} kwargs = {"action": "store", "type": str, "required": True}
self.add_argument("--input", "-i", self.add_argument("--input", "-i", help="Input Directory", **kwargs)
help="Input Directory", **kwargs) self.add_argument("--output", "-o", help="Output Directory", **kwargs)
self.add_argument("--output", "-o", self.add_argument("--model", help="Path to model checkpoint", **kwargs)
help="Output Directory", **kwargs) self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs)
self.add_argument("--model", self.add_argument(
help="Path to model checkpoint", **kwargs) "--ext", default="wav", required=False, help="Audio file extension"
self.add_argument("--split", )
help="Dataset Splits", nargs='+', **kwargs)
self.add_argument("--ext", default="wav", required=False,
help="Audio file extension")
self.add_argument("--no-copy-labels", action="store_true", self.add_argument(
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.") "--no-copy-labels",
self.add_argument("--use-feat", action="store_true", action="store_true",
help="Use the feature vector ('z') instead of context vector ('c') for features") help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.",
self.add_argument("--gpu", )
help="GPU to use", default=0, type=int) self.add_argument(
"--use-feat",
action="store_true",
help="Use the feature vector ('z') instead of context vector ('c') for features",
)
self.add_argument("--gpu", help="GPU to use", default=0, type=int)
class Prediction(): class Prediction:
""" Lightweight wrapper around a fairspeech embedding model """ """ Lightweight wrapper around a fairspeech embedding model """
def __init__(self, fname, gpu=0): def __init__(self, fname, gpu=0):
@ -95,7 +94,7 @@ class Prediction():
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
class H5Writer(): class H5Writer:
""" Write features as hdf5 file in wav2letter++ compatible format """ """ Write features as hdf5 file in wav2letter++ compatible format """
def __init__(self, fname): def __init__(self, fname):
@ -112,7 +111,7 @@ class H5Writer():
class EmbeddingDatasetWriter(object): class EmbeddingDatasetWriter(object):
""" Given a model and a wav2letter++ dataset, pre-compute and store embeddings """Given a model and a wav2letter++ dataset, pre-compute and store embeddings
Args: Args:
input_root, str : input_root, str :
@ -123,13 +122,17 @@ class EmbeddingDatasetWriter(object):
Dataset split Dataset split
""" """
def __init__(self, input_root, output_root, split, def __init__(
model_fname, self,
extension="wav", input_root,
gpu=0, output_root,
verbose=False, split,
use_feat=False, model_fname,
): extension="wav",
gpu=0,
verbose=False,
use_feat=False,
):
assert os.path.exists(model_fname) assert os.path.exists(model_fname)
@ -143,8 +146,9 @@ class EmbeddingDatasetWriter(object):
self.extension = extension self.extension = extension
self.use_feat = use_feat self.use_feat = use_feat
assert os.path.exists(self.input_path), \ assert os.path.exists(self.input_path), "Input path '{}' does not exist".format(
"Input path '{}' does not exist".format(self.input_path) self.input_path
)
def _progress(self, iterable, **kwargs): def _progress(self, iterable, **kwargs):
if self.verbose: if self.verbose:
@ -176,7 +180,11 @@ class EmbeddingDatasetWriter(object):
def copy_labels(self): def copy_labels(self):
self.require_output_path() self.require_output_path()
labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*")))) labels = list(
filter(
lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))
)
)
for fname in tqdm.tqdm(labels): for fname in tqdm.tqdm(labels):
copy(fname, self.output_path) copy(fname, self.output_path)
@ -191,10 +199,16 @@ class EmbeddingDatasetWriter(object):
paths = self.input_fnames paths = self.input_fnames
fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \ fnames_context = map(
map(os.path.basename, paths)) lambda x: os.path.join(
self.output_path, x.replace("." + self.extension, ".h5context")
),
map(os.path.basename, paths),
)
for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)): for name, target_fname in self._progress(
zip(paths, fnames_context), total=len(self)
):
wav, sr = read_audio(name) wav, sr = read_audio(name)
z, c = self.model(wav) z, c = self.model(wav)
feat = z if self.use_feat else c feat = z if self.use_feat else c
@ -204,7 +218,8 @@ class EmbeddingDatasetWriter(object):
def __repr__(self): def __repr__(self):
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
n_files=len(self), **self.__dict__) n_files=len(self), **self.__dict__
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -10,32 +10,50 @@ Data pre-processing: build vocabularies and binarize training data.
import argparse import argparse
import glob import glob
import os import os
import soundfile
import random import random
import soundfile
def get_parser(): def get_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index') parser.add_argument(
parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D', "root", metavar="DIR", help="root directory containing flac files to index"
help='percentage of data to use as validation set (between 0 and 1)') )
parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory') parser.add_argument(
parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for') "--valid-percent",
parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed') default=0.01,
parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG', type=float,
help='if set, path must contain this substring for a file to be included in the manifest') metavar="D",
help="percentage of data to use as validation set (between 0 and 1)",
)
parser.add_argument(
"--dest", default=".", type=str, metavar="DIR", help="output directory"
)
parser.add_argument(
"--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
)
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
parser.add_argument(
"--path-must-contain",
default=None,
type=str,
metavar="FRAG",
help="if set, path must contain this substring for a file to be included in the manifest",
)
return parser return parser
def main(args): def main(args):
assert args.valid_percent >= 0 and args.valid_percent <= 1. assert args.valid_percent >= 0 and args.valid_percent <= 1.0
dir_path = os.path.realpath(args.root) dir_path = os.path.realpath(args.root)
search_path = os.path.join(dir_path, '**/*.' + args.ext) search_path = os.path.join(dir_path, "**/*." + args.ext)
rand = random.Random(args.seed) rand = random.Random(args.seed)
with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open( with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open(
os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f: os.path.join(args.dest, "valid.tsv"), "w"
) as valid_f:
print(dir_path, file=train_f) print(dir_path, file=train_f)
print(dir_path, file=valid_f) print(dir_path, file=valid_f)
@ -47,10 +65,12 @@ def main(args):
frames = soundfile.info(fname).frames frames = soundfile.info(fname).frames
dest = train_f if rand.random() > args.valid_percent else valid_f dest = train_f if rand.random() > args.valid_percent else valid_f
print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest) print(
"{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
)
if __name__ == '__main__': if __name__ == "__main__":
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -4,16 +4,17 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
"""isort:skip_file""" """isort:skip_file"""
__all__ = ['pdb'] __all__ = ["pdb"]
__version__ = '1.0.0a0' __version__ = "1.0.0a0"
import sys import sys
# backwards compatibility to support `from fairseq.meters import AverageMeter` # backwards compatibility to support `from fairseq.meters import AverageMeter`
from fairseq.logging import meters, metrics, progress_bar # noqa from fairseq.logging import meters, metrics, progress_bar # noqa
sys.modules['fairseq.meters'] = meters
sys.modules['fairseq.metrics'] = metrics sys.modules["fairseq.meters"] = meters
sys.modules['fairseq.progress_bar'] = progress_bar sys.modules["fairseq.metrics"] = metrics
sys.modules["fairseq.progress_bar"] = progress_bar
import fairseq.criterions # noqa import fairseq.criterions # noqa
import fairseq.models # noqa import fairseq.models # noqa

View File

@ -4,9 +4,4 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# import models/tasks to register them # import models/tasks to register them
from . import ( # noqa from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
dummy_lm,
dummy_masked_lm,
dummy_model,
dummy_mt,
)

View File

@ -7,25 +7,27 @@ import logging
import numpy as np import numpy as np
import torch import torch
from fairseq.data import Dictionary, FairseqDataset from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_task('dummy_lm') @register_task("dummy_lm")
class DummyLMTask(LegacyFairseqTask): class DummyLMTask(LegacyFairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49996, type=int) parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument('--dataset-size', default=100000, type=int) parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int, parser.add_argument(
help='max number of total tokens over all segments ' "--tokens-per-sample",
'per sample for BERT dataset') default=512,
type=int,
help="max number of total tokens over all segments "
"per sample for BERT dataset",
)
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(args) super().__init__(args)
@ -44,8 +46,8 @@ class DummyLMTask(LegacyFairseqTask):
"""Setup the task. """ """Setup the task. """
dictionary = Dictionary() dictionary = Dictionary()
for i in range(args.dict_size): for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i)) dictionary.add_symbol("word{}".format(i))
logger.info('dictionary: {} types'.format(len(dictionary))) logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs): def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@ -59,16 +61,16 @@ class DummyLMTask(LegacyFairseqTask):
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
self.datasets[split] = DummyDataset( self.datasets[split] = DummyDataset(
{ {
'id': 1, "id": 1,
'net_input': { "net_input": {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full( "src_lengths": torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long (bsz,), self.args.tokens_per_sample, dtype=torch.long
), ),
}, },
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]), "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz, "nsentences": bsz,
'ntokens': bsz * self.args.tokens_per_sample, "ntokens": bsz * self.args.tokens_per_sample,
}, },
num_items=self.args.dataset_size, num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample, item_size=self.args.tokens_per_sample,
@ -84,7 +86,6 @@ class DummyLMTask(LegacyFairseqTask):
class DummyDataset(FairseqDataset): class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size): def __init__(self, batch, num_items, item_size):
super().__init__() super().__init__()
self.batch = batch self.batch = batch

View File

@ -7,32 +7,34 @@ import logging
import numpy as np import numpy as np
import torch import torch
from fairseq.data import Dictionary, FairseqDataset from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_task('dummy_masked_lm') @register_task("dummy_masked_lm")
class DummyMaskedLMTask(LegacyFairseqTask): class DummyMaskedLMTask(LegacyFairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49995, type=int) parser.add_argument("--dict-size", default=49995, type=int)
parser.add_argument('--dataset-size', default=100000, type=int) parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int, parser.add_argument(
help='max number of total tokens over all segments ' "--tokens-per-sample",
'per sample for BERT dataset') default=512,
type=int,
help="max number of total tokens over all segments "
"per sample for BERT dataset",
)
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(args) super().__init__(args)
self.dictionary = dictionary self.dictionary = dictionary
# add mask token # add mask token
self.mask_idx = dictionary.add_symbol('<mask>') self.mask_idx = dictionary.add_symbol("<mask>")
dictionary.pad_to_multiple_(8) # often faster if divisible by 8 dictionary.pad_to_multiple_(8) # often faster if divisible by 8
mask_idx = 0 mask_idx = 0
@ -52,8 +54,8 @@ class DummyMaskedLMTask(LegacyFairseqTask):
"""Setup the task. """ """Setup the task. """
dictionary = Dictionary() dictionary = Dictionary()
for i in range(args.dict_size): for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i)) dictionary.add_symbol("word{}".format(i))
logger.info('dictionary: {} types'.format(len(dictionary))) logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs): def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@ -67,16 +69,16 @@ class DummyMaskedLMTask(LegacyFairseqTask):
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
self.datasets[split] = DummyDataset( self.datasets[split] = DummyDataset(
{ {
'id': 1, "id": 1,
'net_input': { "net_input": {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full( "src_lengths": torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long (bsz,), self.args.tokens_per_sample, dtype=torch.long
), ),
}, },
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]), "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz, "nsentences": bsz,
'ntokens': bsz * self.args.tokens_per_sample, "ntokens": bsz * self.args.tokens_per_sample,
}, },
num_items=self.args.dataset_size, num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample, item_size=self.args.tokens_per_sample,
@ -92,7 +94,6 @@ class DummyMaskedLMTask(LegacyFairseqTask):
class DummyDataset(FairseqDataset): class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size): def __init__(self, batch, num_items, item_size):
super().__init__() super().__init__()
self.batch = batch self.batch = batch

View File

@ -5,7 +5,6 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.models import ( from fairseq.models import (
FairseqDecoder, FairseqDecoder,
@ -15,17 +14,16 @@ from fairseq.models import (
) )
@register_model('dummy_model') @register_model("dummy_model")
class DummyModel(FairseqLanguageModel): class DummyModel(FairseqLanguageModel):
def __init__(self, args, encoder): def __init__(self, args, encoder):
super().__init__(encoder) super().__init__(encoder)
self.args = args self.args = args
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
parser.add_argument('--num-layers', type=int, default=24) parser.add_argument("--num-layers", type=int, default=24)
parser.add_argument('--embed-dim', type=int, default=1024) parser.add_argument("--embed-dim", type=int, default=1024)
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
@ -41,32 +39,35 @@ class DummyModel(FairseqLanguageModel):
class DummyEncoder(FairseqDecoder): class DummyEncoder(FairseqDecoder):
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__(Dictionary()) super().__init__(Dictionary())
self.embed = nn.Embedding( self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
) )
self.layers_a = nn.ModuleList([ self.layers_a = nn.ModuleList(
nn.Sequential( [
nn.LayerNorm(embed_dim), nn.Sequential(
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection nn.LayerNorm(embed_dim),
nn.Linear(3*embed_dim, embed_dim), # skip self-attention nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
nn.Linear(embed_dim, embed_dim), # output projection nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
nn.Dropout(), nn.Linear(embed_dim, embed_dim), # output projection
) nn.Dropout(),
for i in range(num_layers) )
]) for i in range(num_layers)
self.layers_b = nn.ModuleList([ ]
nn.Sequential( )
nn.LayerNorm(embed_dim), self.layers_b = nn.ModuleList(
nn.Linear(embed_dim, 4*embed_dim), # FFN [
nn.ReLU(), nn.Sequential(
nn.Linear(4*embed_dim, embed_dim), # FFN nn.LayerNorm(embed_dim),
nn.Dropout(0.1), nn.Linear(embed_dim, 4 * embed_dim), # FFN
) nn.ReLU(),
for i in range(num_layers) nn.Linear(4 * embed_dim, embed_dim), # FFN
]) nn.Dropout(0.1),
)
for i in range(num_layers)
]
)
self.out_proj = nn.Linear(embed_dim, num_embed) self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens, masked_tokens=None): def forward(self, tokens, masked_tokens=None):
@ -90,6 +91,6 @@ class DummyEncoder(FairseqDecoder):
return F.softmax(logits, dim=-1) return F.softmax(logits, dim=-1)
@register_model_architecture('dummy_model', 'dummy_model') @register_model_architecture("dummy_model", "dummy_model")
def base_architecture(args): def base_architecture(args):
pass pass

View File

@ -7,24 +7,22 @@ import logging
import numpy as np import numpy as np
import torch import torch
from fairseq.data import Dictionary, FairseqDataset from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_task('dummy_mt') @register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask): class DummyMTTask(LegacyFairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49996, type=int) parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument('--dataset-size', default=100000, type=int) parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument('--src-len', default=30, type=int) parser.add_argument("--src-len", default=30, type=int)
parser.add_argument('--tgt-len', default=30, type=int) parser.add_argument("--tgt-len", default=30, type=int)
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(args) super().__init__(args)
@ -41,8 +39,8 @@ class DummyMTTask(LegacyFairseqTask):
"""Setup the task. """ """Setup the task. """
dictionary = Dictionary() dictionary = Dictionary()
for i in range(args.dict_size): for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i)) dictionary.add_symbol("word{}".format(i))
logger.info('dictionary: {} types'.format(len(dictionary))) logger.info("dictionary: {} types".format(len(dictionary)))
args.max_source_positions = args.src_len + dictionary.pad() + 2 args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2 args.max_target_positions = args.tgt_len + dictionary.pad() + 2
@ -62,17 +60,17 @@ class DummyMTTask(LegacyFairseqTask):
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset( self.datasets[split] = DummyDataset(
{ {
'id': 1, "id": 1,
'net_input': { "net_input": {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full( "src_lengths": torch.full(
(bsz, ), self.args.src_len, dtype=torch.long (bsz,), self.args.src_len, dtype=torch.long
), ),
'prev_output_tokens': tgt.clone(), "prev_output_tokens": tgt.clone(),
}, },
'target': tgt, "target": tgt,
'nsentences': bsz, "nsentences": bsz,
'ntokens': bsz * self.args.tgt_len, "ntokens": bsz * self.args.tgt_len,
}, },
num_items=self.args.dataset_size, num_items=self.args.dataset_size,
item_size=item_size, item_size=item_size,
@ -88,7 +86,6 @@ class DummyMTTask(LegacyFairseqTask):
class DummyDataset(FairseqDataset): class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size): def __init__(self, batch, num_items, item_size):
super().__init__() super().__init__()
self.batch = batch self.batch = batch

View File

@ -6,9 +6,10 @@
import os import os
from collections import Counter from collections import Counter
from fairseq.tokenizer import tokenize_line
import torch import torch
from fairseq.file_io import PathManager from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
def safe_readline(f): def safe_readline(f):
pos = f.tell() pos = f.tell()

View File

@ -67,12 +67,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
or is_better(val_loss, save_checkpoint.best) or is_better(val_loss, save_checkpoint.best)
) )
if val_loss is not None and args.keep_best_checkpoints > 0: if val_loss is not None and args.keep_best_checkpoints > 0:
checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format( checkpoint_conds[
args.best_checkpoint_metric, val_loss)] = ( "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
not hasattr(save_checkpoint, "best") ] = not hasattr(save_checkpoint, "best") or is_better(
or is_better(val_loss, save_checkpoint.best) val_loss, save_checkpoint.best
) )
checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not args.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"): if hasattr(save_checkpoint, "best"):
@ -112,10 +114,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.keep_best_checkpoints > 0: if args.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric # only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths( checkpoints = checkpoint_paths(
args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric)) args.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
args.best_checkpoint_metric
),
)
if not args.maximize_best_checkpoint_metric: if not args.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1] checkpoints = checkpoints[::-1]
for old_chk in checkpoints[args.keep_best_checkpoints:]: for old_chk in checkpoints[args.keep_best_checkpoints :]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
@ -133,16 +139,23 @@ def load_checkpoint(args, trainer, **passthrough_args):
reset_meters = args.reset_meters reset_meters = args.reset_meters
reset_dataloader = args.reset_dataloader reset_dataloader = args.reset_dataloader
if getattr(args, 'finetune_from_model', None) is not None \ if getattr(args, "finetune_from_model", None) is not None and (
and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader): reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer" ):
" or reset_lr_scheduler or reset_meters or reset_dataloader") raise ValueError(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = getattr(args, "checkpoint_suffix", "") suffix = getattr(args, "checkpoint_suffix", "")
if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt' if (
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix)) args.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
args.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path) first_launch = not PathManager.exists(checkpoint_path)
if getattr(args, 'finetune_from_model', None) is not None and first_launch: if getattr(args, "finetune_from_model", None) is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model # if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(args.finetune_from_model): if PathManager.exists(args.finetune_from_model):
@ -151,19 +164,26 @@ def load_checkpoint(args, trainer, **passthrough_args):
reset_lr_scheduler = True reset_lr_scheduler = True
reset_meters = True reset_meters = True
reset_dataloader = True reset_dataloader = True
logger.info(f'loading pretrained model from {checkpoint_path}: ' logger.info(
'optimizer, lr scheduler, meters, dataloader will be reset') f"loading pretrained model from {checkpoint_path}: "
"optimizer, lr scheduler, meters, dataloader will be reset"
)
else: else:
raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist') raise ValueError(
f"--funetune-from-model {args.finetune_from_model} does not exist"
)
elif getattr(args, "model_parallel_size", 1) > 1: elif getattr(args, "model_parallel_size", 1) > 1:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
else: else:
checkpoint_path = args.restore_file checkpoint_path = args.restore_file
if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None): if args.restore_file != "checkpoint_last.pt" and getattr(
args, "finetune_from_model", None
):
raise ValueError( raise ValueError(
'--finetune-from-model and --restore-file (non-default value) ' "--finetune-from-model and --restore-file (non-default value) "
'can not be specified together: ' + str(args)) "can not be specified together: " + str(args)
)
extra_state = trainer.load_checkpoint( extra_state = trainer.load_checkpoint(
checkpoint_path, checkpoint_path,
@ -213,7 +233,9 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
return state return state
def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1): def load_model_ensemble(
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
):
"""Loads an ensemble of models. """Loads an ensemble of models.
Args: Args:
@ -222,18 +244,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s
were used during model training were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading task (fairseq.tasks.FairseqTask, optional): task to use for loading
""" """
assert not (strict and num_shards > 1), \ assert not (
"Cannot load state dict with strict=True and checkpoint shards > 1" strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble, args, _task = load_model_ensemble_and_task( ensemble, args, _task = load_model_ensemble_and_task(
filenames, arg_overrides, task, strict, suffix, num_shards, filenames,
arg_overrides,
task,
strict,
suffix,
num_shards,
) )
return ensemble, args return ensemble, args
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1): def load_model_ensemble_and_task(
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
):
from fairseq import tasks from fairseq import tasks
assert not (strict and num_shards > 1), \
"Cannot load state dict with strict=True and checkpoint shards > 1" assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble = [] ensemble = []
for filename in filenames: for filename in filenames:
orig_filename = filename orig_filename = filename
@ -533,7 +565,9 @@ def verify_checkpoint_directory(save_dir: str) -> None:
with open(temp_file_path, "w"): with open(temp_file_path, "w"):
pass pass
except OSError as e: except OSError as e:
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir)) logger.warning(
"Unable to access checkpoint save directory: {}".format(save_dir)
)
raise e raise e
else: else:
os.remove(temp_file_path) os.remove(temp_file_path)

Some files were not shown because too many files have changed in this diff Show More