mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
Apply black+isort (#1357)
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
This commit is contained in:
parent
5695cdfb2c
commit
a48f235636
57
docs/conf.py
57
docs/conf.py
@ -20,10 +20,11 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
source_suffix = ['.rst']
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
source_suffix = [".rst"]
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
@ -35,34 +36,34 @@ source_suffix = ['.rst']
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinxarg.ext',
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinxarg.ext",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
master_doc = "index"
|
||||
|
||||
# General information about the project.
|
||||
project = 'fairseq'
|
||||
copyright = '2019, Facebook AI Research (FAIR)'
|
||||
author = 'Facebook AI Research (FAIR)'
|
||||
project = "fairseq"
|
||||
copyright = "2019, Facebook AI Research (FAIR)"
|
||||
author = "Facebook AI Research (FAIR)"
|
||||
|
||||
github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
|
||||
github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
version = '0.9.0'
|
||||
version = "0.9.0"
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = '0.9.0'
|
||||
release = "0.9.0"
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
@ -74,11 +75,11 @@ language = None
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
highlight_language = 'python'
|
||||
pygments_style = "sphinx"
|
||||
highlight_language = "python"
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = False
|
||||
@ -89,7 +90,7 @@ todo_include_todos = False
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
@ -100,11 +101,11 @@ html_theme = 'sphinx_rtd_theme'
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_context = {
|
||||
'css_files': [
|
||||
'_static/theme_overrides.css', # override wide tables in RTD theme
|
||||
"css_files": [
|
||||
"_static/theme_overrides.css", # override wide tables in RTD theme
|
||||
],
|
||||
}
|
||||
|
||||
@ -113,7 +114,7 @@ html_context = {
|
||||
#
|
||||
# This is required for the alabaster theme
|
||||
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
||||
#html_sidebars = {
|
||||
# html_sidebars = {
|
||||
# '**': [
|
||||
# 'about.html',
|
||||
# 'navigation.html',
|
||||
@ -121,12 +122,12 @@ html_context = {
|
||||
# 'searchbox.html',
|
||||
# 'donate.html',
|
||||
# ]
|
||||
#}
|
||||
# }
|
||||
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
|
||||
'python': ('https://docs.python.org/', None),
|
||||
'torch': ('https://pytorch.org/docs/master/', None),
|
||||
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
|
||||
"python": ("https://docs.python.org/", None),
|
||||
"torch": ("https://pytorch.org/docs/master/", None),
|
||||
}
|
||||
|
@ -3,6 +3,6 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
__version__ = '0.9.0'
|
||||
__version__ = "0.9.0"
|
||||
|
||||
import examples.noisychannel # noqa
|
||||
|
@ -7,8 +7,8 @@
|
||||
import argparse
|
||||
import fileinput
|
||||
import hashlib
|
||||
from multiprocessing import Pool
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def get_hashes_and_lines(raw_line):
|
||||
@ -18,12 +18,12 @@ def get_hashes_and_lines(raw_line):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--workers', type=int, default=10)
|
||||
parser.add_argument('files', nargs='*', help='input files')
|
||||
parser.add_argument("--workers", type=int, default=10)
|
||||
parser.add_argument("files", nargs="*", help="input files")
|
||||
args = parser.parse_args()
|
||||
|
||||
seen = set()
|
||||
with fileinput.input(args.files, mode='rb') as h:
|
||||
with fileinput.input(args.files, mode="rb") as h:
|
||||
pool = Pool(args.workers)
|
||||
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
|
||||
for i, (hash, raw_line) in enumerate(results):
|
||||
@ -37,5 +37,5 @@ def main():
|
||||
print(file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -11,26 +11,38 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Extract back-translations from the stdout of fairseq-generate. '
|
||||
'If there are multiply hypotheses for a source, we only keep the first one. '
|
||||
))
|
||||
parser.add_argument('--output', required=True, help='output prefix')
|
||||
parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)')
|
||||
parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)')
|
||||
parser.add_argument('--minlen', type=int, help='min length filter')
|
||||
parser.add_argument('--maxlen', type=int, help='max length filter')
|
||||
parser.add_argument('--ratio', type=float, help='ratio filter')
|
||||
parser.add_argument('files', nargs='*', help='input files')
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Extract back-translations from the stdout of fairseq-generate. "
|
||||
"If there are multiply hypotheses for a source, we only keep the first one. "
|
||||
)
|
||||
)
|
||||
parser.add_argument("--output", required=True, help="output prefix")
|
||||
parser.add_argument(
|
||||
"--srclang", required=True, help="source language (extracted from H-* lines)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
|
||||
)
|
||||
parser.add_argument("--minlen", type=int, help="min length filter")
|
||||
parser.add_argument("--maxlen", type=int, help="max length filter")
|
||||
parser.add_argument("--ratio", type=float, help="ratio filter")
|
||||
parser.add_argument("files", nargs="*", help="input files")
|
||||
args = parser.parse_args()
|
||||
|
||||
def validate(src, tgt):
|
||||
srclen = len(src.split(' ')) if src != '' else 0
|
||||
tgtlen = len(tgt.split(' ')) if tgt != '' else 0
|
||||
srclen = len(src.split(" ")) if src != "" else 0
|
||||
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
|
||||
if (
|
||||
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
|
||||
or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen))
|
||||
or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio))
|
||||
or (
|
||||
args.maxlen is not None
|
||||
and (srclen > args.maxlen or tgtlen > args.maxlen)
|
||||
)
|
||||
or (
|
||||
args.ratio is not None
|
||||
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
|
||||
)
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@ -41,19 +53,20 @@ def main():
|
||||
except IndexError:
|
||||
return default
|
||||
|
||||
with open(args.output + '.' + args.srclang, 'w') as src_h, \
|
||||
open(args.output + '.' + args.tgtlang, 'w') as tgt_h:
|
||||
with open(args.output + "." + args.srclang, "w") as src_h, open(
|
||||
args.output + "." + args.tgtlang, "w"
|
||||
) as tgt_h:
|
||||
for line in tqdm(fileinput.input(args.files)):
|
||||
if line.startswith('S-'):
|
||||
tgt = safe_index(line.rstrip().split('\t'), 1, '')
|
||||
elif line.startswith('H-'):
|
||||
if line.startswith("S-"):
|
||||
tgt = safe_index(line.rstrip().split("\t"), 1, "")
|
||||
elif line.startswith("H-"):
|
||||
if tgt is not None:
|
||||
src = safe_index(line.rstrip().split('\t'), 2, '')
|
||||
src = safe_index(line.rstrip().split("\t"), 2, "")
|
||||
if validate(src, tgt):
|
||||
print(src, file=src_h)
|
||||
print(tgt, file=tgt_h)
|
||||
tgt = None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -4,203 +4,251 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os.path as op
|
||||
import argparse
|
||||
import os
|
||||
from multiprocessing import cpu_count
|
||||
import os.path as op
|
||||
from collections import namedtuple
|
||||
from typing import Optional, List
|
||||
from multiprocessing import cpu_count
|
||||
from typing import List, Optional
|
||||
|
||||
import sentencepiece as sp
|
||||
|
||||
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
|
||||
from fairseq.data.encoders.byte_utils import byte_encode
|
||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
||||
from fairseq.data.encoders.characters import Characters
|
||||
from fairseq.data.encoders.byte_bpe import ByteBPE
|
||||
from fairseq.data.encoders.byte_utils import byte_encode
|
||||
from fairseq.data.encoders.bytes import Bytes
|
||||
from fairseq.data.encoders.characters import Characters
|
||||
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
|
||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
||||
|
||||
|
||||
SPLITS = ['train', 'valid', 'test']
|
||||
SPLITS = ["train", "valid", "test"]
|
||||
|
||||
|
||||
def _convert_xml(in_path: str, out_path: str):
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
ss = s.strip()
|
||||
if not ss.startswith('<seg'):
|
||||
if not ss.startswith("<seg"):
|
||||
continue
|
||||
ss = ss.replace('</seg>', '').split('">')
|
||||
ss = ss.replace("</seg>", "").split('">')
|
||||
assert len(ss) == 2
|
||||
f_o.write(ss[1].strip() + '\n')
|
||||
f_o.write(ss[1].strip() + "\n")
|
||||
|
||||
|
||||
def _convert_train(in_path: str, out_path: str):
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
ss = s.strip()
|
||||
if ss.startswith('<'):
|
||||
if ss.startswith("<"):
|
||||
continue
|
||||
f_o.write(ss.strip() + '\n')
|
||||
f_o.write(ss.strip() + "\n")
|
||||
|
||||
|
||||
def _get_bytes(in_path: str, out_path: str):
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
f_o.write(Bytes.encode(s.strip()) + '\n')
|
||||
f_o.write(Bytes.encode(s.strip()) + "\n")
|
||||
|
||||
|
||||
def _get_chars(in_path: str, out_path: str):
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
f_o.write(Characters.encode(s.strip()) + '\n')
|
||||
f_o.write(Characters.encode(s.strip()) + "\n")
|
||||
|
||||
|
||||
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
|
||||
Args = namedtuple('Args', ['moses_source_lang', 'moses_target_lang',
|
||||
'moses_no_dash_splits', 'moses_no_escape'])
|
||||
args = Args(moses_source_lang=src, moses_target_lang=tgt,
|
||||
moses_no_dash_splits=False, moses_no_escape=False)
|
||||
Args = namedtuple(
|
||||
"Args",
|
||||
[
|
||||
"moses_source_lang",
|
||||
"moses_target_lang",
|
||||
"moses_no_dash_splits",
|
||||
"moses_no_escape",
|
||||
],
|
||||
)
|
||||
args = Args(
|
||||
moses_source_lang=src,
|
||||
moses_target_lang=tgt,
|
||||
moses_no_dash_splits=False,
|
||||
moses_no_escape=False,
|
||||
)
|
||||
pretokenizer = MosesTokenizer(args)
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
f_o.write(pretokenizer.encode(s.strip()) + '\n')
|
||||
f_o.write(pretokenizer.encode(s.strip()) + "\n")
|
||||
|
||||
|
||||
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
|
||||
with open(out_path, 'w') as f_o:
|
||||
with open(out_path, "w") as f_o:
|
||||
for lang in [src, tgt]:
|
||||
with open(f'{in_path_prefix}.{lang}') as f:
|
||||
with open(f"{in_path_prefix}.{lang}") as f:
|
||||
for s in f:
|
||||
f_o.write(byte_encode(s.strip()) + '\n')
|
||||
f_o.write(byte_encode(s.strip()) + "\n")
|
||||
|
||||
|
||||
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
|
||||
arguments = [
|
||||
f'--input={in_path}', f'--model_prefix={model_prefix}',
|
||||
f'--model_type=bpe', f'--vocab_size={vocab_size}',
|
||||
'--character_coverage=1.0', '--normalization_rule_name=identity',
|
||||
f'--num_threads={cpu_count()}'
|
||||
f"--input={in_path}",
|
||||
f"--model_prefix={model_prefix}",
|
||||
f"--model_type=bpe",
|
||||
f"--vocab_size={vocab_size}",
|
||||
"--character_coverage=1.0",
|
||||
"--normalization_rule_name=identity",
|
||||
f"--num_threads={cpu_count()}",
|
||||
]
|
||||
sp.SentencePieceTrainer.Train(' '.join(arguments))
|
||||
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
||||
|
||||
|
||||
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
|
||||
Args = namedtuple('Args', ['sentencepiece_model_path'])
|
||||
Args = namedtuple("Args", ["sentencepiece_model_path"])
|
||||
args = Args(sentencepiece_model_path=model_path)
|
||||
tokenizer = ByteBPE(args)
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
f_o.write(tokenizer.encode(s.strip()) + '\n')
|
||||
f_o.write(tokenizer.encode(s.strip()) + "\n")
|
||||
|
||||
|
||||
def _apply_bpe(model_path: str, in_path: str, out_path: str):
|
||||
Args = namedtuple('Args', ['sentencepiece_model'])
|
||||
Args = namedtuple("Args", ["sentencepiece_model"])
|
||||
args = Args(sentencepiece_model=model_path)
|
||||
tokenizer = SentencepieceBPE(args)
|
||||
with open(in_path) as f, open(out_path, 'w') as f_o:
|
||||
with open(in_path) as f, open(out_path, "w") as f_o:
|
||||
for s in f:
|
||||
f_o.write(tokenizer.encode(s.strip()) + '\n')
|
||||
f_o.write(tokenizer.encode(s.strip()) + "\n")
|
||||
|
||||
|
||||
def _concat_files(in_paths: List[str], out_path: str):
|
||||
with open(out_path, 'w') as f_o:
|
||||
with open(out_path, "w") as f_o:
|
||||
for p in in_paths:
|
||||
with open(p) as f:
|
||||
for r in f:
|
||||
f_o.write(r)
|
||||
|
||||
|
||||
def preprocess_iwslt17(root: str, src: str, tgt: str, bpe_size: Optional[int],
|
||||
need_chars: bool, bbpe_size: Optional[int],
|
||||
need_bytes: bool):
|
||||
def preprocess_iwslt17(
|
||||
root: str,
|
||||
src: str,
|
||||
tgt: str,
|
||||
bpe_size: Optional[int],
|
||||
need_chars: bool,
|
||||
bbpe_size: Optional[int],
|
||||
need_bytes: bool,
|
||||
):
|
||||
# extract bitext
|
||||
in_root = op.join(root, f'{src}-{tgt}')
|
||||
in_root = op.join(root, f"{src}-{tgt}")
|
||||
for lang in [src, tgt]:
|
||||
_convert_train(
|
||||
op.join(in_root, f'train.tags.{src}-{tgt}.{lang}'),
|
||||
op.join(root, f'train.{lang}')
|
||||
op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
|
||||
op.join(root, f"train.{lang}"),
|
||||
)
|
||||
_convert_xml(
|
||||
op.join(in_root, f'IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml'),
|
||||
op.join(root, f'valid.{lang}')
|
||||
op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
|
||||
op.join(root, f"valid.{lang}"),
|
||||
)
|
||||
_convert_xml(
|
||||
op.join(in_root, f'IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml'),
|
||||
op.join(root, f'test.{lang}')
|
||||
op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
|
||||
op.join(root, f"test.{lang}"),
|
||||
)
|
||||
# pre-tokenize
|
||||
for lang in [src, tgt]:
|
||||
for split in SPLITS:
|
||||
pretokenize(op.join(root, f'{split}.{lang}'),
|
||||
op.join(root, f'{split}.moses.{lang}'), src, tgt)
|
||||
pretokenize(
|
||||
op.join(root, f"{split}.{lang}"),
|
||||
op.join(root, f"{split}.moses.{lang}"),
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
# tokenize with BPE vocabulary
|
||||
if bpe_size is not None:
|
||||
# learn vocabulary
|
||||
concated_train_path = op.join(root, 'train.all')
|
||||
concated_train_path = op.join(root, "train.all")
|
||||
_concat_files(
|
||||
[op.join(root, 'train.moses.fr'), op.join(root, 'train.moses.en')],
|
||||
concated_train_path
|
||||
[op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
|
||||
concated_train_path,
|
||||
)
|
||||
bpe_model_prefix = op.join(root, f'spm_bpe{bpe_size}')
|
||||
bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
|
||||
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
|
||||
os.remove(concated_train_path)
|
||||
# apply
|
||||
for lang in [src, tgt]:
|
||||
for split in SPLITS:
|
||||
_apply_bpe(
|
||||
bpe_model_prefix + '.model',
|
||||
op.join(root, f'{split}.moses.{lang}'),
|
||||
op.join(root, f'{split}.moses.bpe{bpe_size}.{lang}')
|
||||
bpe_model_prefix + ".model",
|
||||
op.join(root, f"{split}.moses.{lang}"),
|
||||
op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
|
||||
)
|
||||
# tokenize with bytes vocabulary
|
||||
if need_bytes:
|
||||
for lang in [src, tgt]:
|
||||
for split in SPLITS:
|
||||
_get_bytes(op.join(root, f'{split}.moses.{lang}'),
|
||||
op.join(root, f'{split}.moses.bytes.{lang}'))
|
||||
_get_bytes(
|
||||
op.join(root, f"{split}.moses.{lang}"),
|
||||
op.join(root, f"{split}.moses.bytes.{lang}"),
|
||||
)
|
||||
# tokenize with characters vocabulary
|
||||
if need_chars:
|
||||
for lang in [src, tgt]:
|
||||
for split in SPLITS:
|
||||
_get_chars(op.join(root, f'{split}.moses.{lang}'),
|
||||
op.join(root, f'{split}.moses.chars.{lang}'))
|
||||
_get_chars(
|
||||
op.join(root, f"{split}.moses.{lang}"),
|
||||
op.join(root, f"{split}.moses.chars.{lang}"),
|
||||
)
|
||||
# tokenize with byte-level BPE vocabulary
|
||||
if bbpe_size is not None:
|
||||
# learn vocabulary
|
||||
bchar_path = op.join(root, 'train.bchar')
|
||||
_convert_to_bchar(op.join(root, 'train.moses'), src, tgt, bchar_path)
|
||||
bbpe_model_prefix = op.join(root, f'spm_bbpe{bbpe_size}')
|
||||
bchar_path = op.join(root, "train.bchar")
|
||||
_convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
|
||||
bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
|
||||
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
|
||||
os.remove(bchar_path)
|
||||
# apply
|
||||
for lang in [src, tgt]:
|
||||
for split in SPLITS:
|
||||
_apply_bbpe(
|
||||
bbpe_model_prefix + '.model',
|
||||
op.join(root, f'{split}.moses.{lang}'),
|
||||
op.join(root, f'{split}.moses.bbpe{bbpe_size}.{lang}')
|
||||
bbpe_model_prefix + ".model",
|
||||
op.join(root, f"{split}.moses.{lang}"),
|
||||
op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, default='data')
|
||||
parser.add_argument('--bpe-vocab', default=None, type=int,
|
||||
help='Generate tokenized bitext with BPE of size K.'
|
||||
'Default to None (disabled).')
|
||||
parser.add_argument('--bbpe-vocab', default=None, type=int,
|
||||
help='Generate tokenized bitext with BBPE of size K.'
|
||||
'Default to None (disabled).')
|
||||
parser.add_argument('--byte-vocab', action='store_true',
|
||||
help='Generate tokenized bitext with bytes vocabulary')
|
||||
parser.add_argument('--char-vocab', action='store_true',
|
||||
help='Generate tokenized bitext with chars vocabulary')
|
||||
parser.add_argument("--root", type=str, default="data")
|
||||
parser.add_argument(
|
||||
"--bpe-vocab",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Generate tokenized bitext with BPE of size K."
|
||||
"Default to None (disabled).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bbpe-vocab",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Generate tokenized bitext with BBPE of size K."
|
||||
"Default to None (disabled).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--byte-vocab",
|
||||
action="store_true",
|
||||
help="Generate tokenized bitext with bytes vocabulary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--char-vocab",
|
||||
action="store_true",
|
||||
help="Generate tokenized bitext with chars vocabulary",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
preprocess_iwslt17(args.root, 'fr', 'en', args.bpe_vocab, args.char_vocab,
|
||||
args.bbpe_vocab, args.byte_vocab)
|
||||
preprocess_iwslt17(
|
||||
args.root,
|
||||
"fr",
|
||||
"en",
|
||||
args.bpe_vocab,
|
||||
args.char_vocab,
|
||||
args.bbpe_vocab,
|
||||
args.byte_vocab,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -11,7 +11,7 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.transformer import TransformerModel, TransformerEncoder
|
||||
from fairseq.models.transformer import TransformerEncoder, TransformerModel
|
||||
|
||||
|
||||
@register_model("gru_transformer")
|
||||
@ -24,9 +24,12 @@ class GRUTransformerModel(TransformerModel):
|
||||
class GRUTransformerEncoder(TransformerEncoder):
|
||||
def __init__(self, args, dictionary, embed_tokens):
|
||||
super().__init__(args, dictionary, embed_tokens)
|
||||
self.emb_ctx = nn.GRU(input_size=embed_tokens.embedding_dim,
|
||||
hidden_size=embed_tokens.embedding_dim // 2,
|
||||
num_layers=1, bidirectional=True)
|
||||
self.emb_ctx = nn.GRU(
|
||||
input_size=embed_tokens.embedding_dim,
|
||||
hidden_size=embed_tokens.embedding_dim // 2,
|
||||
num_layers=1,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
def forward_embedding(self, src_tokens):
|
||||
# embed tokens and positions
|
||||
|
@ -16,11 +16,12 @@ def main(args):
|
||||
print(normalizer.normalize(line.rstrip()), flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--lang', '-l', default='en')
|
||||
parser.add_argument('--penn', '-p', action='store_true')
|
||||
parser.add_argument("--lang", "-l", default="en")
|
||||
parser.add_argument("--penn", "-p", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
@ -6,12 +6,14 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
|
||||
import sacremoses
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Tokenizes, preserving tabs"""
|
||||
mt = sacremoses.MosesTokenizer(lang=args.lang)
|
||||
|
||||
def tok(s):
|
||||
return mt.tokenize(s, return_str=True)
|
||||
|
||||
@ -20,12 +22,13 @@ def main(args):
|
||||
print(*parts, sep="\t", flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--lang', '-l', default='en')
|
||||
parser.add_argument('--penn', '-p', action='store_true')
|
||||
parser.add_argument('--fields', '-f', help="fields to tokenize")
|
||||
parser.add_argument("--lang", "-l", default="en")
|
||||
parser.add_argument("--penn", "-p", action="store_true")
|
||||
parser.add_argument("--fields", "-f", help="fields to tokenize")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
@ -3,14 +3,15 @@
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import faiss
|
||||
import numpy as np
|
||||
import glob
|
||||
import argparse
|
||||
import glob
|
||||
from subprocess import check_call
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
GB = 1024*1024*1024
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def call(cmd):
|
||||
@ -18,14 +19,14 @@ def call(cmd):
|
||||
check_call(cmd, shell=True)
|
||||
|
||||
|
||||
def get_batches(directory, lang, prefix='all_avg_pool'):
|
||||
def get_batches(directory, lang, prefix="all_avg_pool"):
|
||||
print(f"Finding in {directory}/{prefix}.{lang}*")
|
||||
files = glob.glob(f'{directory}/{prefix}.{lang}*')
|
||||
files = glob.glob(f"{directory}/{prefix}.{lang}*")
|
||||
emb_files = []
|
||||
txt_files = []
|
||||
for emb_fi in files:
|
||||
emb_files.append(emb_fi)
|
||||
txt_fi = emb_fi.replace(prefix, 'sentences')
|
||||
txt_fi = emb_fi.replace(prefix, "sentences")
|
||||
txt_files.append(txt_fi)
|
||||
return emb_files, txt_files
|
||||
|
||||
@ -38,7 +39,7 @@ def load_batch(emb_file, dim):
|
||||
return embeddings
|
||||
|
||||
|
||||
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
|
||||
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
|
||||
sims = []
|
||||
inds = []
|
||||
xfrom = 0
|
||||
@ -53,7 +54,7 @@ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
|
||||
y_batch = load_batch(y_batch_f, dim)
|
||||
neighbor_size = min(k, y_batch.shape[0])
|
||||
yto = yfrom + y_batch.shape[0]
|
||||
print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto))
|
||||
print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
|
||||
idx = faiss.IndexFlatIP(dim)
|
||||
idx = faiss.index_cpu_to_all_gpus(idx)
|
||||
idx.add(y_batch)
|
||||
@ -86,8 +87,10 @@ def score(sim, fwd_mean, bwd_mean, margin):
|
||||
return margin(sim, (fwd_mean + bwd_mean) / 2)
|
||||
|
||||
|
||||
def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False):
|
||||
print(' - scoring {:d} candidates'.format(sim_mat.shape[0]))
|
||||
def score_candidates(
|
||||
sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
|
||||
):
|
||||
print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
|
||||
scores = np.zeros(candidate_inds.shape)
|
||||
for i in range(scores.shape[0]):
|
||||
for j in range(scores.shape[1]):
|
||||
@ -106,42 +109,50 @@ def load_text(files):
|
||||
return all_sentences
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Mine bitext')
|
||||
parser.add_argument('--src-lang', help='Source language')
|
||||
parser.add_argument('--tgt-lang', help='Target language')
|
||||
parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt')
|
||||
parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model')
|
||||
parser.add_argument('--dim', type=int, default=1024,
|
||||
help='Embedding dimension')
|
||||
parser.add_argument('--mem', type=int, default=5,
|
||||
help='Memory in GB')
|
||||
parser.add_argument('--src-dir', help='Source directory')
|
||||
parser.add_argument('--tgt-dir', help='Target directory')
|
||||
parser.add_argument('--output', help='Output path')
|
||||
parser.add_argument('--neighborhood', type=int, default=4,
|
||||
help='Embedding dimension')
|
||||
parser.add_argument('--threshold', type=float, default=1.06,
|
||||
help='Threshold on mined bitext')
|
||||
parser.add_argument('--valid-size', type=int, default=2000,
|
||||
help='Number of sentences used for validation set')
|
||||
parser.add_argument('--min-count', type=int, default=50000,
|
||||
help='Min num sentences used for each language')
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Mine bitext")
|
||||
parser.add_argument("--src-lang", help="Source language")
|
||||
parser.add_argument("--tgt-lang", help="Target language")
|
||||
parser.add_argument(
|
||||
"--dict-path", help="Path to dictionary file", default="dict.txt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
|
||||
)
|
||||
parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
|
||||
parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
|
||||
parser.add_argument("--src-dir", help="Source directory")
|
||||
parser.add_argument("--tgt-dir", help="Target directory")
|
||||
parser.add_argument("--output", help="Output path")
|
||||
parser.add_argument(
|
||||
"--neighborhood", type=int, default=4, help="Embedding dimension"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold", type=float, default=1.06, help="Threshold on mined bitext"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--valid-size",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="Number of sentences used for validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-count",
|
||||
type=int,
|
||||
default=50000,
|
||||
help="Min num sentences used for each language",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
|
||||
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
|
||||
margin = lambda a, b: a / b
|
||||
y2x_sim, y2x_ind = knnGPU_sharded(
|
||||
y_batches_f, x_batches_f,
|
||||
args.dim,
|
||||
args.neighborhood,
|
||||
direction='y2x')
|
||||
y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
|
||||
)
|
||||
x2y_sim, x2y_ind = knnGPU_sharded(
|
||||
x_batches_f, y_batches_f,
|
||||
args.dim,
|
||||
args.neighborhood,
|
||||
direction='x2y')
|
||||
x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
|
||||
)
|
||||
|
||||
x2y_mean = x2y_sim.mean(axis=1)
|
||||
y2x_mean = y2x_sim.mean(axis=1)
|
||||
@ -149,8 +160,13 @@ if __name__ == '__main__':
|
||||
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
|
||||
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
|
||||
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
|
||||
indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
|
||||
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1)
|
||||
indices = np.stack(
|
||||
(
|
||||
np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
|
||||
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
|
||||
|
||||
x_sentences = load_text(x_sents_f)
|
||||
@ -162,20 +178,20 @@ if __name__ == '__main__':
|
||||
directory = args.output
|
||||
call(f"mkdir -p {directory}")
|
||||
src_out = open(
|
||||
f'{directory}/all.{args.src_lang}',
|
||||
mode='w',
|
||||
encoding='utf-8',
|
||||
errors='surrogateescape')
|
||||
f"{directory}/all.{args.src_lang}",
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
errors="surrogateescape",
|
||||
)
|
||||
tgt_out = open(
|
||||
f'{directory}/all.{args.tgt_lang}',
|
||||
mode='w',
|
||||
encoding='utf-8',
|
||||
errors='surrogateescape')
|
||||
f"{directory}/all.{args.tgt_lang}",
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
errors="surrogateescape",
|
||||
)
|
||||
scores_out = open(
|
||||
f'{directory}/all.scores',
|
||||
mode='w',
|
||||
encoding='utf-8',
|
||||
errors='surrogateescape')
|
||||
f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
|
||||
)
|
||||
count = 0
|
||||
for i in np.argsort(-scores):
|
||||
src_ind, trg_ind = indices[i]
|
||||
@ -195,20 +211,23 @@ if __name__ == '__main__':
|
||||
scores_out.close()
|
||||
|
||||
print(f"Found {count} pairs for threshold={threshold}")
|
||||
with open(f'{directory}/all.{args.src_lang}') as all_s, \
|
||||
open(f'{directory}/all.{args.tgt_lang}') as all_t, \
|
||||
open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \
|
||||
open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \
|
||||
open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \
|
||||
open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t:
|
||||
count = 0
|
||||
for s_line, t_line in zip(all_s, all_t):
|
||||
s_line = s_line.split('\t')[1]
|
||||
t_line = t_line.split('\t')[1]
|
||||
if count >= args.valid_size:
|
||||
train_s.write(s_line)
|
||||
train_t.write(t_line)
|
||||
else:
|
||||
valid_s.write(s_line)
|
||||
valid_t.write(t_line)
|
||||
count += 1
|
||||
with open(f"{directory}/all.{args.src_lang}") as all_s, open(
|
||||
f"{directory}/all.{args.tgt_lang}"
|
||||
) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
|
||||
f"{directory}/valid.{args.tgt_lang}", "w"
|
||||
) as valid_t, open(
|
||||
f"{directory}/train.{args.src_lang}", "w"
|
||||
) as train_s, open(
|
||||
f"{directory}/train.{args.tgt_lang}", "w"
|
||||
) as train_t:
|
||||
count = 0
|
||||
for s_line, t_line in zip(all_s, all_t):
|
||||
s_line = s_line.split("\t")[1]
|
||||
t_line = t_line.split("\t")[1]
|
||||
if count >= args.valid_size:
|
||||
train_s.write(s_line)
|
||||
train_t.write(t_line)
|
||||
else:
|
||||
valid_s.write(s_line)
|
||||
valid_t.write(t_line)
|
||||
count += 1
|
||||
|
@ -7,27 +7,29 @@
|
||||
Translate pre-processed data with a trained model.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.sequence_generator import EnsembleModel
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False):
|
||||
def get_avg_pool(
|
||||
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
|
||||
):
|
||||
model = EnsembleModel(models)
|
||||
|
||||
# model.forward normally channels prev_output_tokens into the decoder
|
||||
# separately, but SequenceGenerator directly calls model.encoder
|
||||
encoder_input = {
|
||||
k: v for k, v in sample['net_input'].items()
|
||||
if k != 'prev_output_tokens'
|
||||
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||
}
|
||||
|
||||
# compute the encoder output for each beam
|
||||
encoder_outs = model.forward_encoder(encoder_input)
|
||||
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
|
||||
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32)
|
||||
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
|
||||
np.float32
|
||||
)
|
||||
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
|
||||
if has_langtok:
|
||||
encoder_mask = encoder_mask[1:, :, :]
|
||||
@ -38,13 +40,15 @@ def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langto
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.path is not None, '--path required for generation!'
|
||||
assert not args.sampling or args.nbest == args.beam, \
|
||||
'--sampling requires --nbest to be equal to --beam'
|
||||
assert args.replace_unk is None or args.raw_text, \
|
||||
'--replace-unk requires a raw text dataset (--raw-text)'
|
||||
assert args.path is not None, "--path required for generation!"
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
args.replace_unk is None or args.raw_text
|
||||
), "--replace-unk requires a raw text dataset (--raw-text)"
|
||||
|
||||
args.beam=1
|
||||
args.beam = 1
|
||||
utils.import_user_module(args)
|
||||
|
||||
if args.max_tokens is None:
|
||||
@ -58,15 +62,15 @@ def main(args):
|
||||
|
||||
# Set dictionaries
|
||||
try:
|
||||
src_dict = getattr(task, 'source_dictionary', None)
|
||||
src_dict = getattr(task, "source_dictionary", None)
|
||||
except NotImplementedError:
|
||||
src_dict = None
|
||||
tgt_dict = task.target_dictionary
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(args.path))
|
||||
print("| loading model(s) from {}".format(args.path))
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||
args.path.split(':'),
|
||||
args.path.split(":"),
|
||||
arg_overrides=eval(args.model_overrides),
|
||||
task=task,
|
||||
)
|
||||
@ -105,9 +109,9 @@ def main(args):
|
||||
shard_id = 0
|
||||
all_avg_pool = None
|
||||
encoder_has_langtok = (
|
||||
hasattr(task.args, 'encoder_langtok')
|
||||
hasattr(task.args, "encoder_langtok")
|
||||
and task.args.encoder_langtok is not None
|
||||
and hasattr(task.args, 'lang_tok_replacing_bos_eos')
|
||||
and hasattr(task.args, "lang_tok_replacing_bos_eos")
|
||||
and not task.args.lang_tok_replacing_bos_eos
|
||||
)
|
||||
with progress_bar.build_progress_bar(args, itr) as t:
|
||||
@ -116,34 +120,42 @@ def main(args):
|
||||
print("Skipping None")
|
||||
continue
|
||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||
if 'net_input' not in sample:
|
||||
if "net_input" not in sample:
|
||||
continue
|
||||
|
||||
prefix_tokens = None
|
||||
if args.prefix_size > 0:
|
||||
prefix_tokens = sample['target'][:, :args.prefix_size]
|
||||
prefix_tokens = sample["target"][:, : args.prefix_size]
|
||||
|
||||
with torch.no_grad():
|
||||
avg_pool = get_avg_pool(
|
||||
models, sample, prefix_tokens, src_dict,
|
||||
args.remove_bpe,
|
||||
has_langtok=encoder_has_langtok)
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens,
|
||||
src_dict,
|
||||
args.remove_bpe,
|
||||
has_langtok=encoder_has_langtok,
|
||||
)
|
||||
if all_avg_pool is not None:
|
||||
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
|
||||
else:
|
||||
all_avg_pool = avg_pool
|
||||
|
||||
if not isinstance(sample['id'], list):
|
||||
sample_ids = sample['id'].tolist()
|
||||
if not isinstance(sample["id"], list):
|
||||
sample_ids = sample["id"].tolist()
|
||||
else:
|
||||
sample_ids = sample['id']
|
||||
sample_ids = sample["id"]
|
||||
for i, sample_id in enumerate(sample_ids):
|
||||
# Remove padding
|
||||
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
|
||||
src_tokens = utils.strip_pad(
|
||||
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
|
||||
)
|
||||
|
||||
# Either retrieve the original sentences or regenerate them from tokens.
|
||||
if align_dict is not None:
|
||||
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
|
||||
src_str = task.dataset(args.gen_subset).src.get_original_text(
|
||||
sample_id
|
||||
)
|
||||
else:
|
||||
if src_dict is not None:
|
||||
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
||||
@ -152,37 +164,50 @@ def main(args):
|
||||
|
||||
if not args.quiet:
|
||||
if src_dict is not None:
|
||||
print('S-{}\t{}'.format(sample_id, src_str))
|
||||
print("S-{}\t{}".format(sample_id, src_str))
|
||||
|
||||
source_sentences.append(f"{sample_id}\t{src_str}")
|
||||
|
||||
num_sentences += sample['nsentences']
|
||||
num_sentences += sample["nsentences"]
|
||||
if all_avg_pool.shape[0] >= 1000000:
|
||||
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
|
||||
'w') as avg_pool_file:
|
||||
with open(
|
||||
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
|
||||
"w",
|
||||
) as avg_pool_file:
|
||||
all_avg_pool.tofile(avg_pool_file)
|
||||
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
|
||||
sentence_file.writelines(f'{line}\n' for line in source_sentences)
|
||||
with open(
|
||||
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
|
||||
"w",
|
||||
) as sentence_file:
|
||||
sentence_file.writelines(f"{line}\n" for line in source_sentences)
|
||||
all_avg_pool = None
|
||||
source_sentences = []
|
||||
shard_id += 1
|
||||
|
||||
if all_avg_pool is not None:
|
||||
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
|
||||
'w') as avg_pool_file:
|
||||
with open(
|
||||
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
|
||||
) as avg_pool_file:
|
||||
all_avg_pool.tofile(avg_pool_file)
|
||||
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
|
||||
sentence_file.writelines(f'{line}\n' for line in source_sentences)
|
||||
with open(
|
||||
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
|
||||
) as sentence_file:
|
||||
sentence_file.writelines(f"{line}\n" for line in source_sentences)
|
||||
return None
|
||||
|
||||
|
||||
def cli_main():
|
||||
parser = options.get_generation_parser()
|
||||
parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N',
|
||||
help='directory to save encoder outputs')
|
||||
parser.add_argument(
|
||||
"--encoder-save-dir",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="N",
|
||||
help="directory to save encoder outputs",
|
||||
)
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
@ -3,10 +3,11 @@
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import numpy as np
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
DIM = 1024
|
||||
|
||||
@ -14,9 +15,13 @@ DIM = 1024
|
||||
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
|
||||
target_ids = [tid for tid in target_embs]
|
||||
source_mat = np.stack(source_embs.values(), axis=0)
|
||||
normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True)
|
||||
normalized_source_mat = source_mat / np.linalg.norm(
|
||||
source_mat, axis=1, keepdims=True
|
||||
)
|
||||
target_mat = np.stack(target_embs.values(), axis=0)
|
||||
normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True)
|
||||
normalized_target_mat = target_mat / np.linalg.norm(
|
||||
target_mat, axis=1, keepdims=True
|
||||
)
|
||||
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
|
||||
if return_sim_mat:
|
||||
return sim_mat
|
||||
@ -36,14 +41,14 @@ def load_embeddings(directory, LANGS):
|
||||
lang_dir = f"{directory}/{lang}"
|
||||
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
|
||||
for embed_file in embedding_files:
|
||||
shard_id = embed_file.split('.')[-1]
|
||||
shard_id = embed_file.split(".")[-1]
|
||||
embeddings = np.fromfile(embed_file, dtype=np.float32)
|
||||
num_rows = embeddings.shape[0] // DIM
|
||||
embeddings = embeddings.reshape((num_rows, DIM))
|
||||
|
||||
with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file:
|
||||
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
|
||||
for idx, line in enumerate(sentence_file):
|
||||
sentence_id, sentence = line.strip().split('\t')
|
||||
sentence_id, sentence = line.strip().split("\t")
|
||||
sentence_texts[lang][sentence_id] = sentence
|
||||
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
|
||||
|
||||
@ -55,7 +60,7 @@ def compute_accuracy(directory, LANGS):
|
||||
|
||||
top_1_accuracy = {}
|
||||
|
||||
top1_str = " ".join(LANGS) + '\n'
|
||||
top1_str = " ".join(LANGS) + "\n"
|
||||
for source_lang in LANGS:
|
||||
top_1_accuracy[source_lang] = {}
|
||||
top1_str += f"{source_lang} "
|
||||
@ -63,8 +68,8 @@ def compute_accuracy(directory, LANGS):
|
||||
top1 = 0
|
||||
top5 = 0
|
||||
neighbors_map = compute_dist(
|
||||
sentence_embeddings[source_lang],
|
||||
sentence_embeddings[target_lang])
|
||||
sentence_embeddings[source_lang], sentence_embeddings[target_lang]
|
||||
)
|
||||
for sentence_id, neighbors in neighbors_map.items():
|
||||
if sentence_id == neighbors[0]:
|
||||
top1 += 1
|
||||
@ -75,17 +80,13 @@ def compute_accuracy(directory, LANGS):
|
||||
top1_str += "\n"
|
||||
|
||||
print(top1_str)
|
||||
print(top1_str, file=open(f"{directory}/accuracy", 'w'))
|
||||
print(top1_str, file=open(f"{directory}/accuracy", "w"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Analyze encoder outputs')
|
||||
parser.add_argument('directory',
|
||||
help='Source language corpus'
|
||||
)
|
||||
parser.add_argument('--langs',
|
||||
help='List of langs'
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Analyze encoder outputs")
|
||||
parser.add_argument("directory", help="Source language corpus")
|
||||
parser.add_argument("--langs", help="List of langs")
|
||||
args = parser.parse_args()
|
||||
langs = args.langs.split(',')
|
||||
langs = args.langs.split(",")
|
||||
compute_accuracy(args.directory, langs)
|
||||
|
@ -3,7 +3,7 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .models import latent_multilingual_transformer # noqa
|
||||
from .modules import latent_layers # noqa
|
||||
from .loss import latent_depth # noqa
|
||||
from . import multilingual_translation_latent_depth # noqa
|
||||
from . import multilingual_translation_latent_depth # noqa
|
||||
from .loss import latent_depth # noqa
|
||||
from .models import latent_multilingual_transformer # noqa
|
||||
from .modules import latent_layers # noqa
|
||||
|
@ -3,8 +3,9 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@ -19,17 +20,16 @@ class LatentLayersKLLoss(_Loss):
|
||||
eps = 1e-7
|
||||
if prior == "uniform":
|
||||
# uniform prior
|
||||
kl_loss = (samples * (
|
||||
torch.log(samples + eps) - math.log(0.5)
|
||||
)).sum(-1)
|
||||
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
|
||||
elif prior == "agged_posterior":
|
||||
# aggregated posterior
|
||||
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
|
||||
agged_q = torch.sum(y_t, dim=0)
|
||||
row_norm = agged_q.sum(-1)
|
||||
normed_agg_q = agged_q / row_norm
|
||||
kl_loss = (samples * (
|
||||
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
|
||||
kl_loss = (
|
||||
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
|
||||
).sum(-1)
|
||||
else:
|
||||
raise NotImplementedError("The specified prior is not implemented.")
|
||||
|
||||
@ -37,7 +37,9 @@ class LatentLayersKLLoss(_Loss):
|
||||
kl_loss /= layer_samples[0].size()[0]
|
||||
kl_weight = min(
|
||||
self.args.sparsity_weight,
|
||||
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
|
||||
(update_num - self.args.soft_update)
|
||||
* self.args.sparsity_weight
|
||||
/ self.args.anneal_updates,
|
||||
)
|
||||
kl_loss *= kl_weight * sample_size
|
||||
return kl_loss
|
||||
@ -58,15 +60,17 @@ class LatentLayersSparsityLoss(_Loss):
|
||||
share_loss = 0
|
||||
global_sparsity_loss = 0
|
||||
layer_samples = torch.stack(layer_samples_list, dim=0)
|
||||
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
|
||||
update_num > (self.args.soft_update + self.args.anneal_updates)):
|
||||
if (
|
||||
self.args.target_layers > 0 or self.args.share_weight > 0
|
||||
) and update_num > (self.args.soft_update + self.args.anneal_updates):
|
||||
# anneal sparsity weight
|
||||
if update_num < (self.args.anneal_updates + self.args.soft_update):
|
||||
weight_anneal = 0
|
||||
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
|
||||
weight_anneal = (
|
||||
(update_num - self.args.soft_update - self.args.anneal_updates)
|
||||
* self.args.share_weight / self.args.anneal_updates
|
||||
* self.args.share_weight
|
||||
/ self.args.anneal_updates
|
||||
)
|
||||
else:
|
||||
weight_anneal = 1
|
||||
@ -75,12 +79,21 @@ class LatentLayersSparsityLoss(_Loss):
|
||||
layer_utilization /= layer_samples.size()[0]
|
||||
if self.args.share_weight > 0:
|
||||
# encouraging sharing across languages
|
||||
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
|
||||
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
|
||||
share_loss = sum(
|
||||
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
|
||||
)
|
||||
batch_loss += (
|
||||
weight_anneal * self.args.share_weight * sample_size * share_loss
|
||||
)
|
||||
if self.args.target_layers > 0:
|
||||
# computed expected number of layers selected
|
||||
expeted_layers = sum(layer_utilization)
|
||||
# compute l2 loss wrt target number of layers
|
||||
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
|
||||
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
|
||||
batch_loss += (
|
||||
weight_anneal
|
||||
* self.args.share_weight
|
||||
* sample_size
|
||||
* global_sparsity_loss
|
||||
)
|
||||
return batch_loss
|
||||
|
@ -3,34 +3,31 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.models import (
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import (
|
||||
base_architecture,
|
||||
TransformerEncoder,
|
||||
TransformerDecoder,
|
||||
)
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
|
||||
|
||||
from .latent_transformer import (
|
||||
LatentTransformerEncoder,
|
||||
LatentTransformerDecoder,
|
||||
from fairseq.models.transformer import (
|
||||
TransformerDecoder,
|
||||
TransformerEncoder,
|
||||
base_architecture,
|
||||
)
|
||||
|
||||
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
|
||||
|
||||
@register_model('latent_multilingual_transformer')
|
||||
|
||||
@register_model("latent_multilingual_transformer")
|
||||
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
||||
"""A variant of standard multilingual Transformer models which encoder and/or
|
||||
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
|
||||
(https://arxiv.org/abs/2009.13102).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
|
||||
if is_encoder:
|
||||
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
|
||||
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
|
||||
return LatentTransformerEncoder(
|
||||
args, lang_dict, embed_tokens, num_logits=len(langs)
|
||||
)
|
||||
else:
|
||||
return TransformerEncoder(args, lang_dict, embed_tokens)
|
||||
else:
|
||||
@ -42,19 +39,21 @@ class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
||||
return TransformerDecoder(args, lang_dict, embed_tokens)
|
||||
|
||||
|
||||
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
|
||||
@register_model_architecture(
|
||||
"latent_multilingual_transformer", "latent_multilingual_transformer"
|
||||
)
|
||||
def latent_multilingual_architecture(args):
|
||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
|
||||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
|
||||
args.encoder_layers = getattr(args, 'encoder_layers', 12)
|
||||
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
|
||||
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
|
||||
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
|
||||
args.decoder_layers = getattr(args, 'decoder_layers', 24)
|
||||
args.share_encoders = getattr(args, 'share_encoders', True)
|
||||
args.share_decoders = getattr(args, 'share_decoders', True)
|
||||
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
|
||||
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
||||
args.share_encoders = getattr(args, "share_encoders", True)
|
||||
args.share_decoders = getattr(args, "share_decoders", True)
|
||||
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
|
||||
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
|
||||
|
||||
base_architecture(args)
|
||||
|
@ -7,26 +7,27 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
|
||||
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
|
||||
from ..modules.latent_layers import LayerSelect
|
||||
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
|
||||
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
|
||||
from torch import Tensor
|
||||
|
||||
from ..modules.latent_layers import LayerSelect
|
||||
|
||||
|
||||
class LatentTransformerEncoder(TransformerEncoder):
|
||||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
||||
TransformerEncoder.
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
|
||||
self.num_logits = num_logits
|
||||
self.num_layers = args.encoder_layers
|
||||
super().__init__(args, dictionary, embed_tokens)
|
||||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
||||
self.lang_idx = None
|
||||
self.layers = nn.ModuleList([
|
||||
self._build_encoder_layer(args, idx)
|
||||
for idx in range(args.encoder_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
|
||||
)
|
||||
|
||||
def set_lang_idx(self, lang_idx):
|
||||
self.lang_idx = lang_idx
|
||||
@ -50,6 +51,7 @@ class LatentTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
|
||||
parameters and sampling method.
|
||||
"""
|
||||
|
||||
def __init__(self, args, idx, layer_select=None):
|
||||
super().__init__(args)
|
||||
self.idx = idx
|
||||
@ -63,7 +65,10 @@ class LatentTransformerDecoder(TransformerDecoder):
|
||||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
||||
TransformerDecoder.
|
||||
"""
|
||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
|
||||
|
||||
def __init__(
|
||||
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
|
||||
):
|
||||
self.num_logits = num_logits
|
||||
self.num_layers = args.decoder_layers
|
||||
super().__init__(
|
||||
@ -71,16 +76,20 @@ class LatentTransformerDecoder(TransformerDecoder):
|
||||
)
|
||||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
||||
self.lang_idx = None
|
||||
self.layers = nn.ModuleList([
|
||||
self._build_decoder_layer(args, no_encoder_attn, idx)
|
||||
for idx in range(args.decoder_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
self._build_decoder_layer(args, no_encoder_attn, idx)
|
||||
for idx in range(args.decoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def set_lang_idx(self, lang_idx):
|
||||
self.lang_idx = lang_idx
|
||||
|
||||
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
|
||||
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)
|
||||
return LatentTransformerDecoderLayer(
|
||||
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -119,8 +128,15 @@ class LatentTransformerDecoderLayer(TransformerDecoderLayer):
|
||||
(default: False).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
||||
self,
|
||||
args,
|
||||
idx,
|
||||
layer_select=None,
|
||||
no_encoder_attn=False,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
):
|
||||
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
|
||||
self.idx = idx
|
||||
|
@ -12,6 +12,7 @@ class LayerSelect(nn.Module):
|
||||
either (soft) weighting or (hard) selection of residual connection.
|
||||
https://arxiv.org/abs/2009.13102
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers, num_logits, args):
|
||||
super(LayerSelect, self).__init__()
|
||||
self.args = args
|
||||
@ -27,14 +28,14 @@ class LayerSelect(nn.Module):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument(
|
||||
'--soft-select',
|
||||
action='store_true',
|
||||
help='use soft samples in training an inference'
|
||||
"--soft-select",
|
||||
action="store_true",
|
||||
help="use soft samples in training an inference",
|
||||
)
|
||||
parser.add_argument('--sampling-tau', type=float, help='sampling temperature')
|
||||
parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
|
||||
|
||||
def sample(self, logit_idx):
|
||||
""" To leverage the efficiency of distributed training, samples for all
|
||||
"""To leverage the efficiency of distributed training, samples for all
|
||||
layers are computed at once for each logit_idx. Logits are parameters
|
||||
learnt independent of each other.
|
||||
|
||||
@ -43,7 +44,9 @@ class LayerSelect(nn.Module):
|
||||
"""
|
||||
assert logit_idx is not None
|
||||
self.samples = self._gumbel_sigmoid(
|
||||
self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :],
|
||||
self.layer_logits[logit_idx, :].detach()
|
||||
if self.detach_grad
|
||||
else self.layer_logits[logit_idx, :],
|
||||
dim=-1,
|
||||
tau=self.tau,
|
||||
hard=self.hard_select,
|
||||
@ -54,10 +57,20 @@ class LayerSelect(nn.Module):
|
||||
sample = self.samples[i]
|
||||
return sample
|
||||
|
||||
def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5):
|
||||
def _gumbel_sigmoid(
|
||||
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
|
||||
):
|
||||
# ~Gumbel(0,1)
|
||||
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||
gumbels1 = (
|
||||
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
||||
.exponential_()
|
||||
.log()
|
||||
)
|
||||
gumbels2 = (
|
||||
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
||||
.exponential_()
|
||||
.log()
|
||||
)
|
||||
# Difference of two gumbels because we apply a sigmoid
|
||||
gumbels1 = (logits + gumbels1 - gumbels2) / tau
|
||||
y_soft = gumbels1.sigmoid()
|
||||
|
@ -5,10 +5,11 @@
|
||||
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
|
||||
|
||||
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
|
||||
|
||||
|
||||
@register_task('multilingual_translation_latent_depth')
|
||||
@register_task("multilingual_translation_latent_depth")
|
||||
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||
"""A task for multiple translation with latent depth.
|
||||
|
||||
@ -39,7 +40,9 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||
|
||||
def __init__(self, args, dicts, training):
|
||||
super().__init__(args, dicts, training)
|
||||
self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs])
|
||||
self.src_langs, self.tgt_langs = zip(
|
||||
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
|
||||
)
|
||||
if self.training and self.encoder_latent_layer:
|
||||
assert self.args.share_encoders
|
||||
if self.training and self.decoder_latent_layer:
|
||||
@ -47,46 +50,56 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||
if training or self.encoder_latent_layer or self.decoder_latent_layer:
|
||||
self.lang_pairs = args.lang_pairs
|
||||
else:
|
||||
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
|
||||
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
|
||||
self.eval_lang_pairs = self.lang_pairs
|
||||
self.model_lang_pairs = self.lang_pairs
|
||||
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
|
||||
self.kl_loss = LatentLayersKLLoss(self.args)
|
||||
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
|
||||
|
||||
def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
|
||||
def _per_lang_pair_train_loss(
|
||||
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
|
||||
):
|
||||
src, tgt = lang_pair.split("-")
|
||||
if self.encoder_latent_layer:
|
||||
src_lang_idx = self.src_lang_idx_dict[src]
|
||||
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
|
||||
model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update
|
||||
model.models[lang_pair].encoder.layer_select.hard_select = (
|
||||
update_num > self.args.soft_update
|
||||
)
|
||||
if self.decoder_latent_layer:
|
||||
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
|
||||
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
|
||||
model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update
|
||||
model.models[lang_pair].decoder.layer_select.hard_select = (
|
||||
update_num > self.args.soft_update
|
||||
)
|
||||
|
||||
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
|
||||
loss, sample_size, logging_output = criterion(
|
||||
model.models[lang_pair], sample[lang_pair]
|
||||
)
|
||||
if self.encoder_latent_layer:
|
||||
none_samples = sum(
|
||||
1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples
|
||||
1 if x is None else 0
|
||||
for x in model.models[lang_pair].encoder.layer_select.layer_samples
|
||||
)
|
||||
if none_samples == 0 or self.args.prior != "agged_posterior":
|
||||
loss += self.kl_loss(
|
||||
model.models[lang_pair].encoder.layer_select.layer_samples,
|
||||
src_lang_idx,
|
||||
update_num,
|
||||
sample_size
|
||||
sample_size,
|
||||
)
|
||||
if self.decoder_latent_layer:
|
||||
none_samples = sum(
|
||||
1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples
|
||||
1 if x is None else 0
|
||||
for x in model.models[lang_pair].decoder.layer_select.layer_samples
|
||||
)
|
||||
if none_samples == 0 or self.args.prior != "agged_posterior":
|
||||
loss += self.kl_loss(
|
||||
model.models[lang_pair].decoder.layer_select.layer_samples,
|
||||
tgt_lang_idx,
|
||||
update_num,
|
||||
sample_size
|
||||
sample_size,
|
||||
)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
@ -99,18 +112,31 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
|
||||
sample, model, criterion, optimizer, update_num, ignore_grad)
|
||||
sample, model, criterion, optimizer, update_num, ignore_grad
|
||||
)
|
||||
# compute auxiliary loss from layere sparsity, based on all samples from all languages
|
||||
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
|
||||
sparsity_loss = 0
|
||||
if self.encoder_latent_layer:
|
||||
sparsity_loss += self.sparsity_loss(
|
||||
next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size)
|
||||
next(
|
||||
iter(model.models.values())
|
||||
).encoder.layer_select.layer_samples,
|
||||
update_num,
|
||||
agg_sample_size,
|
||||
)
|
||||
if self.decoder_latent_layer:
|
||||
sparsity_loss += self.sparsity_loss(
|
||||
next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size)
|
||||
next(
|
||||
iter(model.models.values())
|
||||
).decoder.layer_select.layer_samples,
|
||||
update_num,
|
||||
agg_sample_size,
|
||||
)
|
||||
if sparsity_loss > 0:
|
||||
optimizer.backward(sparsity_loss)
|
||||
return agg_loss, agg_sample_size, agg_logging_output
|
||||
@ -123,10 +149,14 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||
if self.decoder_latent_layer:
|
||||
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
|
||||
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
|
||||
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
|
||||
loss, sample_size, logging_output = criterion(
|
||||
model.models[lang_pair], sample[lang_pair]
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
if self.encoder_latent_layer or self.decoder_latent_layer:
|
||||
for model in models:
|
||||
if self.encoder_latent_layer:
|
||||
@ -137,15 +167,23 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
|
||||
assert model.decoder.layer_select is not None
|
||||
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
|
||||
model.decoder.set_lang_idx(tgt_lang_idx)
|
||||
return super().inference_step(generator, models, sample, prefix_tokens, constraints)
|
||||
return super().inference_step(
|
||||
generator, models, sample, prefix_tokens, constraints
|
||||
)
|
||||
|
||||
@property
|
||||
def encoder_latent_layer(self):
|
||||
return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer
|
||||
return (
|
||||
hasattr(self.args, "encoder_latent_layer")
|
||||
and self.args.encoder_latent_layer
|
||||
)
|
||||
|
||||
@property
|
||||
def decoder_latent_layer(self):
|
||||
return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer
|
||||
return (
|
||||
hasattr(self.args, "decoder_latent_layer")
|
||||
and self.args.decoder_latent_layer
|
||||
)
|
||||
|
||||
@property
|
||||
def src_lang_idx_dict(self):
|
||||
|
@ -8,37 +8,40 @@ Linformer: Self-Attention with Linear Complexity
|
||||
|
||||
import logging
|
||||
|
||||
from fairseq.models import (
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.roberta import RobertaEncoder, RobertaModel
|
||||
|
||||
from fairseq.models.roberta import (
|
||||
RobertaModel,
|
||||
RobertaEncoder,
|
||||
)
|
||||
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_model('linformer_roberta')
|
||||
@register_model("linformer_roberta")
|
||||
class LinformerModel(RobertaModel):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
RobertaModel.add_args(parser)
|
||||
|
||||
# add args for Linformer
|
||||
parser.add_argument('--compressed', type=int,
|
||||
help='compressed ratio of sequence length')
|
||||
parser.add_argument('--shared-kv-compressed', type=int,
|
||||
help='share compressed matrix between k and v, in each layer')
|
||||
parser.add_argument('--shared-layer-kv-compressed', type=int,
|
||||
help='share compressed matrix between k and v and across all layers')
|
||||
parser.add_argument('--freeze-compress', type=int,
|
||||
help='freeze the parameters in compressed layer')
|
||||
parser.add_argument(
|
||||
"--compressed", type=int, help="compressed ratio of sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shared-kv-compressed",
|
||||
type=int,
|
||||
help="share compressed matrix between k and v, in each layer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shared-layer-kv-compressed",
|
||||
type=int,
|
||||
help="share compressed matrix between k and v and across all layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-compress",
|
||||
type=int,
|
||||
help="freeze the parameters in compressed layer",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
@ -47,7 +50,7 @@ class LinformerModel(RobertaModel):
|
||||
# make sure all arguments are present
|
||||
base_architecture(args)
|
||||
|
||||
if not hasattr(args, 'max_positions'):
|
||||
if not hasattr(args, "max_positions"):
|
||||
args.max_positions = args.tokens_per_sample
|
||||
|
||||
encoder = LinformerEncoder(args, task.source_dictionary)
|
||||
@ -85,47 +88,47 @@ class LinformerEncoder(RobertaEncoder):
|
||||
)
|
||||
|
||||
|
||||
@register_model_architecture('linformer_roberta', 'linformer_roberta')
|
||||
@register_model_architecture("linformer_roberta", "linformer_roberta")
|
||||
def base_architecture(args):
|
||||
args.encoder_layers = getattr(args, 'encoder_layers', 12)
|
||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
|
||||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
|
||||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
||||
|
||||
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
|
||||
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||
|
||||
args.dropout = getattr(args, 'dropout', 0.1)
|
||||
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
|
||||
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
|
||||
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
|
||||
args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
|
||||
args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
|
||||
args.compressed = getattr(args, 'compressed', 4)
|
||||
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
|
||||
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
|
||||
args.freeze_compress = getattr(args, 'freeze_compress', 0)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
||||
args.compressed = getattr(args, "compressed", 4)
|
||||
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
||||
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
||||
args.freeze_compress = getattr(args, "freeze_compress", 0)
|
||||
|
||||
|
||||
@register_model_architecture('linformer_roberta', 'linformer_roberta_base')
|
||||
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
|
||||
def linformer_roberta_base_architecture(args):
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture('linformer_roberta', 'linformer_roberta_large')
|
||||
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
|
||||
def linformer_roberta_large_architecture(args):
|
||||
args.encoder_layers = getattr(args, 'encoder_layers', 24)
|
||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
|
||||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
|
||||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
||||
|
||||
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
|
||||
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||
|
||||
args.dropout = getattr(args, 'dropout', 0.1)
|
||||
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
|
||||
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
|
||||
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
|
||||
args.compressed = getattr(args, 'compressed', 4)
|
||||
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
|
||||
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||
args.compressed = getattr(args, "compressed", 4)
|
||||
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
||||
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
||||
|
@ -6,8 +6,8 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq.modules import TransformerSentenceEncoder
|
||||
|
||||
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
|
||||
|
||||
|
||||
@ -117,7 +117,9 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
||||
qn_block_size,
|
||||
):
|
||||
if self.shared_layer_kv_compressed == 1:
|
||||
compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed)
|
||||
compress_layer = nn.Linear(
|
||||
self.max_seq_len, self.max_seq_len // self.compressed
|
||||
)
|
||||
# intialize parameters for compressed layer
|
||||
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
|
||||
if self.freeze_compress == 1:
|
||||
@ -139,8 +141,7 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
||||
max_seq_len=self.max_seq_len,
|
||||
shared_kv_compressed=self.shared_kv_compressed,
|
||||
shared_compress_layer=(
|
||||
None if self.shared_layer_kv_compressed == 0
|
||||
else self.compress_layer
|
||||
None if self.shared_layer_kv_compressed == 0 else self.compress_layer
|
||||
),
|
||||
freeze_compress=self.freeze_compress,
|
||||
)
|
||||
@ -156,7 +157,8 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
||||
if self.shared_layer_kv_compressed:
|
||||
for layer_idx in range(len(self.layers)):
|
||||
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
|
||||
layer_idx, k[len(prefix + 'compress_layer.'):],
|
||||
layer_idx,
|
||||
k[len(prefix + "compress_layer.") :],
|
||||
)
|
||||
items_to_add[new_k] = state_dict[k]
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
from typing import Callable
|
||||
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
|
||||
from .multihead_linear_attention import MultiheadLinearAttention
|
||||
|
||||
|
||||
@ -23,7 +24,7 @@ class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = 'relu',
|
||||
activation_fn: str = "relu",
|
||||
export: bool = False,
|
||||
q_noise: float = 0.0,
|
||||
qn_block_size: int = 8,
|
||||
|
@ -9,10 +9,10 @@ from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Parameter
|
||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||
from fairseq.modules.quant_noise import quant_noise
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
@with_incremental_state
|
||||
@ -65,16 +65,24 @@ class MultiheadLinearAttention(nn.Module):
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
# used for compress sequence to subsequence
|
||||
if shared_compress_layer is None:
|
||||
self.compress_seq_len = max_seq_len // compressed
|
||||
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
||||
if shared_kv_compressed == 0:
|
||||
self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
||||
self.compress_v = nn.Linear(
|
||||
max_seq_len, self.compress_seq_len, bias=False
|
||||
)
|
||||
self.layerwise_sharing = False
|
||||
else:
|
||||
self.compress_k = shared_compress_layer
|
||||
@ -83,7 +91,9 @@ class MultiheadLinearAttention(nn.Module):
|
||||
self.layerwise_sharing = True
|
||||
self.shared_kv_compressed = shared_kv_compressed
|
||||
|
||||
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
@ -116,22 +126,28 @@ class MultiheadLinearAttention(nn.Module):
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
if not self.layerwise_sharing: # otherwise, we already initialize the parameters
|
||||
nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2))
|
||||
if (
|
||||
not self.layerwise_sharing
|
||||
): # otherwise, we already initialize the parameters
|
||||
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
|
||||
if self.shared_kv_compressed == 0:
|
||||
nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2))
|
||||
nn.init.xavier_uniform_(
|
||||
self.compress_v.weight, gain=1 / math.sqrt(2)
|
||||
)
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
if not self.layerwise_sharing: # otherwise, we already initialize the parameters
|
||||
if (
|
||||
not self.layerwise_sharing
|
||||
): # otherwise, we already initialize the parameters
|
||||
nn.init.xavier_uniform_(self.compress_k.weight)
|
||||
if self.shared_kv_compressed == 0:
|
||||
nn.init.xavier_uniform_(self.compress_v.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
@ -189,14 +205,26 @@ class MultiheadLinearAttention(nn.Module):
|
||||
q = self.q_proj(query)
|
||||
|
||||
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
||||
k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
|
||||
k_input = (
|
||||
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
k = self.k_proj(k_input)
|
||||
|
||||
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
||||
if self.shared_kv_compressed == 0:
|
||||
v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
|
||||
v_input = (
|
||||
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
|
||||
v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
|
||||
v_input = (
|
||||
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
v = self.v_proj(v_input)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
@ -302,7 +330,9 @@ class MultiheadLinearAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
attn_weights = MultiheadLinearAttention.apply_sparse_mask(
|
||||
attn_weights, tgt_len, src_len, bsz
|
||||
)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
@ -385,7 +415,9 @@ class MultiheadLinearAttention(nn.Module):
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_incremental_state(
|
||||
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
new_order: Tensor,
|
||||
):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
input_buffer = self._get_input_buffer(incremental_state)
|
||||
@ -393,7 +425,9 @@ class MultiheadLinearAttention(nn.Module):
|
||||
for k in input_buffer.keys():
|
||||
input_buffer_k = input_buffer[k]
|
||||
if input_buffer_k is not None:
|
||||
if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0):
|
||||
if self.encoder_decoder_attention and input_buffer_k.size(
|
||||
0
|
||||
) == new_order.size(0):
|
||||
break
|
||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||
@ -428,8 +462,8 @@ class MultiheadLinearAttention(nn.Module):
|
||||
# in_proj_weight used to be q + k + v with same dimensions
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim]
|
||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
|
||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
||||
|
||||
keys_to_remove.append(k)
|
||||
|
||||
@ -438,9 +472,9 @@ class MultiheadLinearAttention(nn.Module):
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
||||
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||||
dim:2 * dim
|
||||
dim : 2 * dim
|
||||
]
|
||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
|
||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
||||
|
||||
keys_to_remove.append(prefix + "in_proj_bias")
|
||||
|
||||
|
@ -8,14 +8,16 @@
|
||||
|
||||
import sys
|
||||
|
||||
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
|
||||
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
|
||||
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
|
||||
|
||||
factory=IndicNormalizerFactory()
|
||||
normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing')
|
||||
|
||||
factory = IndicNormalizerFactory()
|
||||
normalizer = factory.get_normalizer(
|
||||
sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing"
|
||||
)
|
||||
|
||||
for line in sys.stdin:
|
||||
normalized_line=normalizer.normalize(line.strip())
|
||||
tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1]))
|
||||
normalized_line = normalizer.normalize(line.strip())
|
||||
tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1]))
|
||||
print(tokenized_line)
|
||||
|
||||
|
@ -8,5 +8,6 @@ import sys
|
||||
|
||||
from pythainlp import word_tokenize
|
||||
|
||||
|
||||
for line in sys.stdin:
|
||||
print(" ".join(word_tokenize(line.strip())))
|
||||
|
@ -6,7 +6,9 @@
|
||||
|
||||
|
||||
import fileinput
|
||||
|
||||
import sacrebleu
|
||||
|
||||
|
||||
for line in fileinput.input():
|
||||
print(sacrebleu.tokenize_zh(line))
|
||||
|
@ -6,19 +6,27 @@
|
||||
|
||||
import argparse
|
||||
import fileinput
|
||||
|
||||
import sacremoses
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('files', nargs='*', help='input files')
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument("files", nargs="*", help="input files")
|
||||
args = parser.parse_args()
|
||||
|
||||
detok = sacremoses.MosesDetokenizer()
|
||||
|
||||
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
|
||||
print(detok.detokenize(line.strip().split(' ')).replace(' @', '').replace('@ ', '').replace(' =', '=').replace('= ', '=').replace(' – ', '–'))
|
||||
print(
|
||||
detok.detokenize(line.strip().split(" "))
|
||||
.replace(" @", "")
|
||||
.replace("@ ", "")
|
||||
.replace(" =", "=")
|
||||
.replace("= ", "=")
|
||||
.replace(" – ", "–")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -7,21 +7,22 @@ import math
|
||||
from multiprocessing import Pool
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq import options
|
||||
from fairseq.data import dictionary
|
||||
from fairseq.scoring import bleu
|
||||
|
||||
from . import (
|
||||
rerank_generate,
|
||||
rerank_options,
|
||||
rerank_score_bw,
|
||||
rerank_score_lm,
|
||||
rerank_options,
|
||||
rerank_utils,
|
||||
)
|
||||
|
||||
|
||||
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize):
|
||||
def score_target_hypo(
|
||||
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
|
||||
):
|
||||
|
||||
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
|
||||
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
|
||||
@ -61,11 +62,21 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
||||
bitext2_score = None
|
||||
bitext2_backwards = None
|
||||
|
||||
score = rerank_utils.get_score(a, b, c, target_len,
|
||||
bitext1.rescore_score[i], bitext2_score, lm_score=lm_score,
|
||||
lenpen=lenpen, src_len=bitext1.source_lengths[i],
|
||||
tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards,
|
||||
bitext2_backwards=bitext2_backwards, normalize=normalize)
|
||||
score = rerank_utils.get_score(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
target_len,
|
||||
bitext1.rescore_score[i],
|
||||
bitext2_score,
|
||||
lm_score=lm_score,
|
||||
lenpen=lenpen,
|
||||
src_len=bitext1.source_lengths[i],
|
||||
tgt_len=bitext1.target_lengths[i],
|
||||
bitext1_backwards=bitext1.backwards,
|
||||
bitext2_backwards=bitext2_backwards,
|
||||
normalize=normalize,
|
||||
)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
@ -88,8 +99,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
||||
for key in range(len(gen_keys)):
|
||||
if args.prefix_len is None:
|
||||
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
||||
"pred and rescore hypo mismatch: i: " + str(key) + ", "
|
||||
+ str(hypo_lst[key]) + str(gen_keys[key])
|
||||
"pred and rescore hypo mismatch: i: "
|
||||
+ str(key)
|
||||
+ ", "
|
||||
+ str(hypo_lst[key])
|
||||
+ str(gen_keys[key])
|
||||
+ str(gen_output.no_bpe_hypo[key])
|
||||
)
|
||||
sys_tok = dict.encode_line(hypo_lst[key])
|
||||
@ -97,7 +111,9 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
||||
scorer.add(ref_tok, sys_tok)
|
||||
|
||||
else:
|
||||
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
|
||||
full_hypo = rerank_utils.get_full_from_prefix(
|
||||
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
|
||||
)
|
||||
sys_tok = dict.encode_line(full_hypo)
|
||||
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
|
||||
scorer.add(ref_tok, sys_tok)
|
||||
@ -107,20 +123,31 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
||||
# recover the orinal ids from n best list generation
|
||||
for key in range(len(gen_output.no_bpe_target)):
|
||||
if args.prefix_len is None:
|
||||
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
|
||||
"pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key])
|
||||
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
||||
"pred and rescore hypo mismatch:"
|
||||
+ "i:"
|
||||
+ str(key)
|
||||
+ str(hypo_lst[key])
|
||||
+ str(gen_output.no_bpe_hypo[key])
|
||||
)
|
||||
ordered_hypos[gen_keys[key]] = hypo_lst[key]
|
||||
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
|
||||
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
|
||||
gen_keys[key]
|
||||
]
|
||||
|
||||
else:
|
||||
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
|
||||
full_hypo = rerank_utils.get_full_from_prefix(
|
||||
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
|
||||
)
|
||||
ordered_hypos[gen_keys[key]] = full_hypo
|
||||
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
|
||||
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
|
||||
gen_keys[key]
|
||||
]
|
||||
|
||||
# write the hypos in the original order from nbest list generation
|
||||
if args.num_shards == (len(bitext1_lst)):
|
||||
with open(target_outfile, 'w') as t:
|
||||
with open(hypo_outfile, 'w') as h:
|
||||
with open(target_outfile, "w") as t:
|
||||
with open(hypo_outfile, "w") as h:
|
||||
for key in range(len(ordered_hypos)):
|
||||
t.write(ordered_targets[key])
|
||||
h.write(ordered_hypos[key])
|
||||
@ -135,17 +162,38 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
|
||||
def match_target_hypo(args, target_outfile, hypo_outfile):
|
||||
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
|
||||
if len(args.weight1) == 1:
|
||||
res = score_target_hypo(args, args.weight1[0], args.weight2[0],
|
||||
args.weight3[0], args.lenpen[0], target_outfile,
|
||||
hypo_outfile, True, args.normalize)
|
||||
res = score_target_hypo(
|
||||
args,
|
||||
args.weight1[0],
|
||||
args.weight2[0],
|
||||
args.weight3[0],
|
||||
args.lenpen[0],
|
||||
target_outfile,
|
||||
hypo_outfile,
|
||||
True,
|
||||
args.normalize,
|
||||
)
|
||||
rerank_scores = [res]
|
||||
else:
|
||||
print("launching pool")
|
||||
with Pool(32) as p:
|
||||
rerank_scores = p.starmap(score_target_hypo,
|
||||
[(args, args.weight1[i], args.weight2[i], args.weight3[i],
|
||||
args.lenpen[i], target_outfile, hypo_outfile,
|
||||
False, args.normalize) for i in range(len(args.weight1))])
|
||||
rerank_scores = p.starmap(
|
||||
score_target_hypo,
|
||||
[
|
||||
(
|
||||
args,
|
||||
args.weight1[i],
|
||||
args.weight2[i],
|
||||
args.weight3[i],
|
||||
args.lenpen[i],
|
||||
target_outfile,
|
||||
hypo_outfile,
|
||||
False,
|
||||
args.normalize,
|
||||
)
|
||||
for i in range(len(args.weight1))
|
||||
],
|
||||
)
|
||||
|
||||
if len(rerank_scores) > 1:
|
||||
best_index = np.argmax(rerank_scores)
|
||||
@ -155,11 +203,22 @@ def match_target_hypo(args, target_outfile, hypo_outfile):
|
||||
print("best weight1", args.weight1[best_index])
|
||||
print("best weight2", args.weight2[best_index])
|
||||
print("best weight3", args.weight3[best_index])
|
||||
return args.lenpen[best_index], args.weight1[best_index], \
|
||||
args.weight2[best_index], args.weight3[best_index], best_score
|
||||
return (
|
||||
args.lenpen[best_index],
|
||||
args.weight1[best_index],
|
||||
args.weight2[best_index],
|
||||
args.weight3[best_index],
|
||||
best_score,
|
||||
)
|
||||
|
||||
else:
|
||||
return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0]
|
||||
return (
|
||||
args.lenpen[0],
|
||||
args.weight1[0],
|
||||
args.weight2[0],
|
||||
args.weight3[0],
|
||||
rerank_scores[0],
|
||||
)
|
||||
|
||||
|
||||
def load_score_files(args):
|
||||
@ -175,55 +234,100 @@ def load_score_files(args):
|
||||
|
||||
for shard_id in shard_ids:
|
||||
using_nbest = args.nbest_list is not None
|
||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
||||
args.gen_model_name, shard_id, args.num_shards, args.sampling,
|
||||
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
|
||||
(
|
||||
pre_gen,
|
||||
left_to_right_preprocessed_dir,
|
||||
right_to_left_preprocessed_dir,
|
||||
backwards_preprocessed_dir,
|
||||
lm_preprocessed_dir,
|
||||
) = rerank_utils.get_directories(
|
||||
args.data_dir_name,
|
||||
args.num_rescore,
|
||||
args.gen_subset,
|
||||
args.gen_model_name,
|
||||
shard_id,
|
||||
args.num_shards,
|
||||
args.sampling,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
|
||||
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||
rerank1_is_gen = (
|
||||
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||
)
|
||||
rerank2_is_gen = (
|
||||
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||
)
|
||||
|
||||
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards1)
|
||||
score1_file = rerank_utils.rescore_file_name(
|
||||
pre_gen,
|
||||
args.prefix_len,
|
||||
args.model1_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards1,
|
||||
)
|
||||
if args.score_model2 is not None:
|
||||
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards2)
|
||||
score2_file = rerank_utils.rescore_file_name(
|
||||
pre_gen,
|
||||
args.prefix_len,
|
||||
args.model2_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards2,
|
||||
)
|
||||
if args.language_model is not None:
|
||||
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
|
||||
lm_score_file = rerank_utils.rescore_file_name(
|
||||
pre_gen, args.prefix_len, args.lm_name, lm_file=True
|
||||
)
|
||||
|
||||
# get gen output
|
||||
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
|
||||
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
||||
if using_nbest:
|
||||
print("Using predefined n-best list from interactive.py")
|
||||
predictions_bpe_file = args.nbest_list
|
||||
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
|
||||
nbest=using_nbest, prefix_len=args.prefix_len,
|
||||
target_prefix_frac=args.target_prefix_frac)
|
||||
gen_output = rerank_utils.BitextOutputFromGen(
|
||||
predictions_bpe_file,
|
||||
bpe_symbol=args.remove_bpe,
|
||||
nbest=using_nbest,
|
||||
prefix_len=args.prefix_len,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
)
|
||||
|
||||
if rerank1_is_gen:
|
||||
bitext1 = gen_output
|
||||
else:
|
||||
bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1,
|
||||
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
|
||||
args.source_prefix_frac)
|
||||
bitext1 = rerank_utils.BitextOutput(
|
||||
score1_file,
|
||||
args.backwards1,
|
||||
args.right_to_left1,
|
||||
args.remove_bpe,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
|
||||
if args.score_model2 is not None or args.nbest_list is not None:
|
||||
if rerank2_is_gen:
|
||||
bitext2 = gen_output
|
||||
else:
|
||||
bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2,
|
||||
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
|
||||
args.source_prefix_frac)
|
||||
bitext2 = rerank_utils.BitextOutput(
|
||||
score2_file,
|
||||
args.backwards2,
|
||||
args.right_to_left2,
|
||||
args.remove_bpe,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
|
||||
assert bitext2.source_lengths == bitext1.source_lengths, \
|
||||
"source lengths for rescoring models do not match"
|
||||
assert bitext2.target_lengths == bitext1.target_lengths, \
|
||||
"target lengths for rescoring models do not match"
|
||||
assert (
|
||||
bitext2.source_lengths == bitext1.source_lengths
|
||||
), "source lengths for rescoring models do not match"
|
||||
assert (
|
||||
bitext2.target_lengths == bitext1.target_lengths
|
||||
), "target lengths for rescoring models do not match"
|
||||
else:
|
||||
if args.diff_bpe:
|
||||
assert args.score_model2 is None
|
||||
@ -232,8 +336,13 @@ def load_score_files(args):
|
||||
bitext2 = None
|
||||
|
||||
if args.language_model is not None:
|
||||
lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len,
|
||||
args.remove_bpe, args.target_prefix_frac)
|
||||
lm_res1 = rerank_utils.LMOutput(
|
||||
lm_score_file,
|
||||
args.lm_dict,
|
||||
args.prefix_len,
|
||||
args.remove_bpe,
|
||||
args.target_prefix_frac,
|
||||
)
|
||||
else:
|
||||
lm_res1 = None
|
||||
|
||||
@ -259,28 +368,46 @@ def rerank(args):
|
||||
shard_ids = [args.shard_id]
|
||||
|
||||
for shard_id in shard_ids:
|
||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
||||
args.gen_model_name, shard_id, args.num_shards, args.sampling,
|
||||
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
|
||||
(
|
||||
pre_gen,
|
||||
left_to_right_preprocessed_dir,
|
||||
right_to_left_preprocessed_dir,
|
||||
backwards_preprocessed_dir,
|
||||
lm_preprocessed_dir,
|
||||
) = rerank_utils.get_directories(
|
||||
args.data_dir_name,
|
||||
args.num_rescore,
|
||||
args.gen_subset,
|
||||
args.gen_model_name,
|
||||
shard_id,
|
||||
args.num_shards,
|
||||
args.sampling,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
rerank_generate.gen_and_reprocess_nbest(args)
|
||||
rerank_score_bw.score_bw(args)
|
||||
rerank_score_lm.score_lm(args)
|
||||
|
||||
if args.write_hypos is None:
|
||||
write_targets = pre_gen+"/matched_targets"
|
||||
write_hypos = pre_gen+"/matched_hypos"
|
||||
write_targets = pre_gen + "/matched_targets"
|
||||
write_hypos = pre_gen + "/matched_hypos"
|
||||
else:
|
||||
write_targets = args.write_hypos+"_targets" + args.gen_subset
|
||||
write_hypos = args.write_hypos+"_hypos" + args.gen_subset
|
||||
write_targets = args.write_hypos + "_targets" + args.gen_subset
|
||||
write_hypos = args.write_hypos + "_hypos" + args.gen_subset
|
||||
|
||||
if args.all_shards:
|
||||
write_targets += "_all_shards"
|
||||
write_hypos += "_all_shards"
|
||||
|
||||
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \
|
||||
match_target_hypo(args, write_targets, write_hypos)
|
||||
(
|
||||
best_lenpen,
|
||||
best_weight1,
|
||||
best_weight2,
|
||||
best_weight3,
|
||||
best_score,
|
||||
) = match_target_hypo(args, write_targets, write_hypos)
|
||||
|
||||
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
|
||||
|
||||
@ -291,5 +418,5 @@ def cli_main():
|
||||
rerank(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
@ -8,9 +8,9 @@
|
||||
Generate n-best translations using a trained model.
|
||||
"""
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
import os
|
||||
import subprocess
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
from fairseq import options
|
||||
from fairseq_cli import generate, preprocess
|
||||
@ -22,8 +22,12 @@ def gen_and_reprocess_nbest(args):
|
||||
if args.score_dict_dir is None:
|
||||
args.score_dict_dir = args.data
|
||||
if args.prefix_len is not None:
|
||||
assert args.right_to_left1 is False, "prefix length not compatible with right to left models"
|
||||
assert args.right_to_left2 is False, "prefix length not compatible with right to left models"
|
||||
assert (
|
||||
args.right_to_left1 is False
|
||||
), "prefix length not compatible with right to left models"
|
||||
assert (
|
||||
args.right_to_left2 is False
|
||||
), "prefix length not compatible with right to left models"
|
||||
|
||||
if args.nbest_list is not None:
|
||||
assert args.score_model2 is None
|
||||
@ -35,27 +39,50 @@ def gen_and_reprocess_nbest(args):
|
||||
scorer1_src = args.source_lang
|
||||
scorer1_tgt = args.target_lang
|
||||
|
||||
store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name
|
||||
store_data = (
|
||||
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
|
||||
)
|
||||
if not os.path.exists(store_data):
|
||||
os.makedirs(store_data)
|
||||
|
||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
||||
args.gen_model_name, args.shard_id, args.num_shards,
|
||||
args.sampling, args.prefix_len, args.target_prefix_frac,
|
||||
args.source_prefix_frac)
|
||||
assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported"
|
||||
assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported"
|
||||
assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \
|
||||
"target prefix frac and target prefix len incompatible"
|
||||
(
|
||||
pre_gen,
|
||||
left_to_right_preprocessed_dir,
|
||||
right_to_left_preprocessed_dir,
|
||||
backwards_preprocessed_dir,
|
||||
lm_preprocessed_dir,
|
||||
) = rerank_utils.get_directories(
|
||||
args.data_dir_name,
|
||||
args.num_rescore,
|
||||
args.gen_subset,
|
||||
args.gen_model_name,
|
||||
args.shard_id,
|
||||
args.num_shards,
|
||||
args.sampling,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
assert not (
|
||||
args.right_to_left1 and args.backwards1
|
||||
), "backwards right to left not supported"
|
||||
assert not (
|
||||
args.right_to_left2 and args.backwards2
|
||||
), "backwards right to left not supported"
|
||||
assert not (
|
||||
args.prefix_len is not None and args.target_prefix_frac is not None
|
||||
), "target prefix frac and target prefix len incompatible"
|
||||
|
||||
# make directory to store generation results
|
||||
if not os.path.exists(pre_gen):
|
||||
os.makedirs(pre_gen)
|
||||
|
||||
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||
rerank1_is_gen = (
|
||||
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||
)
|
||||
rerank2_is_gen = (
|
||||
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||
)
|
||||
|
||||
if args.nbest_list is not None:
|
||||
rerank2_is_gen = True
|
||||
@ -70,17 +97,25 @@ def gen_and_reprocess_nbest(args):
|
||||
if not os.path.exists(backwards_preprocessed_dir):
|
||||
os.makedirs(backwards_preprocessed_dir)
|
||||
|
||||
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards1)
|
||||
score1_file = rerank_utils.rescore_file_name(
|
||||
pre_gen,
|
||||
args.prefix_len,
|
||||
args.model1_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards1,
|
||||
)
|
||||
if args.score_model2 is not None:
|
||||
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards2)
|
||||
score2_file = rerank_utils.rescore_file_name(
|
||||
pre_gen,
|
||||
args.prefix_len,
|
||||
args.model2_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards2,
|
||||
)
|
||||
|
||||
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
|
||||
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
||||
|
||||
using_nbest = args.nbest_list is not None
|
||||
|
||||
@ -92,17 +127,29 @@ def gen_and_reprocess_nbest(args):
|
||||
if not os.path.isfile(predictions_bpe_file):
|
||||
print("STEP 1: generate predictions using the p(T|S) model with bpe")
|
||||
print(args.data)
|
||||
param1 = [args.data,
|
||||
"--path", args.gen_model,
|
||||
"--shard-id", str(args.shard_id),
|
||||
"--num-shards", str(args.num_shards),
|
||||
"--nbest", str(args.num_rescore),
|
||||
"--batch-size", str(args.batch_size),
|
||||
"--beam", str(args.num_rescore),
|
||||
"--batch-size", str(args.num_rescore),
|
||||
"--gen-subset", args.gen_subset,
|
||||
"--source-lang", args.source_lang,
|
||||
"--target-lang", args.target_lang]
|
||||
param1 = [
|
||||
args.data,
|
||||
"--path",
|
||||
args.gen_model,
|
||||
"--shard-id",
|
||||
str(args.shard_id),
|
||||
"--num-shards",
|
||||
str(args.num_shards),
|
||||
"--nbest",
|
||||
str(args.num_rescore),
|
||||
"--batch-size",
|
||||
str(args.batch_size),
|
||||
"--beam",
|
||||
str(args.num_rescore),
|
||||
"--batch-size",
|
||||
str(args.num_rescore),
|
||||
"--gen-subset",
|
||||
args.gen_subset,
|
||||
"--source-lang",
|
||||
args.source_lang,
|
||||
"--target-lang",
|
||||
args.target_lang,
|
||||
]
|
||||
if args.sampling:
|
||||
param1 += ["--sampling"]
|
||||
|
||||
@ -110,124 +157,229 @@ def gen_and_reprocess_nbest(args):
|
||||
input_args = options.parse_args_and_arch(gen_parser, param1)
|
||||
|
||||
print(input_args)
|
||||
with open(predictions_bpe_file, 'w') as f:
|
||||
with open(predictions_bpe_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
generate.main(input_args)
|
||||
|
||||
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
|
||||
nbest=using_nbest, prefix_len=args.prefix_len,
|
||||
target_prefix_frac=args.target_prefix_frac)
|
||||
gen_output = rerank_utils.BitextOutputFromGen(
|
||||
predictions_bpe_file,
|
||||
bpe_symbol=args.remove_bpe,
|
||||
nbest=using_nbest,
|
||||
prefix_len=args.prefix_len,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
)
|
||||
|
||||
if args.diff_bpe:
|
||||
rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
|
||||
gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang,
|
||||
pre_gen+"/target_gen_bpe."+args.target_lang,
|
||||
pre_gen+"/reference_gen_bpe."+args.target_lang)
|
||||
rerank_utils.write_reprocessed(
|
||||
gen_output.no_bpe_source,
|
||||
gen_output.no_bpe_hypo,
|
||||
gen_output.no_bpe_target,
|
||||
pre_gen + "/source_gen_bpe." + args.source_lang,
|
||||
pre_gen + "/target_gen_bpe." + args.target_lang,
|
||||
pre_gen + "/reference_gen_bpe." + args.target_lang,
|
||||
)
|
||||
bitext_bpe = args.rescore_bpe_code
|
||||
bpe_src_param = ["-c", bitext_bpe,
|
||||
"--input", pre_gen+"/source_gen_bpe."+args.source_lang,
|
||||
"--output", pre_gen+"/rescore_data."+args.source_lang]
|
||||
bpe_tgt_param = ["-c", bitext_bpe,
|
||||
"--input", pre_gen+"/target_gen_bpe."+args.target_lang,
|
||||
"--output", pre_gen+"/rescore_data."+args.target_lang]
|
||||
bpe_src_param = [
|
||||
"-c",
|
||||
bitext_bpe,
|
||||
"--input",
|
||||
pre_gen + "/source_gen_bpe." + args.source_lang,
|
||||
"--output",
|
||||
pre_gen + "/rescore_data." + args.source_lang,
|
||||
]
|
||||
bpe_tgt_param = [
|
||||
"-c",
|
||||
bitext_bpe,
|
||||
"--input",
|
||||
pre_gen + "/target_gen_bpe." + args.target_lang,
|
||||
"--output",
|
||||
pre_gen + "/rescore_data." + args.target_lang,
|
||||
]
|
||||
|
||||
subprocess.call(["python",
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
|
||||
shell=False)
|
||||
subprocess.call(
|
||||
[
|
||||
"python",
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
||||
),
|
||||
]
|
||||
+ bpe_src_param,
|
||||
shell=False,
|
||||
)
|
||||
|
||||
subprocess.call(["python",
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param,
|
||||
shell=False)
|
||||
subprocess.call(
|
||||
[
|
||||
"python",
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
||||
),
|
||||
]
|
||||
+ bpe_tgt_param,
|
||||
shell=False,
|
||||
)
|
||||
|
||||
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \
|
||||
(args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen):
|
||||
print("STEP 2: process the output of generate.py so we have clean text files with the translations")
|
||||
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
|
||||
args.score_model2 is not None
|
||||
and not os.path.isfile(score2_file)
|
||||
and not rerank2_is_gen
|
||||
):
|
||||
print(
|
||||
"STEP 2: process the output of generate.py so we have clean text files with the translations"
|
||||
)
|
||||
|
||||
rescore_file = "/rescore_data"
|
||||
if args.prefix_len is not None:
|
||||
prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len)
|
||||
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
|
||||
if args.target_prefix_frac is not None:
|
||||
target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac)
|
||||
target_prefix_frac_rescore_file = (
|
||||
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
|
||||
)
|
||||
if args.source_prefix_frac is not None:
|
||||
source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac)
|
||||
source_prefix_frac_rescore_file = (
|
||||
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
|
||||
)
|
||||
|
||||
if not args.right_to_left1 or not args.right_to_left2:
|
||||
if not args.diff_bpe:
|
||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
||||
pre_gen+rescore_file+"."+args.source_lang,
|
||||
pre_gen+rescore_file+"."+args.target_lang,
|
||||
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe)
|
||||
rerank_utils.write_reprocessed(
|
||||
gen_output.source,
|
||||
gen_output.hypo,
|
||||
gen_output.target,
|
||||
pre_gen + rescore_file + "." + args.source_lang,
|
||||
pre_gen + rescore_file + "." + args.target_lang,
|
||||
pre_gen + "/reference_file",
|
||||
bpe_symbol=args.remove_bpe,
|
||||
)
|
||||
if args.prefix_len is not None:
|
||||
bw_rescore_file = prefix_len_rescore_file
|
||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
||||
pre_gen+prefix_len_rescore_file+"."+args.source_lang,
|
||||
pre_gen+prefix_len_rescore_file+"."+args.target_lang,
|
||||
pre_gen+"/reference_file", prefix_len=args.prefix_len,
|
||||
bpe_symbol=args.remove_bpe)
|
||||
rerank_utils.write_reprocessed(
|
||||
gen_output.source,
|
||||
gen_output.hypo,
|
||||
gen_output.target,
|
||||
pre_gen + prefix_len_rescore_file + "." + args.source_lang,
|
||||
pre_gen + prefix_len_rescore_file + "." + args.target_lang,
|
||||
pre_gen + "/reference_file",
|
||||
prefix_len=args.prefix_len,
|
||||
bpe_symbol=args.remove_bpe,
|
||||
)
|
||||
elif args.target_prefix_frac is not None:
|
||||
bw_rescore_file = target_prefix_frac_rescore_file
|
||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
||||
pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang,
|
||||
pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang,
|
||||
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
|
||||
target_prefix_frac=args.target_prefix_frac)
|
||||
rerank_utils.write_reprocessed(
|
||||
gen_output.source,
|
||||
gen_output.hypo,
|
||||
gen_output.target,
|
||||
pre_gen
|
||||
+ target_prefix_frac_rescore_file
|
||||
+ "."
|
||||
+ args.source_lang,
|
||||
pre_gen
|
||||
+ target_prefix_frac_rescore_file
|
||||
+ "."
|
||||
+ args.target_lang,
|
||||
pre_gen + "/reference_file",
|
||||
bpe_symbol=args.remove_bpe,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
)
|
||||
else:
|
||||
bw_rescore_file = rescore_file
|
||||
|
||||
if args.source_prefix_frac is not None:
|
||||
fw_rescore_file = source_prefix_frac_rescore_file
|
||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
||||
pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang,
|
||||
pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang,
|
||||
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
|
||||
source_prefix_frac=args.source_prefix_frac)
|
||||
rerank_utils.write_reprocessed(
|
||||
gen_output.source,
|
||||
gen_output.hypo,
|
||||
gen_output.target,
|
||||
pre_gen
|
||||
+ source_prefix_frac_rescore_file
|
||||
+ "."
|
||||
+ args.source_lang,
|
||||
pre_gen
|
||||
+ source_prefix_frac_rescore_file
|
||||
+ "."
|
||||
+ args.target_lang,
|
||||
pre_gen + "/reference_file",
|
||||
bpe_symbol=args.remove_bpe,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
)
|
||||
else:
|
||||
fw_rescore_file = rescore_file
|
||||
|
||||
if args.right_to_left1 or args.right_to_left2:
|
||||
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
|
||||
pre_gen+"/right_to_left_rescore_data."+args.source_lang,
|
||||
pre_gen+"/right_to_left_rescore_data."+args.target_lang,
|
||||
pre_gen+"/right_to_left_reference_file",
|
||||
right_to_left=True, bpe_symbol=args.remove_bpe)
|
||||
rerank_utils.write_reprocessed(
|
||||
gen_output.source,
|
||||
gen_output.hypo,
|
||||
gen_output.target,
|
||||
pre_gen + "/right_to_left_rescore_data." + args.source_lang,
|
||||
pre_gen + "/right_to_left_rescore_data." + args.target_lang,
|
||||
pre_gen + "/right_to_left_reference_file",
|
||||
right_to_left=True,
|
||||
bpe_symbol=args.remove_bpe,
|
||||
)
|
||||
|
||||
print("STEP 3: binarize the translations")
|
||||
if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen:
|
||||
if (
|
||||
not args.right_to_left1
|
||||
or args.score_model2 is not None
|
||||
and not args.right_to_left2
|
||||
or not rerank1_is_gen
|
||||
):
|
||||
|
||||
if args.backwards1 or args.backwards2:
|
||||
if args.backwards_score_dict_dir is not None:
|
||||
bw_dict = args.backwards_score_dict_dir
|
||||
else:
|
||||
bw_dict = args.score_dict_dir
|
||||
bw_preprocess_param = ["--source-lang", scorer1_src,
|
||||
"--target-lang", scorer1_tgt,
|
||||
"--trainpref", pre_gen+bw_rescore_file,
|
||||
"--srcdict", bw_dict + "/dict." + scorer1_src + ".txt",
|
||||
"--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt",
|
||||
"--destdir", backwards_preprocessed_dir]
|
||||
bw_preprocess_param = [
|
||||
"--source-lang",
|
||||
scorer1_src,
|
||||
"--target-lang",
|
||||
scorer1_tgt,
|
||||
"--trainpref",
|
||||
pre_gen + bw_rescore_file,
|
||||
"--srcdict",
|
||||
bw_dict + "/dict." + scorer1_src + ".txt",
|
||||
"--tgtdict",
|
||||
bw_dict + "/dict." + scorer1_tgt + ".txt",
|
||||
"--destdir",
|
||||
backwards_preprocessed_dir,
|
||||
]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(bw_preprocess_param)
|
||||
preprocess.main(input_args)
|
||||
|
||||
preprocess_param = ["--source-lang", scorer1_src,
|
||||
"--target-lang", scorer1_tgt,
|
||||
"--trainpref", pre_gen+fw_rescore_file,
|
||||
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
|
||||
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
|
||||
"--destdir", left_to_right_preprocessed_dir]
|
||||
preprocess_param = [
|
||||
"--source-lang",
|
||||
scorer1_src,
|
||||
"--target-lang",
|
||||
scorer1_tgt,
|
||||
"--trainpref",
|
||||
pre_gen + fw_rescore_file,
|
||||
"--srcdict",
|
||||
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
|
||||
"--tgtdict",
|
||||
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
|
||||
"--destdir",
|
||||
left_to_right_preprocessed_dir,
|
||||
]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(preprocess_param)
|
||||
preprocess.main(input_args)
|
||||
|
||||
if args.right_to_left1 or args.right_to_left2:
|
||||
preprocess_param = ["--source-lang", scorer1_src,
|
||||
"--target-lang", scorer1_tgt,
|
||||
"--trainpref", pre_gen+"/right_to_left_rescore_data",
|
||||
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
|
||||
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
|
||||
"--destdir", right_to_left_preprocessed_dir]
|
||||
preprocess_param = [
|
||||
"--source-lang",
|
||||
scorer1_src,
|
||||
"--target-lang",
|
||||
scorer1_tgt,
|
||||
"--trainpref",
|
||||
pre_gen + "/right_to_left_rescore_data",
|
||||
"--srcdict",
|
||||
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
|
||||
"--tgtdict",
|
||||
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
|
||||
"--destdir",
|
||||
right_to_left_preprocessed_dir,
|
||||
]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(preprocess_param)
|
||||
preprocess.main(input_args)
|
||||
@ -241,5 +393,5 @@ def cli_main():
|
||||
gen_and_reprocess_nbest(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
@ -6,14 +6,14 @@
|
||||
from fairseq import options
|
||||
|
||||
|
||||
def get_reranking_parser(default_task='translation'):
|
||||
parser = options.get_parser('Generation and reranking', default_task)
|
||||
def get_reranking_parser(default_task="translation"):
|
||||
parser = options.get_parser("Generation and reranking", default_task)
|
||||
add_reranking_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def get_tuning_parser(default_task='translation'):
|
||||
parser = options.get_parser('Reranking tuning', default_task)
|
||||
def get_tuning_parser(default_task="translation"):
|
||||
parser = options.get_parser("Reranking tuning", default_task)
|
||||
add_reranking_args(parser)
|
||||
add_tuning_args(parser)
|
||||
return parser
|
||||
@ -110,17 +110,40 @@ def add_reranking_args(parser):
|
||||
def add_tuning_args(parser):
|
||||
group = parser.add_argument_group("Tuning")
|
||||
|
||||
group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float,
|
||||
help='lower bound of search space')
|
||||
group.add_argument('--upper-bound', default=[3], nargs='+', type=float,
|
||||
help='upper bound of search space')
|
||||
group.add_argument('--tune-param', default=['lenpen'], nargs='+',
|
||||
choices=['lenpen', 'weight1', 'weight2', 'weight3'],
|
||||
help='the parameter(s) to tune')
|
||||
group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'],
|
||||
help='the subset to tune on ')
|
||||
group.add_argument('--num-trials', default=1000, type=int,
|
||||
help='number of trials to do for random search')
|
||||
group.add_argument('--share-weights', action='store_true',
|
||||
help='share weight2 and weight 3')
|
||||
group.add_argument(
|
||||
"--lower-bound",
|
||||
default=[-0.7],
|
||||
nargs="+",
|
||||
type=float,
|
||||
help="lower bound of search space",
|
||||
)
|
||||
group.add_argument(
|
||||
"--upper-bound",
|
||||
default=[3],
|
||||
nargs="+",
|
||||
type=float,
|
||||
help="upper bound of search space",
|
||||
)
|
||||
group.add_argument(
|
||||
"--tune-param",
|
||||
default=["lenpen"],
|
||||
nargs="+",
|
||||
choices=["lenpen", "weight1", "weight2", "weight3"],
|
||||
help="the parameter(s) to tune",
|
||||
)
|
||||
group.add_argument(
|
||||
"--tune-subset",
|
||||
default="valid",
|
||||
choices=["valid", "test", "train"],
|
||||
help="the subset to tune on ",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-trials",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="number of trials to do for random search",
|
||||
)
|
||||
group.add_argument(
|
||||
"--share-weights", action="store_true", help="share weight2 and weight 3"
|
||||
)
|
||||
return group
|
||||
|
@ -3,8 +3,8 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
import os
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
from fairseq import options
|
||||
from fairseq_cli import generate
|
||||
@ -13,82 +13,124 @@ from . import rerank_options, rerank_utils
|
||||
|
||||
|
||||
def score_bw(args):
|
||||
if args.backwards1:
|
||||
scorer1_src = args.target_lang
|
||||
scorer1_tgt = args.source_lang
|
||||
if args.backwards1:
|
||||
scorer1_src = args.target_lang
|
||||
scorer1_tgt = args.source_lang
|
||||
else:
|
||||
scorer1_src = args.source_lang
|
||||
scorer1_tgt = args.target_lang
|
||||
|
||||
if args.score_model2 is not None:
|
||||
if args.backwards2:
|
||||
scorer2_src = args.target_lang
|
||||
scorer2_tgt = args.source_lang
|
||||
else:
|
||||
scorer1_src = args.source_lang
|
||||
scorer1_tgt = args.target_lang
|
||||
scorer2_src = args.source_lang
|
||||
scorer2_tgt = args.target_lang
|
||||
|
||||
if args.score_model2 is not None:
|
||||
if args.backwards2:
|
||||
scorer2_src = args.target_lang
|
||||
scorer2_tgt = args.source_lang
|
||||
else:
|
||||
scorer2_src = args.source_lang
|
||||
scorer2_tgt = args.target_lang
|
||||
rerank1_is_gen = (
|
||||
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||
)
|
||||
rerank2_is_gen = (
|
||||
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||
)
|
||||
|
||||
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
||||
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
||||
(
|
||||
pre_gen,
|
||||
left_to_right_preprocessed_dir,
|
||||
right_to_left_preprocessed_dir,
|
||||
backwards_preprocessed_dir,
|
||||
lm_preprocessed_dir,
|
||||
) = rerank_utils.get_directories(
|
||||
args.data_dir_name,
|
||||
args.num_rescore,
|
||||
args.gen_subset,
|
||||
args.gen_model_name,
|
||||
args.shard_id,
|
||||
args.num_shards,
|
||||
args.sampling,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
|
||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
||||
args.gen_model_name, args.shard_id, args.num_shards,
|
||||
args.sampling, args.prefix_len, args.target_prefix_frac,
|
||||
args.source_prefix_frac)
|
||||
score1_file = rerank_utils.rescore_file_name(
|
||||
pre_gen,
|
||||
args.prefix_len,
|
||||
args.model1_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards1,
|
||||
)
|
||||
|
||||
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards1)
|
||||
if args.score_model2 is not None:
|
||||
score2_file = rerank_utils.rescore_file_name(
|
||||
pre_gen,
|
||||
args.prefix_len,
|
||||
args.model2_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards2,
|
||||
)
|
||||
|
||||
if args.score_model2 is not None:
|
||||
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
|
||||
target_prefix_frac=args.target_prefix_frac,
|
||||
source_prefix_frac=args.source_prefix_frac,
|
||||
backwards=args.backwards2)
|
||||
if args.right_to_left1:
|
||||
rerank_data1 = right_to_left_preprocessed_dir
|
||||
elif args.backwards1:
|
||||
rerank_data1 = backwards_preprocessed_dir
|
||||
else:
|
||||
rerank_data1 = left_to_right_preprocessed_dir
|
||||
|
||||
if args.right_to_left1:
|
||||
rerank_data1 = right_to_left_preprocessed_dir
|
||||
elif args.backwards1:
|
||||
rerank_data1 = backwards_preprocessed_dir
|
||||
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
|
||||
if not rerank1_is_gen and not os.path.isfile(score1_file):
|
||||
print("STEP 4: score the translations for model 1")
|
||||
|
||||
model_param1 = [
|
||||
"--path",
|
||||
args.score_model1,
|
||||
"--source-lang",
|
||||
scorer1_src,
|
||||
"--target-lang",
|
||||
scorer1_tgt,
|
||||
]
|
||||
gen_model1_param = [rerank_data1] + gen_param + model_param1
|
||||
|
||||
gen_parser = options.get_generation_parser()
|
||||
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
|
||||
|
||||
with open(score1_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
generate.main(input_args)
|
||||
|
||||
if (
|
||||
args.score_model2 is not None
|
||||
and not os.path.isfile(score2_file)
|
||||
and not rerank2_is_gen
|
||||
):
|
||||
print("STEP 4: score the translations for model 2")
|
||||
|
||||
if args.right_to_left2:
|
||||
rerank_data2 = right_to_left_preprocessed_dir
|
||||
elif args.backwards2:
|
||||
rerank_data2 = backwards_preprocessed_dir
|
||||
else:
|
||||
rerank_data1 = left_to_right_preprocessed_dir
|
||||
rerank_data2 = left_to_right_preprocessed_dir
|
||||
|
||||
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
|
||||
if not rerank1_is_gen and not os.path.isfile(score1_file):
|
||||
print("STEP 4: score the translations for model 1")
|
||||
model_param2 = [
|
||||
"--path",
|
||||
args.score_model2,
|
||||
"--source-lang",
|
||||
scorer2_src,
|
||||
"--target-lang",
|
||||
scorer2_tgt,
|
||||
]
|
||||
gen_model2_param = [rerank_data2] + gen_param + model_param2
|
||||
|
||||
model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt]
|
||||
gen_model1_param = [rerank_data1] + gen_param + model_param1
|
||||
gen_parser = options.get_generation_parser()
|
||||
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
|
||||
|
||||
gen_parser = options.get_generation_parser()
|
||||
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
|
||||
|
||||
with open(score1_file, 'w') as f:
|
||||
with redirect_stdout(f):
|
||||
generate.main(input_args)
|
||||
|
||||
if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
|
||||
print("STEP 4: score the translations for model 2")
|
||||
|
||||
if args.right_to_left2:
|
||||
rerank_data2 = right_to_left_preprocessed_dir
|
||||
elif args.backwards2:
|
||||
rerank_data2 = backwards_preprocessed_dir
|
||||
else:
|
||||
rerank_data2 = left_to_right_preprocessed_dir
|
||||
|
||||
model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
|
||||
gen_model2_param = [rerank_data2] + gen_param + model_param2
|
||||
|
||||
gen_parser = options.get_generation_parser()
|
||||
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
|
||||
|
||||
with open(score2_file, 'w') as f:
|
||||
with redirect_stdout(f):
|
||||
generate.main(input_args)
|
||||
with open(score2_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
generate.main(input_args)
|
||||
|
||||
|
||||
def cli_main():
|
||||
@ -97,5 +139,5 @@ def cli_main():
|
||||
score_bw(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
@ -12,22 +12,38 @@ from . import rerank_options, rerank_utils
|
||||
|
||||
def score_lm(args):
|
||||
using_nbest = args.nbest_list is not None
|
||||
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
||||
backwards_preprocessed_dir, lm_preprocessed_dir = \
|
||||
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
|
||||
args.gen_model_name, args.shard_id, args.num_shards,
|
||||
args.sampling, args.prefix_len, args.target_prefix_frac,
|
||||
args.source_prefix_frac)
|
||||
(
|
||||
pre_gen,
|
||||
left_to_right_preprocessed_dir,
|
||||
right_to_left_preprocessed_dir,
|
||||
backwards_preprocessed_dir,
|
||||
lm_preprocessed_dir,
|
||||
) = rerank_utils.get_directories(
|
||||
args.data_dir_name,
|
||||
args.num_rescore,
|
||||
args.gen_subset,
|
||||
args.gen_model_name,
|
||||
args.shard_id,
|
||||
args.num_shards,
|
||||
args.sampling,
|
||||
args.prefix_len,
|
||||
args.target_prefix_frac,
|
||||
args.source_prefix_frac,
|
||||
)
|
||||
|
||||
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
|
||||
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
||||
if using_nbest:
|
||||
print("Using predefined n-best list from interactive.py")
|
||||
predictions_bpe_file = args.nbest_list
|
||||
|
||||
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest)
|
||||
gen_output = rerank_utils.BitextOutputFromGen(
|
||||
predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
|
||||
)
|
||||
|
||||
if args.language_model is not None:
|
||||
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
|
||||
lm_score_file = rerank_utils.rescore_file_name(
|
||||
pre_gen, args.prefix_len, args.lm_name, lm_file=True
|
||||
)
|
||||
|
||||
if args.language_model is not None and not os.path.isfile(lm_score_file):
|
||||
print("STEP 4.5: language modeling for P(T)")
|
||||
@ -38,10 +54,21 @@ def score_lm(args):
|
||||
else:
|
||||
bpe_status = "different"
|
||||
|
||||
rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen,
|
||||
args.lm_dict, args.lm_name, args.language_model,
|
||||
args.lm_bpe_code, 128, lm_score_file, args.target_lang,
|
||||
args.source_lang, prefix_len=args.prefix_len)
|
||||
rerank_utils.lm_scoring(
|
||||
lm_preprocessed_dir,
|
||||
bpe_status,
|
||||
gen_output,
|
||||
pre_gen,
|
||||
args.lm_dict,
|
||||
args.lm_name,
|
||||
args.language_model,
|
||||
args.lm_bpe_code,
|
||||
128,
|
||||
lm_score_file,
|
||||
args.target_lang,
|
||||
args.source_lang,
|
||||
prefix_len=args.prefix_len,
|
||||
)
|
||||
|
||||
|
||||
def cli_main():
|
||||
@ -50,5 +77,5 @@ def cli_main():
|
||||
score_lm(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
@ -5,8 +5,8 @@
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
from fairseq import options
|
||||
|
||||
from . import rerank, rerank_options
|
||||
@ -14,7 +14,7 @@ from . import rerank, rerank_options
|
||||
|
||||
def random_search(args):
|
||||
param_values = []
|
||||
tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3']
|
||||
tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
|
||||
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
|
||||
for i, elem in enumerate(initial_params):
|
||||
if type(elem) is not list:
|
||||
@ -33,51 +33,60 @@ def random_search(args):
|
||||
param_values += initial_params
|
||||
random.seed(args.seed)
|
||||
|
||||
random_params = np.array([
|
||||
[random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))]
|
||||
for k in range(args.num_trials)
|
||||
])
|
||||
set_params = np.array([
|
||||
[initial_params[i][0] for i in range(len(tuneable_parameters))]
|
||||
for k in range(args.num_trials)
|
||||
])
|
||||
random_params = np.array(
|
||||
[
|
||||
[
|
||||
random.uniform(args.lower_bound[i], args.upper_bound[i])
|
||||
for i in range(len(args.tune_param))
|
||||
]
|
||||
for k in range(args.num_trials)
|
||||
]
|
||||
)
|
||||
set_params = np.array(
|
||||
[
|
||||
[initial_params[i][0] for i in range(len(tuneable_parameters))]
|
||||
for k in range(args.num_trials)
|
||||
]
|
||||
)
|
||||
random_params = np.concatenate((random_params, set_params), 1)
|
||||
|
||||
rerank_args = vars(args).copy()
|
||||
if args.nbest_list:
|
||||
rerank_args['gen_subset'] = 'test'
|
||||
rerank_args["gen_subset"] = "test"
|
||||
else:
|
||||
rerank_args['gen_subset'] = args.tune_subset
|
||||
rerank_args["gen_subset"] = args.tune_subset
|
||||
|
||||
for k in range(len(tune_parameters)):
|
||||
rerank_args[tune_parameters[k]] = list(random_params[:, k])
|
||||
|
||||
if args.share_weights:
|
||||
k = tune_parameters.index('weight2')
|
||||
rerank_args['weight3'] = list(random_params[:, k])
|
||||
k = tune_parameters.index("weight2")
|
||||
rerank_args["weight3"] = list(random_params[:, k])
|
||||
|
||||
rerank_args = argparse.Namespace(**rerank_args)
|
||||
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args)
|
||||
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
|
||||
rerank_args
|
||||
)
|
||||
rerank_args = vars(args).copy()
|
||||
rerank_args['lenpen'] = [best_lenpen]
|
||||
rerank_args['weight1'] = [best_weight1]
|
||||
rerank_args['weight2'] = [best_weight2]
|
||||
rerank_args['weight3'] = [best_weight3]
|
||||
rerank_args["lenpen"] = [best_lenpen]
|
||||
rerank_args["weight1"] = [best_weight1]
|
||||
rerank_args["weight2"] = [best_weight2]
|
||||
rerank_args["weight3"] = [best_weight3]
|
||||
|
||||
# write the hypothesis from the valid set from the best trial
|
||||
|
||||
if args.gen_subset != "valid":
|
||||
rerank_args['gen_subset'] = "valid"
|
||||
rerank_args["gen_subset"] = "valid"
|
||||
rerank_args = argparse.Namespace(**rerank_args)
|
||||
rerank.rerank(rerank_args)
|
||||
|
||||
# test with the best hyperparameters on gen subset
|
||||
rerank_args = vars(args).copy()
|
||||
rerank_args['gen_subset'] = args.gen_subset
|
||||
rerank_args['lenpen'] = [best_lenpen]
|
||||
rerank_args['weight1'] = [best_weight1]
|
||||
rerank_args['weight2'] = [best_weight2]
|
||||
rerank_args['weight3'] = [best_weight3]
|
||||
rerank_args["gen_subset"] = args.gen_subset
|
||||
rerank_args["lenpen"] = [best_lenpen]
|
||||
rerank_args["weight1"] = [best_weight1]
|
||||
rerank_args["weight2"] = [best_weight2]
|
||||
rerank_args["weight3"] = [best_weight3]
|
||||
rerank_args = argparse.Namespace(**rerank_args)
|
||||
rerank.rerank(rerank_args)
|
||||
|
||||
@ -89,5 +98,5 @@ def cli_main():
|
||||
random_search(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
@ -3,11 +3,11 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
from fairseq import options
|
||||
from fairseq_cli import eval_lm, preprocess
|
||||
@ -20,7 +20,7 @@ def reprocess(fle):
|
||||
# per source, so the values for hypothesis_dict are lists.
|
||||
# parses output of generate.py
|
||||
|
||||
with open(fle, 'r') as f:
|
||||
with open(fle, "r") as f:
|
||||
txt = f.read()
|
||||
|
||||
"""reprocess generate.py output"""
|
||||
@ -45,7 +45,9 @@ def reprocess(fle):
|
||||
if line_type == "H":
|
||||
h_txt = line[j:]
|
||||
hypo = re.search(hp, h_txt)
|
||||
assert hypo is not None, ("regular expression failed to find the hypothesis scoring")
|
||||
assert (
|
||||
hypo is not None
|
||||
), "regular expression failed to find the hypothesis scoring"
|
||||
_, i = hypo.span()
|
||||
score = hypo.group()
|
||||
if id_num in hypothesis_dict:
|
||||
@ -56,9 +58,9 @@ def reprocess(fle):
|
||||
score_dict[id_num] = [float(score)]
|
||||
|
||||
elif line_type == "S":
|
||||
source_dict[id_num] = (line[j:])
|
||||
source_dict[id_num] = line[j:]
|
||||
elif line_type == "T":
|
||||
target_dict[id_num] = (line[j:])
|
||||
target_dict[id_num] = line[j:]
|
||||
elif line_type == "P":
|
||||
pos_scores = (line[j:]).split()
|
||||
pos_scores = [float(x) for x in pos_scores]
|
||||
@ -72,7 +74,7 @@ def reprocess(fle):
|
||||
|
||||
def reprocess_nbest(fle):
|
||||
"""reprocess interactive.py output"""
|
||||
with open(fle, 'r') as f:
|
||||
with open(fle, "r") as f:
|
||||
txt = f.read()
|
||||
|
||||
source_dict = {}
|
||||
@ -82,7 +84,7 @@ def reprocess_nbest(fle):
|
||||
pos_score_dict = {}
|
||||
lines = txt.split("\n")
|
||||
|
||||
hp = re.compile(r'[-]?\d+[.]?\d+')
|
||||
hp = re.compile(r"[-]?\d+[.]?\d+")
|
||||
j = -1
|
||||
|
||||
for _i, line in enumerate(lines):
|
||||
@ -119,59 +121,76 @@ def reprocess_nbest(fle):
|
||||
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
|
||||
|
||||
|
||||
def write_reprocessed(sources, hypos, targets, source_outfile,
|
||||
hypo_outfile, target_outfile, right_to_left=False,
|
||||
prefix_len=None, bpe_symbol=None,
|
||||
target_prefix_frac=None, source_prefix_frac=None):
|
||||
def write_reprocessed(
|
||||
sources,
|
||||
hypos,
|
||||
targets,
|
||||
source_outfile,
|
||||
hypo_outfile,
|
||||
target_outfile,
|
||||
right_to_left=False,
|
||||
prefix_len=None,
|
||||
bpe_symbol=None,
|
||||
target_prefix_frac=None,
|
||||
source_prefix_frac=None,
|
||||
):
|
||||
|
||||
"""writes nbest hypothesis for rescoring"""
|
||||
assert not (prefix_len is not None and target_prefix_frac is not None), \
|
||||
"in writing reprocessed, only one type of prefix may be used"
|
||||
assert not (prefix_len is not None and source_prefix_frac is not None), \
|
||||
"in writing reprocessed, only one type of prefix may be used"
|
||||
assert not (target_prefix_frac is not None and source_prefix_frac is not None), \
|
||||
"in writing reprocessed, only one type of prefix may be used"
|
||||
assert not (
|
||||
prefix_len is not None and target_prefix_frac is not None
|
||||
), "in writing reprocessed, only one type of prefix may be used"
|
||||
assert not (
|
||||
prefix_len is not None and source_prefix_frac is not None
|
||||
), "in writing reprocessed, only one type of prefix may be used"
|
||||
assert not (
|
||||
target_prefix_frac is not None and source_prefix_frac is not None
|
||||
), "in writing reprocessed, only one type of prefix may be used"
|
||||
|
||||
with open(source_outfile, 'w') as source_file, \
|
||||
open(hypo_outfile, 'w') as hypo_file, \
|
||||
open(target_outfile, 'w') as target_file:
|
||||
with open(source_outfile, "w") as source_file, open(
|
||||
hypo_outfile, "w"
|
||||
) as hypo_file, open(target_outfile, "w") as target_file:
|
||||
|
||||
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
|
||||
if right_to_left:
|
||||
for i in range(len(sources)):
|
||||
for j in range(len(hypos[i])):
|
||||
if prefix_len is None:
|
||||
hypo_file.write(make_right_to_left(hypos[i][j])+"\n")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
source_file.write(make_right_to_left(sources[i])+"\n")
|
||||
target_file.write(make_right_to_left(targets[i])+"\n")
|
||||
for j in range(len(hypos[i])):
|
||||
if prefix_len is None:
|
||||
hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
source_file.write(make_right_to_left(sources[i]) + "\n")
|
||||
target_file.write(make_right_to_left(targets[i]) + "\n")
|
||||
else:
|
||||
for i in sorted(sources.keys()):
|
||||
for j in range(len(hypos[i])):
|
||||
if prefix_len is not None:
|
||||
shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n"
|
||||
hypo_file.write(shortened)
|
||||
source_file.write(sources[i])
|
||||
target_file.write(targets[i])
|
||||
elif target_prefix_frac is not None:
|
||||
num_words, shortened, num_bpe_tokens = \
|
||||
calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol)
|
||||
shortened += "\n"
|
||||
hypo_file.write(shortened)
|
||||
source_file.write(sources[i])
|
||||
target_file.write(targets[i])
|
||||
elif source_prefix_frac is not None:
|
||||
num_words, shortened, num_bpe_tokensn = \
|
||||
calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol)
|
||||
shortened += "\n"
|
||||
hypo_file.write(hypos[i][j])
|
||||
source_file.write(shortened)
|
||||
target_file.write(targets[i])
|
||||
else:
|
||||
hypo_file.write(hypos[i][j])
|
||||
source_file.write(sources[i])
|
||||
target_file.write(targets[i])
|
||||
for j in range(len(hypos[i])):
|
||||
if prefix_len is not None:
|
||||
shortened = (
|
||||
get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
|
||||
+ "\n"
|
||||
)
|
||||
hypo_file.write(shortened)
|
||||
source_file.write(sources[i])
|
||||
target_file.write(targets[i])
|
||||
elif target_prefix_frac is not None:
|
||||
num_words, shortened, num_bpe_tokens = calc_length_from_frac(
|
||||
hypos[i][j], target_prefix_frac, bpe_symbol
|
||||
)
|
||||
shortened += "\n"
|
||||
hypo_file.write(shortened)
|
||||
source_file.write(sources[i])
|
||||
target_file.write(targets[i])
|
||||
elif source_prefix_frac is not None:
|
||||
num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
|
||||
sources[i], source_prefix_frac, bpe_symbol
|
||||
)
|
||||
shortened += "\n"
|
||||
hypo_file.write(hypos[i][j])
|
||||
source_file.write(shortened)
|
||||
target_file.write(targets[i])
|
||||
else:
|
||||
hypo_file.write(hypos[i][j])
|
||||
source_file.write(sources[i])
|
||||
target_file.write(targets[i])
|
||||
|
||||
|
||||
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
|
||||
@ -207,7 +226,9 @@ def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
|
||||
if bpe_count == 0:
|
||||
return sentence[:prefix_len]
|
||||
else:
|
||||
return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count)
|
||||
return sentence[:prefix_len] + get_prefix_from_len(
|
||||
sentence[prefix_len:], bpe_symbol, bpe_count
|
||||
)
|
||||
|
||||
|
||||
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
|
||||
@ -225,9 +246,9 @@ def make_right_to_left(line):
|
||||
|
||||
|
||||
def remove_bpe(line, bpe_symbol):
|
||||
line = line.replace("\n", '')
|
||||
line = (line + ' ').replace(bpe_symbol, '').rstrip()
|
||||
return line+("\n")
|
||||
line = line.replace("\n", "")
|
||||
line = (line + " ").replace(bpe_symbol, "").rstrip()
|
||||
return line + ("\n")
|
||||
|
||||
|
||||
def remove_bpe_dict(pred_dict, bpe_symbol):
|
||||
@ -242,7 +263,7 @@ def remove_bpe_dict(pred_dict, bpe_symbol):
|
||||
|
||||
|
||||
def parse_bleu_scoring(line):
|
||||
p = re.compile(r'(BLEU4 = )\d+[.]\d+')
|
||||
p = re.compile(r"(BLEU4 = )\d+[.]\d+")
|
||||
res = re.search(p, line)
|
||||
assert res is not None, line
|
||||
return float(res.group()[8:])
|
||||
@ -259,9 +280,21 @@ def get_full_from_prefix(hypo_prefix, hypos):
|
||||
raise Exception()
|
||||
|
||||
|
||||
def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None,
|
||||
lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False,
|
||||
bitext2_backwards=False, normalize=False):
|
||||
def get_score(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
target_len,
|
||||
bitext_score1,
|
||||
bitext_score2=None,
|
||||
lm_score=None,
|
||||
lenpen=None,
|
||||
src_len=None,
|
||||
tgt_len=None,
|
||||
bitext1_backwards=False,
|
||||
bitext2_backwards=False,
|
||||
normalize=False,
|
||||
):
|
||||
if bitext1_backwards:
|
||||
bitext1_norm = src_len
|
||||
else:
|
||||
@ -275,9 +308,13 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
|
||||
bitext2_norm = 1
|
||||
bitext_score2 = 0
|
||||
if normalize:
|
||||
score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len
|
||||
score = (
|
||||
a * bitext_score1 / bitext1_norm
|
||||
+ b * bitext_score2 / bitext2_norm
|
||||
+ c * lm_score / src_len
|
||||
)
|
||||
else:
|
||||
score = a*bitext_score1 + b*bitext_score2+c*lm_score
|
||||
score = a * bitext_score1 + b * bitext_score2 + c * lm_score
|
||||
|
||||
if lenpen is not None:
|
||||
score /= (target_len) ** float(lenpen)
|
||||
@ -286,8 +323,16 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
|
||||
|
||||
|
||||
class BitextOutput(object):
|
||||
def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
|
||||
prefix_len=None, target_prefix_frac=None, source_prefix_frac=None):
|
||||
def __init__(
|
||||
self,
|
||||
output_file,
|
||||
backwards,
|
||||
right_to_left,
|
||||
bpe_symbol,
|
||||
prefix_len=None,
|
||||
target_prefix_frac=None,
|
||||
source_prefix_frac=None,
|
||||
):
|
||||
"""process output from rescoring"""
|
||||
source, hypo, score, target, pos_score = reprocess(output_file)
|
||||
if backwards:
|
||||
@ -296,7 +341,9 @@ class BitextOutput(object):
|
||||
self.hypo_fracs = target_prefix_frac
|
||||
|
||||
# remove length penalty so we can use raw scores
|
||||
score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards)
|
||||
score, num_bpe_tokens = get_score_from_pos(
|
||||
pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
|
||||
)
|
||||
source_lengths = {}
|
||||
target_lengths = {}
|
||||
|
||||
@ -341,7 +388,9 @@ class BitextOutput(object):
|
||||
score[i] = float(score[i][0])
|
||||
pos_score[i] = pos_score[i][0]
|
||||
else:
|
||||
assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence"
|
||||
assert (
|
||||
len(hypo[i]) == 1
|
||||
), "expected only one hypothesis per source sentence"
|
||||
source[i] = remove_bpe(source[i], bpe_symbol)
|
||||
target[i] = remove_bpe(target[i], bpe_symbol)
|
||||
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
|
||||
@ -360,11 +409,26 @@ class BitextOutput(object):
|
||||
|
||||
|
||||
class BitextOutputFromGen(object):
|
||||
def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None):
|
||||
def __init__(
|
||||
self,
|
||||
predictions_bpe_file,
|
||||
bpe_symbol=None,
|
||||
nbest=False,
|
||||
prefix_len=None,
|
||||
target_prefix_frac=None,
|
||||
):
|
||||
if nbest:
|
||||
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file)
|
||||
(
|
||||
pred_source,
|
||||
pred_hypo,
|
||||
pred_score,
|
||||
pred_target,
|
||||
pred_pos_score,
|
||||
) = reprocess_nbest(predictions_bpe_file)
|
||||
else:
|
||||
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file)
|
||||
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
|
||||
predictions_bpe_file
|
||||
)
|
||||
|
||||
assert len(pred_source) == len(pred_hypo)
|
||||
assert len(pred_source) == len(pred_score)
|
||||
@ -372,8 +436,9 @@ class BitextOutputFromGen(object):
|
||||
assert len(pred_source) == len(pred_pos_score)
|
||||
|
||||
# remove length penalty so we can use raw scores
|
||||
pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo,
|
||||
bpe_symbol, target_prefix_frac, False)
|
||||
pred_score, num_bpe_tokens = get_score_from_pos(
|
||||
pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
|
||||
)
|
||||
|
||||
self.source = pred_source
|
||||
self.target = pred_target
|
||||
@ -414,7 +479,9 @@ class BitextOutputFromGen(object):
|
||||
index += 1
|
||||
|
||||
|
||||
def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards):
|
||||
def get_score_from_pos(
|
||||
pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
|
||||
):
|
||||
score_dict = {}
|
||||
num_bpe_tokens_dict = {}
|
||||
assert prefix_len is None or hypo_frac is None
|
||||
@ -423,11 +490,15 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
|
||||
num_bpe_tokens_dict[key] = []
|
||||
for i in range(len(pos_score_dict[key])):
|
||||
if prefix_len is not None and not backwards:
|
||||
num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len)
|
||||
num_bpe_tokens = get_num_bpe_tokens_from_len(
|
||||
hypo_dict[key][i], bpe_symbol, prefix_len
|
||||
)
|
||||
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
|
||||
num_bpe_tokens_dict[key].append(num_bpe_tokens)
|
||||
elif hypo_frac is not None:
|
||||
num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol)
|
||||
num_words, shortened, hypo_prefix_len = calc_length_from_frac(
|
||||
hypo_dict[key][i], hypo_frac, bpe_symbol
|
||||
)
|
||||
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
|
||||
num_bpe_tokens_dict[key].append(hypo_prefix_len)
|
||||
else:
|
||||
@ -437,10 +508,26 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
|
||||
|
||||
|
||||
class LMOutput(object):
|
||||
def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
|
||||
lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \
|
||||
parse_lm(lm_score_file, prefix_len=prefix_len,
|
||||
bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac)
|
||||
def __init__(
|
||||
self,
|
||||
lm_score_file,
|
||||
lm_dict=None,
|
||||
prefix_len=None,
|
||||
bpe_symbol=None,
|
||||
target_prefix_frac=None,
|
||||
):
|
||||
(
|
||||
lm_sentences,
|
||||
lm_sen_scores,
|
||||
lm_sen_pos_scores,
|
||||
lm_no_bpe_sentences,
|
||||
lm_bpe_tokens,
|
||||
) = parse_lm(
|
||||
lm_score_file,
|
||||
prefix_len=prefix_len,
|
||||
bpe_symbol=bpe_symbol,
|
||||
target_prefix_frac=target_prefix_frac,
|
||||
)
|
||||
|
||||
self.sentences = lm_sentences
|
||||
self.score = lm_sen_scores
|
||||
@ -452,7 +539,7 @@ class LMOutput(object):
|
||||
|
||||
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
|
||||
"""parse output of eval_lm"""
|
||||
with open(input_file, 'r') as f:
|
||||
with open(input_file, "r") as f:
|
||||
text = f.readlines()
|
||||
text = text[7:]
|
||||
cleaned_text = text[:-2]
|
||||
@ -467,20 +554,23 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
|
||||
if tokens[0].isdigit():
|
||||
line_id = int(tokens[0])
|
||||
scores = [float(x[1:-1]) for x in tokens[2::2]]
|
||||
sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n"
|
||||
sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
|
||||
if bpe_symbol is not None:
|
||||
# exclude <eos> symbol to match output from generate.py
|
||||
bpe_sen = " ".join(tokens[1::2][:-1])+"\n"
|
||||
bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
|
||||
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
|
||||
no_bpe_sentences[line_id] = no_bpe_sen
|
||||
|
||||
if prefix_len is not None:
|
||||
num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len)
|
||||
num_bpe_tokens = get_num_bpe_tokens_from_len(
|
||||
bpe_sen, bpe_symbol, prefix_len
|
||||
)
|
||||
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
|
||||
num_bpe_tokens_dict[line_id] = num_bpe_tokens
|
||||
elif target_prefix_frac is not None:
|
||||
num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac,
|
||||
bpe_symbol)
|
||||
num_words, shortened, target_prefix_len = calc_length_from_frac(
|
||||
bpe_sen, target_prefix_frac, bpe_symbol
|
||||
)
|
||||
sen_scores[line_id] = sum(scores[:target_prefix_len])
|
||||
num_bpe_tokens_dict[line_id] = target_prefix_len
|
||||
else:
|
||||
@ -492,160 +582,269 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
|
||||
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
|
||||
|
||||
|
||||
def get_directories(data_dir_name, num_rescore, gen_subset,
|
||||
fw_name, shard_id, num_shards,
|
||||
sampling=False, prefix_len=None,
|
||||
target_prefix_frac=None, source_prefix_frac=None):
|
||||
nbest_file_id = "nbest_" + str(num_rescore) + \
|
||||
"_subset_" + gen_subset + \
|
||||
"_fw_name_" + fw_name + \
|
||||
"_shard_" + str(shard_id) + \
|
||||
"_of_" + str(num_shards)
|
||||
def get_directories(
|
||||
data_dir_name,
|
||||
num_rescore,
|
||||
gen_subset,
|
||||
fw_name,
|
||||
shard_id,
|
||||
num_shards,
|
||||
sampling=False,
|
||||
prefix_len=None,
|
||||
target_prefix_frac=None,
|
||||
source_prefix_frac=None,
|
||||
):
|
||||
nbest_file_id = (
|
||||
"nbest_"
|
||||
+ str(num_rescore)
|
||||
+ "_subset_"
|
||||
+ gen_subset
|
||||
+ "_fw_name_"
|
||||
+ fw_name
|
||||
+ "_shard_"
|
||||
+ str(shard_id)
|
||||
+ "_of_"
|
||||
+ str(num_shards)
|
||||
)
|
||||
|
||||
if sampling:
|
||||
nbest_file_id += "_sampling"
|
||||
|
||||
# the directory containing all information for this nbest list
|
||||
pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id
|
||||
pre_gen = (
|
||||
os.path.join(os.path.dirname(__file__))
|
||||
+ "/rerank_data/"
|
||||
+ data_dir_name
|
||||
+ "/"
|
||||
+ nbest_file_id
|
||||
)
|
||||
# the directory to store the preprocessed nbest list, for left to right rescoring
|
||||
left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed"
|
||||
left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
|
||||
if source_prefix_frac is not None:
|
||||
left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
|
||||
left_to_right_preprocessed_dir = (
|
||||
left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
|
||||
)
|
||||
# the directory to store the preprocessed nbest list, for right to left rescoring
|
||||
right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed"
|
||||
right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
|
||||
# the directory to store the preprocessed nbest list, for backwards rescoring
|
||||
backwards_preprocessed_dir = pre_gen+"/backwards"
|
||||
backwards_preprocessed_dir = pre_gen + "/backwards"
|
||||
if target_prefix_frac is not None:
|
||||
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac)
|
||||
backwards_preprocessed_dir = (
|
||||
backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
|
||||
)
|
||||
elif prefix_len is not None:
|
||||
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len)
|
||||
backwards_preprocessed_dir = (
|
||||
backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
|
||||
)
|
||||
|
||||
# the directory to store the preprocessed nbest list, for rescoring with P(T)
|
||||
lm_preprocessed_dir = pre_gen+"/lm_preprocessed"
|
||||
lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
|
||||
|
||||
return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
|
||||
backwards_preprocessed_dir, lm_preprocessed_dir
|
||||
return (
|
||||
pre_gen,
|
||||
left_to_right_preprocessed_dir,
|
||||
right_to_left_preprocessed_dir,
|
||||
backwards_preprocessed_dir,
|
||||
lm_preprocessed_dir,
|
||||
)
|
||||
|
||||
|
||||
def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen,
|
||||
cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code,
|
||||
batch_size, lm_score_file, target_lang, source_lang, prefix_len=None):
|
||||
def lm_scoring(
|
||||
preprocess_directory,
|
||||
bpe_status,
|
||||
gen_output,
|
||||
pre_gen,
|
||||
cur_lm_dict,
|
||||
cur_lm_name,
|
||||
cur_language_model,
|
||||
cur_lm_bpe_code,
|
||||
batch_size,
|
||||
lm_score_file,
|
||||
target_lang,
|
||||
source_lang,
|
||||
prefix_len=None,
|
||||
):
|
||||
if prefix_len is not None:
|
||||
assert bpe_status == "different", "bpe status must be different to use prefix len"
|
||||
assert (
|
||||
bpe_status == "different"
|
||||
), "bpe status must be different to use prefix len"
|
||||
if bpe_status == "no bpe":
|
||||
# run lm on output without bpe
|
||||
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
|
||||
gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de",
|
||||
pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe")
|
||||
write_reprocessed(
|
||||
gen_output.no_bpe_source,
|
||||
gen_output.no_bpe_hypo,
|
||||
gen_output.no_bpe_target,
|
||||
pre_gen + "/rescore_data_no_bpe.de",
|
||||
pre_gen + "/rescore_data_no_bpe.en",
|
||||
pre_gen + "/reference_file_no_bpe",
|
||||
)
|
||||
|
||||
preprocess_lm_param = ["--only-source",
|
||||
"--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang,
|
||||
"--srcdict", cur_lm_dict,
|
||||
"--destdir", preprocess_directory]
|
||||
preprocess_lm_param = [
|
||||
"--only-source",
|
||||
"--trainpref",
|
||||
pre_gen + "/rescore_data_no_bpe." + target_lang,
|
||||
"--srcdict",
|
||||
cur_lm_dict,
|
||||
"--destdir",
|
||||
preprocess_directory,
|
||||
]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||
preprocess.main(input_args)
|
||||
|
||||
eval_lm_param = [preprocess_directory,
|
||||
"--path", cur_language_model,
|
||||
"--output-word-probs",
|
||||
"--batch-size", str(batch_size),
|
||||
"--max-tokens", "1024",
|
||||
"--sample-break-mode", "eos",
|
||||
"--gen-subset", "train"]
|
||||
eval_lm_param = [
|
||||
preprocess_directory,
|
||||
"--path",
|
||||
cur_language_model,
|
||||
"--output-word-probs",
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--max-tokens",
|
||||
"1024",
|
||||
"--sample-break-mode",
|
||||
"eos",
|
||||
"--gen-subset",
|
||||
"train",
|
||||
]
|
||||
|
||||
eval_lm_parser = options.get_eval_lm_parser()
|
||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||
|
||||
with open(lm_score_file, 'w') as f:
|
||||
with open(lm_score_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
eval_lm.main(input_args)
|
||||
|
||||
elif bpe_status == "shared":
|
||||
preprocess_lm_param = ["--only-source",
|
||||
"--trainpref", pre_gen+"/rescore_data."+target_lang,
|
||||
"--srcdict", cur_lm_dict,
|
||||
"--destdir", preprocess_directory]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||
preprocess.main(input_args)
|
||||
preprocess_lm_param = [
|
||||
"--only-source",
|
||||
"--trainpref",
|
||||
pre_gen + "/rescore_data." + target_lang,
|
||||
"--srcdict",
|
||||
cur_lm_dict,
|
||||
"--destdir",
|
||||
preprocess_directory,
|
||||
]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||
preprocess.main(input_args)
|
||||
|
||||
eval_lm_param = [preprocess_directory,
|
||||
"--path", cur_language_model,
|
||||
"--output-word-probs",
|
||||
"--batch-size", str(batch_size),
|
||||
"--sample-break-mode", "eos",
|
||||
"--gen-subset", "train"]
|
||||
eval_lm_param = [
|
||||
preprocess_directory,
|
||||
"--path",
|
||||
cur_language_model,
|
||||
"--output-word-probs",
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--sample-break-mode",
|
||||
"eos",
|
||||
"--gen-subset",
|
||||
"train",
|
||||
]
|
||||
|
||||
eval_lm_parser = options.get_eval_lm_parser()
|
||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||
eval_lm_parser = options.get_eval_lm_parser()
|
||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||
|
||||
with open(lm_score_file, 'w') as f:
|
||||
with redirect_stdout(f):
|
||||
eval_lm.main(input_args)
|
||||
with open(lm_score_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
eval_lm.main(input_args)
|
||||
|
||||
elif bpe_status == "different":
|
||||
rescore_file = pre_gen+"/rescore_data_no_bpe"
|
||||
rescore_bpe = pre_gen+"/rescore_data_new_bpe"
|
||||
rescore_file = pre_gen + "/rescore_data_no_bpe"
|
||||
rescore_bpe = pre_gen + "/rescore_data_new_bpe"
|
||||
|
||||
rescore_file += "."
|
||||
rescore_bpe += "."
|
||||
|
||||
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
|
||||
gen_output.no_bpe_target, rescore_file+source_lang,
|
||||
rescore_file+target_lang, pre_gen+"/reference_file_no_bpe",
|
||||
bpe_symbol=None)
|
||||
write_reprocessed(
|
||||
gen_output.no_bpe_source,
|
||||
gen_output.no_bpe_hypo,
|
||||
gen_output.no_bpe_target,
|
||||
rescore_file + source_lang,
|
||||
rescore_file + target_lang,
|
||||
pre_gen + "/reference_file_no_bpe",
|
||||
bpe_symbol=None,
|
||||
)
|
||||
|
||||
# apply LM bpe to nbest list
|
||||
bpe_src_param = ["-c", cur_lm_bpe_code,
|
||||
"--input", rescore_file+target_lang,
|
||||
"--output", rescore_bpe+target_lang]
|
||||
subprocess.call(["python",
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
|
||||
shell=False)
|
||||
bpe_src_param = [
|
||||
"-c",
|
||||
cur_lm_bpe_code,
|
||||
"--input",
|
||||
rescore_file + target_lang,
|
||||
"--output",
|
||||
rescore_bpe + target_lang,
|
||||
]
|
||||
subprocess.call(
|
||||
[
|
||||
"python",
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
||||
),
|
||||
]
|
||||
+ bpe_src_param,
|
||||
shell=False,
|
||||
)
|
||||
# uncomment to use fastbpe instead of subword-nmt bpe
|
||||
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
|
||||
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
|
||||
|
||||
preprocess_dir = preprocess_directory
|
||||
|
||||
preprocess_lm_param = ["--only-source",
|
||||
"--trainpref", rescore_bpe+target_lang,
|
||||
"--srcdict", cur_lm_dict,
|
||||
"--destdir", preprocess_dir]
|
||||
preprocess_lm_param = [
|
||||
"--only-source",
|
||||
"--trainpref",
|
||||
rescore_bpe + target_lang,
|
||||
"--srcdict",
|
||||
cur_lm_dict,
|
||||
"--destdir",
|
||||
preprocess_dir,
|
||||
]
|
||||
preprocess_parser = options.get_preprocessing_parser()
|
||||
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
||||
preprocess.main(input_args)
|
||||
|
||||
eval_lm_param = [preprocess_dir,
|
||||
"--path", cur_language_model,
|
||||
"--output-word-probs",
|
||||
"--batch-size", str(batch_size),
|
||||
"--max-tokens", "1024",
|
||||
"--sample-break-mode", "eos",
|
||||
"--gen-subset", "train"]
|
||||
eval_lm_param = [
|
||||
preprocess_dir,
|
||||
"--path",
|
||||
cur_language_model,
|
||||
"--output-word-probs",
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--max-tokens",
|
||||
"1024",
|
||||
"--sample-break-mode",
|
||||
"eos",
|
||||
"--gen-subset",
|
||||
"train",
|
||||
]
|
||||
|
||||
eval_lm_parser = options.get_eval_lm_parser()
|
||||
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
||||
|
||||
with open(lm_score_file, 'w') as f:
|
||||
with open(lm_score_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
eval_lm.main(input_args)
|
||||
|
||||
|
||||
def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False,
|
||||
target_prefix_frac=None, source_prefix_frac=None, backwards=None):
|
||||
def rescore_file_name(
|
||||
nbest_dir,
|
||||
prefix_len,
|
||||
scorer_name,
|
||||
lm_file=False,
|
||||
target_prefix_frac=None,
|
||||
source_prefix_frac=None,
|
||||
backwards=None,
|
||||
):
|
||||
if lm_file:
|
||||
score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt"
|
||||
score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
|
||||
else:
|
||||
score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt"
|
||||
score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
|
||||
if backwards:
|
||||
if prefix_len is not None:
|
||||
score_file += "prefix_len"+str(prefix_len)
|
||||
score_file += "prefix_len" + str(prefix_len)
|
||||
elif target_prefix_frac is not None:
|
||||
score_file += "target_prefix_frac"+str(target_prefix_frac)
|
||||
score_file += "target_prefix_frac" + str(target_prefix_frac)
|
||||
else:
|
||||
if source_prefix_frac is not None:
|
||||
score_file += "source_prefix_frac"+str(source_prefix_frac)
|
||||
score_file += "source_prefix_frac" + str(source_prefix_frac)
|
||||
return score_file
|
||||
|
@ -13,57 +13,66 @@ logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--en2fr', required=True,
|
||||
help='path to en2fr model')
|
||||
parser.add_argument('--fr2en', required=True,
|
||||
help='path to fr2en mixture of experts model')
|
||||
parser.add_argument('--user-dir',
|
||||
help='path to fairseq examples/translation_moe/src directory')
|
||||
parser.add_argument('--num-experts', type=int, default=10,
|
||||
help='(keep at 10 unless using a different model)')
|
||||
parser.add_argument('files', nargs='*', default=['-'],
|
||||
help='input files to paraphrase; "-" for stdin')
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
|
||||
parser.add_argument(
|
||||
"--fr2en", required=True, help="path to fr2en mixture of experts model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-experts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="(keep at 10 unless using a different model)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"files",
|
||||
nargs="*",
|
||||
default=["-"],
|
||||
help='input files to paraphrase; "-" for stdin',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.user_dir is None:
|
||||
args.user_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
|
||||
'translation_moe',
|
||||
'src',
|
||||
"translation_moe",
|
||||
"src",
|
||||
)
|
||||
if os.path.exists(args.user_dir):
|
||||
logging.info('found user_dir:' + args.user_dir)
|
||||
logging.info("found user_dir:" + args.user_dir)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'cannot find fairseq examples/translation_moe/src '
|
||||
'(tried looking here: {})'.format(args.user_dir)
|
||||
"cannot find fairseq examples/translation_moe/src "
|
||||
"(tried looking here: {})".format(args.user_dir)
|
||||
)
|
||||
|
||||
logging.info('loading en2fr model from:' + args.en2fr)
|
||||
logging.info("loading en2fr model from:" + args.en2fr)
|
||||
en2fr = TransformerModel.from_pretrained(
|
||||
model_name_or_path=args.en2fr,
|
||||
tokenizer='moses',
|
||||
bpe='sentencepiece',
|
||||
tokenizer="moses",
|
||||
bpe="sentencepiece",
|
||||
).eval()
|
||||
|
||||
logging.info('loading fr2en model from:' + args.fr2en)
|
||||
logging.info("loading fr2en model from:" + args.fr2en)
|
||||
fr2en = TransformerModel.from_pretrained(
|
||||
model_name_or_path=args.fr2en,
|
||||
tokenizer='moses',
|
||||
bpe='sentencepiece',
|
||||
tokenizer="moses",
|
||||
bpe="sentencepiece",
|
||||
user_dir=args.user_dir,
|
||||
task='translation_moe',
|
||||
task="translation_moe",
|
||||
).eval()
|
||||
|
||||
def gen_paraphrases(en):
|
||||
fr = en2fr.translate(en)
|
||||
return [
|
||||
fr2en.translate(fr, inference_step_args={'expert': i})
|
||||
fr2en.translate(fr, inference_step_args={"expert": i})
|
||||
for i in range(args.num_experts)
|
||||
]
|
||||
|
||||
logging.info('Type the input sentence and press return:')
|
||||
logging.info("Type the input sentence and press return:")
|
||||
for line in fileinput.input(args.files):
|
||||
line = line.strip()
|
||||
if len(line) == 0:
|
||||
@ -72,5 +81,5 @@ def main():
|
||||
print(paraphrase)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -4,9 +4,9 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
import re
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
class OOVIndexError(IndexError):
|
||||
@ -25,8 +25,8 @@ class OOVIndexError(IndexError):
|
||||
|
||||
def replace_oovs(source_in, target_in, target_out):
|
||||
"""Replaces <unk-N> tokens in the target text with the corresponding word in
|
||||
the source text.
|
||||
"""
|
||||
the source text.
|
||||
"""
|
||||
|
||||
oov_re = re.compile("^<unk-([0-9]+)>$")
|
||||
|
||||
|
@ -10,8 +10,8 @@ from itertools import zip_longest
|
||||
|
||||
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
|
||||
"""Replaces out-of-vocabulary words in source and target text with <unk-N>,
|
||||
where N in is the position of the word in the source sequence.
|
||||
"""
|
||||
where N in is the position of the word in the source sequence.
|
||||
"""
|
||||
|
||||
def format_unk(pos):
|
||||
return "<unk-{}>".format(pos)
|
||||
|
@ -8,19 +8,17 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq import utils, metrics
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from fairseq.models.transformer import (
|
||||
TransformerModel,
|
||||
TransformerDecoder,
|
||||
TransformerEncoder,
|
||||
base_architecture,
|
||||
DEFAULT_MAX_SOURCE_POSITIONS,
|
||||
DEFAULT_MAX_TARGET_POSITIONS,
|
||||
TransformerDecoder,
|
||||
TransformerEncoder,
|
||||
TransformerModel,
|
||||
base_architecture,
|
||||
)
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
@ -8,40 +8,44 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import (
|
||||
data_utils,
|
||||
Dictionary,
|
||||
encoders,
|
||||
IdDataset,
|
||||
ListDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumSamplesDataset,
|
||||
NumelDataset,
|
||||
NumSamplesDataset,
|
||||
RawLabelDataset,
|
||||
RightPadDataset,
|
||||
SortDataset,
|
||||
data_utils,
|
||||
encoders,
|
||||
)
|
||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
@register_task('commonsense_qa')
|
||||
@register_task("commonsense_qa")
|
||||
class CommonsenseQATask(LegacyFairseqTask):
|
||||
"""Task to finetune RoBERTa for Commonsense QA."""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to data directory; we load <split>.jsonl')
|
||||
parser.add_argument('--init-token', type=int, default=None,
|
||||
help='add token at the beginning of each batch item')
|
||||
parser.add_argument('--num-classes', type=int, default=5)
|
||||
parser.add_argument(
|
||||
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-token",
|
||||
type=int,
|
||||
default=None,
|
||||
help="add token at the beginning of each batch item",
|
||||
)
|
||||
parser.add_argument("--num-classes", type=int, default=5)
|
||||
|
||||
def __init__(self, args, vocab):
|
||||
super().__init__(args)
|
||||
self.vocab = vocab
|
||||
self.mask = vocab.add_symbol('<mask>')
|
||||
self.mask = vocab.add_symbol("<mask>")
|
||||
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
|
||||
@ -53,20 +57,24 @@ class CommonsenseQATask(LegacyFairseqTask):
|
||||
filename (str): the filename
|
||||
"""
|
||||
dictionary = Dictionary.load(filename)
|
||||
dictionary.add_symbol('<mask>')
|
||||
dictionary.add_symbol("<mask>")
|
||||
return dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking'
|
||||
assert (
|
||||
args.criterion == "sentence_ranking"
|
||||
), "Must set --criterion=sentence_ranking"
|
||||
|
||||
# load data and label dictionaries
|
||||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
|
||||
print('| dictionary: {} types'.format(len(vocab)))
|
||||
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
||||
print("| dictionary: {} types".format(len(vocab)))
|
||||
|
||||
return cls(args, vocab)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
|
||||
def load_dataset(
|
||||
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
||||
):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
@ -77,16 +85,18 @@ class CommonsenseQATask(LegacyFairseqTask):
|
||||
if self.bpe is not None:
|
||||
s = self.bpe.encode(s)
|
||||
tokens = self.vocab.encode_line(
|
||||
s, append_eos=True, add_if_not_exist=False,
|
||||
s,
|
||||
append_eos=True,
|
||||
add_if_not_exist=False,
|
||||
).long()
|
||||
if append_bos and self.args.init_token is not None:
|
||||
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
||||
return tokens
|
||||
|
||||
if data_path is None:
|
||||
data_path = os.path.join(self.args.data, split + '.jsonl')
|
||||
data_path = os.path.join(self.args.data, split + ".jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
|
||||
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
||||
|
||||
src_tokens = [[] for i in range(self.args.num_classes)]
|
||||
src_lengths = [[] for i in range(self.args.num_classes)]
|
||||
@ -95,20 +105,23 @@ class CommonsenseQATask(LegacyFairseqTask):
|
||||
with open(data_path) as h:
|
||||
for line in h:
|
||||
example = json.loads(line.strip())
|
||||
if 'answerKey' in example:
|
||||
label = ord(example['answerKey']) - ord('A')
|
||||
if "answerKey" in example:
|
||||
label = ord(example["answerKey"]) - ord("A")
|
||||
labels.append(label)
|
||||
question = example['question']['stem']
|
||||
assert len(example['question']['choices']) == self.args.num_classes
|
||||
question = example["question"]["stem"]
|
||||
assert len(example["question"]["choices"]) == self.args.num_classes
|
||||
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
|
||||
question = 'Q: ' + question
|
||||
question = "Q: " + question
|
||||
question_toks = binarize(question, append_bos=True)
|
||||
for i, choice in enumerate(example['question']['choices']):
|
||||
src = 'A: ' + choice['text']
|
||||
for i, choice in enumerate(example["question"]["choices"]):
|
||||
src = "A: " + choice["text"]
|
||||
src_bin = torch.cat([question_toks, binarize(src)])
|
||||
src_tokens[i].append(src_bin)
|
||||
src_lengths[i].append(len(src_bin))
|
||||
assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
|
||||
assert all(
|
||||
len(src_tokens[0]) == len(src_tokens[i])
|
||||
for i in range(self.args.num_classes)
|
||||
)
|
||||
assert len(src_tokens[0]) == len(src_lengths[0])
|
||||
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
|
||||
|
||||
@ -118,24 +131,26 @@ class CommonsenseQATask(LegacyFairseqTask):
|
||||
src_lengths[i] = ListDataset(src_lengths[i])
|
||||
|
||||
dataset = {
|
||||
'id': IdDataset(),
|
||||
'nsentences': NumSamplesDataset(),
|
||||
'ntokens': NumelDataset(src_tokens[0], reduce=True),
|
||||
"id": IdDataset(),
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(src_tokens[0], reduce=True),
|
||||
}
|
||||
|
||||
for i in range(self.args.num_classes):
|
||||
dataset.update({
|
||||
'net_input{}'.format(i + 1): {
|
||||
'src_tokens': RightPadDataset(
|
||||
src_tokens[i],
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
),
|
||||
'src_lengths': src_lengths[i],
|
||||
dataset.update(
|
||||
{
|
||||
"net_input{}".format(i + 1): {
|
||||
"src_tokens": RightPadDataset(
|
||||
src_tokens[i],
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
),
|
||||
"src_lengths": src_lengths[i],
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
if len(labels) > 0:
|
||||
dataset.update({'target': RawLabelDataset(labels)})
|
||||
dataset.update({"target": RawLabelDataset(labels)})
|
||||
|
||||
dataset = NestedDictionaryDataset(
|
||||
dataset,
|
||||
@ -149,17 +164,18 @@ class CommonsenseQATask(LegacyFairseqTask):
|
||||
sort_order=[np.random.permutation(len(dataset))],
|
||||
)
|
||||
|
||||
print('| Loaded {} with {} samples'.format(split, len(dataset)))
|
||||
print("| Loaded {} with {} samples".format(split, len(dataset)))
|
||||
|
||||
self.datasets[split] = dataset
|
||||
return self.datasets[split]
|
||||
|
||||
def build_model(self, args):
|
||||
from fairseq import models
|
||||
|
||||
model = models.build_model(args, self)
|
||||
|
||||
model.register_classification_head(
|
||||
'sentence_classification_head',
|
||||
"sentence_classification_head",
|
||||
num_classes=1,
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,6 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import sys
|
||||
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
|
||||
@ -26,23 +25,23 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--encoder-json",
|
||||
help='path to encoder.json',
|
||||
help="path to encoder.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-bpe",
|
||||
type=str,
|
||||
help='path to vocab.bpe',
|
||||
help="path to vocab.bpe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inputs",
|
||||
nargs="+",
|
||||
default=['-'],
|
||||
default=["-"],
|
||||
help="input files to filter/encode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outputs",
|
||||
nargs="+",
|
||||
default=['-'],
|
||||
default=["-"],
|
||||
help="path to save encoded outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -53,18 +52,21 @@ def main():
|
||||
parser.add_argument("--workers", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert len(args.inputs) == len(args.outputs), \
|
||||
"number of input and output paths should match"
|
||||
assert len(args.inputs) == len(
|
||||
args.outputs
|
||||
), "number of input and output paths should match"
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
inputs = [
|
||||
stack.enter_context(open(input, "r", encoding="utf-8"))
|
||||
if input != "-" else sys.stdin
|
||||
if input != "-"
|
||||
else sys.stdin
|
||||
for input in args.inputs
|
||||
]
|
||||
outputs = [
|
||||
stack.enter_context(open(output, "w", encoding="utf-8"))
|
||||
if output != "-" else sys.stdout
|
||||
if output != "-"
|
||||
else sys.stdout
|
||||
for output in args.outputs
|
||||
]
|
||||
|
||||
@ -87,7 +89,6 @@ def main():
|
||||
|
||||
|
||||
class MultiprocessingEncoder(object):
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
|
@ -25,7 +25,7 @@ def get_examples(data_dir, set_type):
|
||||
examples = []
|
||||
|
||||
levels = ["middle", "high"]
|
||||
set_type_c = set_type.split('-')
|
||||
set_type_c = set_type.split("-")
|
||||
if len(set_type_c) == 2:
|
||||
levels = [set_type_c[1]]
|
||||
set_type = set_type_c[0]
|
||||
@ -33,13 +33,13 @@ def get_examples(data_dir, set_type):
|
||||
cur_dir = os.path.join(data_dir, set_type, level)
|
||||
for filename in os.listdir(cur_dir):
|
||||
cur_path = os.path.join(cur_dir, filename)
|
||||
with open(cur_path, 'r') as f:
|
||||
with open(cur_path, "r") as f:
|
||||
cur_data = json.load(f)
|
||||
answers = cur_data["answers"]
|
||||
options = cur_data["options"]
|
||||
questions = cur_data["questions"]
|
||||
context = cur_data["article"].replace("\n", " ")
|
||||
context = re.sub(r'\s+', ' ', context)
|
||||
context = re.sub(r"\s+", " ", context)
|
||||
for i in range(len(answers)):
|
||||
label = ord(answers[i]) - ord("A")
|
||||
qa_list = []
|
||||
@ -50,7 +50,7 @@ def get_examples(data_dir, set_type):
|
||||
qa_cat = question.replace("_", option)
|
||||
else:
|
||||
qa_cat = " ".join([question, option])
|
||||
qa_cat = re.sub(r'\s+', ' ', qa_cat)
|
||||
qa_cat = re.sub(r"\s+", " ", qa_cat)
|
||||
qa_list.append(qa_cat)
|
||||
examples.append(InputExample(context, qa_list, label))
|
||||
|
||||
@ -64,11 +64,11 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
help='input directory for downloaded RACE dataset',
|
||||
help="input directory for downloaded RACE dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
help='output directory for extracted data',
|
||||
help="output directory for extracted data",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -77,17 +77,20 @@ def main():
|
||||
|
||||
for set_type in ["train", "dev", "test-middle", "test-high"]:
|
||||
examples = get_examples(args.input_dir, set_type)
|
||||
qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
|
||||
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
|
||||
qa_file_paths = [
|
||||
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
|
||||
for i in range(4)
|
||||
]
|
||||
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
|
||||
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
|
||||
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
|
||||
outf_context = open(outf_context_path, 'w')
|
||||
outf_label = open(outf_label_path, 'w')
|
||||
outf_context = open(outf_context_path, "w")
|
||||
outf_label = open(outf_label_path, "w")
|
||||
for example in examples:
|
||||
outf_context.write(example.paragraph + '\n')
|
||||
outf_context.write(example.paragraph + "\n")
|
||||
for i in range(4):
|
||||
qa_files[i].write(example.qa_list[i] + '\n')
|
||||
outf_label.write(str(example.label) + '\n')
|
||||
qa_files[i].write(example.qa_list[i] + "\n")
|
||||
outf_label.write(str(example.label) + "\n")
|
||||
|
||||
for f in qa_files:
|
||||
f.close()
|
||||
@ -95,5 +98,5 @@ def main():
|
||||
outf_context.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -7,19 +7,17 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
||||
from fairseq.data import encoders
|
||||
|
||||
|
||||
@register_criterion('wsc')
|
||||
@register_criterion("wsc")
|
||||
class WSCCriterion(LegacyFairseqCriterion):
|
||||
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
if self.args.save_predictions is not None:
|
||||
self.prediction_h = open(self.args.save_predictions, 'w')
|
||||
self.prediction_h = open(self.args.save_predictions, "w")
|
||||
else:
|
||||
self.prediction_h = None
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
@ -32,12 +30,16 @@ class WSCCriterion(LegacyFairseqCriterion):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
|
||||
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
|
||||
parser.add_argument('--wsc-cross-entropy', action='store_true',
|
||||
help='use cross entropy formulation instead of margin loss')
|
||||
parser.add_argument('--save-predictions', metavar='FILE',
|
||||
help='file to save predictions to')
|
||||
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
|
||||
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
|
||||
parser.add_argument(
|
||||
"--wsc-cross-entropy",
|
||||
action="store_true",
|
||||
help="use cross entropy formulation instead of margin loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-predictions", metavar="FILE", help="file to save predictions to"
|
||||
)
|
||||
|
||||
def get_masked_input(self, tokens, mask):
|
||||
masked_tokens = tokens.clone()
|
||||
@ -60,27 +62,26 @@ class WSCCriterion(LegacyFairseqCriterion):
|
||||
)
|
||||
else:
|
||||
return (
|
||||
- query_lprobs
|
||||
+ self.args.wsc_margin_alpha * (
|
||||
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
|
||||
).clamp(min=0)
|
||||
-query_lprobs
|
||||
+ self.args.wsc_margin_alpha
|
||||
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
|
||||
).sum()
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
# compute loss and accuracy
|
||||
loss, nloss = 0., 0
|
||||
loss, nloss = 0.0, 0
|
||||
ncorrect, nqueries = 0, 0
|
||||
|
||||
for i, label in enumerate(sample['labels']):
|
||||
for i, label in enumerate(sample["labels"]):
|
||||
query_lprobs = self.get_lprobs(
|
||||
model,
|
||||
sample['query_tokens'][i].unsqueeze(0),
|
||||
sample['query_masks'][i].unsqueeze(0),
|
||||
sample["query_tokens"][i].unsqueeze(0),
|
||||
sample["query_masks"][i].unsqueeze(0),
|
||||
)
|
||||
cand_lprobs = self.get_lprobs(
|
||||
model,
|
||||
sample['candidate_tokens'][i],
|
||||
sample['candidate_masks'][i],
|
||||
sample["candidate_tokens"][i],
|
||||
sample["candidate_masks"][i],
|
||||
)
|
||||
|
||||
pred = (query_lprobs >= cand_lprobs).all().item()
|
||||
@ -95,72 +96,72 @@ class WSCCriterion(LegacyFairseqCriterion):
|
||||
nloss += 1
|
||||
loss += self.get_loss(query_lprobs, cand_lprobs)
|
||||
|
||||
id = sample['id'][i].item()
|
||||
id = sample["id"][i].item()
|
||||
if self.prediction_h is not None:
|
||||
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
|
||||
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
|
||||
|
||||
if nloss == 0:
|
||||
loss = torch.tensor(0.0, requires_grad=True)
|
||||
|
||||
sample_size = nqueries if nqueries > 0 else 1
|
||||
logging_output = {
|
||||
'loss': utils.item(loss.data) if reduce else loss.data,
|
||||
'ntokens': sample['ntokens'],
|
||||
'nsentences': sample['nsentences'],
|
||||
'sample_size': sample_size,
|
||||
'ncorrect': ncorrect,
|
||||
'nqueries': nqueries,
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"ncorrect": ncorrect,
|
||||
"nqueries": nqueries,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
agg_output = {
|
||||
'loss': loss_sum / sample_size / math.log(2),
|
||||
'ntokens': ntokens,
|
||||
'nsentences': nsentences,
|
||||
'sample_size': sample_size,
|
||||
"loss": loss_sum / sample_size / math.log(2),
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
|
||||
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
|
||||
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
||||
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
|
||||
if nqueries > 0:
|
||||
agg_output['accuracy'] = ncorrect / float(nqueries)
|
||||
agg_output["accuracy"] = ncorrect / float(nqueries)
|
||||
|
||||
return agg_output
|
||||
|
||||
|
||||
@register_criterion('winogrande')
|
||||
@register_criterion("winogrande")
|
||||
class WinograndeCriterion(WSCCriterion):
|
||||
def forward(self, model, sample, reduce=True):
|
||||
# compute loss and accuracy
|
||||
query_lprobs = self.get_lprobs(
|
||||
model,
|
||||
sample['query_tokens'],
|
||||
sample['query_masks'],
|
||||
sample["query_tokens"],
|
||||
sample["query_masks"],
|
||||
)
|
||||
cand_lprobs = self.get_lprobs(
|
||||
model,
|
||||
sample['candidate_tokens'],
|
||||
sample['candidate_masks'],
|
||||
sample["candidate_tokens"],
|
||||
sample["candidate_masks"],
|
||||
)
|
||||
pred = query_lprobs >= cand_lprobs
|
||||
loss = self.get_loss(query_lprobs, cand_lprobs)
|
||||
|
||||
sample_size = sample['query_tokens'].size(0)
|
||||
sample_size = sample["query_tokens"].size(0)
|
||||
ncorrect = pred.sum().item()
|
||||
logging_output = {
|
||||
'loss': utils.item(loss.data) if reduce else loss.data,
|
||||
'ntokens': sample['ntokens'],
|
||||
'nsentences': sample['nsentences'],
|
||||
'sample_size': sample_size,
|
||||
'ncorrect': ncorrect,
|
||||
'nqueries': sample_size,
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"ncorrect": ncorrect,
|
||||
"nqueries": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
@ -10,47 +10,51 @@ import tempfile
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
data_utils,
|
||||
Dictionary,
|
||||
encoders,
|
||||
IdDataset,
|
||||
ListDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumSamplesDataset,
|
||||
NumelDataset,
|
||||
NumSamplesDataset,
|
||||
PadDataset,
|
||||
SortDataset,
|
||||
data_utils,
|
||||
encoders,
|
||||
)
|
||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
from . import wsc_utils
|
||||
|
||||
|
||||
@register_task('wsc')
|
||||
@register_task("wsc")
|
||||
class WSCTask(LegacyFairseqTask):
|
||||
"""Task to finetune RoBERTa for Winograd Schemas."""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to data directory; we load <split>.jsonl')
|
||||
parser.add_argument('--init-token', type=int, default=None,
|
||||
help='add token at the beginning of each batch item')
|
||||
parser.add_argument(
|
||||
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-token",
|
||||
type=int,
|
||||
default=None,
|
||||
help="add token at the beginning of each batch item",
|
||||
)
|
||||
|
||||
def __init__(self, args, vocab):
|
||||
super().__init__(args)
|
||||
self.vocab = vocab
|
||||
self.mask = vocab.add_symbol('<mask>')
|
||||
self.mask = vocab.add_symbol("<mask>")
|
||||
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
self.tokenizer = encoders.build_tokenizer(args)
|
||||
|
||||
# hack to handle GPT-2 BPE, which includes leading spaces
|
||||
if args.bpe == 'gpt2':
|
||||
if args.bpe == "gpt2":
|
||||
self.leading_space = True
|
||||
self.trailing_space = False
|
||||
else:
|
||||
@ -65,16 +69,16 @@ class WSCTask(LegacyFairseqTask):
|
||||
filename (str): the filename
|
||||
"""
|
||||
dictionary = Dictionary.load(filename)
|
||||
dictionary.add_symbol('<mask>')
|
||||
dictionary.add_symbol("<mask>")
|
||||
return dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
assert args.criterion == 'wsc', 'Must set --criterion=wsc'
|
||||
assert args.criterion == "wsc", "Must set --criterion=wsc"
|
||||
|
||||
# load data and label dictionaries
|
||||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
|
||||
print('| dictionary: {} types'.format(len(vocab)))
|
||||
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
||||
print("| dictionary: {} types".format(len(vocab)))
|
||||
|
||||
return cls(args, vocab)
|
||||
|
||||
@ -84,7 +88,9 @@ class WSCTask(LegacyFairseqTask):
|
||||
if self.bpe is not None:
|
||||
s = self.bpe.encode(s)
|
||||
tokens = self.vocab.encode_line(
|
||||
s, append_eos=append_eos, add_if_not_exist=False,
|
||||
s,
|
||||
append_eos=append_eos,
|
||||
add_if_not_exist=False,
|
||||
).long()
|
||||
if self.args.init_token is not None:
|
||||
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
||||
@ -98,19 +104,21 @@ class WSCTask(LegacyFairseqTask):
|
||||
mask = torch.zeros_like(toks, dtype=torch.bool)
|
||||
mask_start = len(self.binarize(prefix))
|
||||
mask_size = len(self.binarize(leading_space + txt))
|
||||
mask[mask_start:mask_start + mask_size] = 1
|
||||
mask[mask_start : mask_start + mask_size] = 1
|
||||
return toks, mask
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
|
||||
def load_dataset(
|
||||
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
||||
):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if data_path is None:
|
||||
data_path = os.path.join(self.args.data, split + '.jsonl')
|
||||
data_path = os.path.join(self.args.data, split + ".jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
|
||||
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
||||
|
||||
query_tokens = []
|
||||
query_masks = []
|
||||
@ -121,13 +129,15 @@ class WSCTask(LegacyFairseqTask):
|
||||
labels = []
|
||||
|
||||
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
|
||||
prefix = sentence[:pronoun_span.start].text
|
||||
suffix = sentence[pronoun_span.end:].text_with_ws
|
||||
prefix = sentence[: pronoun_span.start].text
|
||||
suffix = sentence[pronoun_span.end :].text_with_ws
|
||||
|
||||
# spaCy spans include trailing spaces, but we need to know about
|
||||
# leading spaces for the GPT-2 BPE
|
||||
leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
|
||||
trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
|
||||
leading_space = (
|
||||
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
|
||||
)
|
||||
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
|
||||
|
||||
# get noun phrases, excluding pronouns and anything overlapping with the query
|
||||
cand_spans = wsc_utils.filter_noun_chunks(
|
||||
@ -152,7 +162,11 @@ class WSCTask(LegacyFairseqTask):
|
||||
cand_toks, cand_masks = [], []
|
||||
for cand_span in cand_spans:
|
||||
toks, mask = self.binarize_with_mask(
|
||||
cand_span.text, prefix, suffix, leading_space, trailing_space,
|
||||
cand_span.text,
|
||||
prefix,
|
||||
suffix,
|
||||
leading_space,
|
||||
trailing_space,
|
||||
)
|
||||
cand_toks.append(toks)
|
||||
cand_masks.append(mask)
|
||||
@ -176,17 +190,17 @@ class WSCTask(LegacyFairseqTask):
|
||||
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
|
||||
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
|
||||
|
||||
labels = ListDataset(labels, [1]*len(labels))
|
||||
labels = ListDataset(labels, [1] * len(labels))
|
||||
|
||||
dataset = {
|
||||
'id': IdDataset(),
|
||||
'query_tokens': query_tokens,
|
||||
'query_masks': query_masks,
|
||||
'candidate_tokens': candidate_tokens,
|
||||
'candidate_masks': candidate_masks,
|
||||
'labels': labels,
|
||||
'nsentences': NumSamplesDataset(),
|
||||
'ntokens': NumelDataset(query_tokens, reduce=True),
|
||||
"id": IdDataset(),
|
||||
"query_tokens": query_tokens,
|
||||
"query_masks": query_masks,
|
||||
"candidate_tokens": candidate_tokens,
|
||||
"candidate_masks": candidate_masks,
|
||||
"labels": labels,
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(query_tokens, reduce=True),
|
||||
}
|
||||
|
||||
nested_dataset = NestedDictionaryDataset(
|
||||
@ -210,9 +224,9 @@ class WSCTask(LegacyFairseqTask):
|
||||
|
||||
def build_dataset_for_inference(self, sample_json):
|
||||
with tempfile.NamedTemporaryFile(buffering=0) as h:
|
||||
h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
|
||||
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
|
||||
dataset = self.load_dataset(
|
||||
'disambiguate_pronoun',
|
||||
"disambiguate_pronoun",
|
||||
data_path=h.name,
|
||||
return_only=True,
|
||||
)
|
||||
@ -239,19 +253,19 @@ class WSCTask(LegacyFairseqTask):
|
||||
return scores
|
||||
|
||||
cand_lprobs = get_lprobs(
|
||||
sample['candidate_tokens'][0],
|
||||
sample['candidate_masks'][0],
|
||||
sample["candidate_tokens"][0],
|
||||
sample["candidate_masks"][0],
|
||||
)
|
||||
if sample['query_tokens'][0] is not None:
|
||||
if sample["query_tokens"][0] is not None:
|
||||
query_lprobs = get_lprobs(
|
||||
sample['query_tokens'][0].unsqueeze(0),
|
||||
sample['query_masks'][0].unsqueeze(0),
|
||||
sample["query_tokens"][0].unsqueeze(0),
|
||||
sample["query_masks"][0].unsqueeze(0),
|
||||
)
|
||||
return (query_lprobs >= cand_lprobs).all().item() == 1
|
||||
else:
|
||||
best_idx = cand_lprobs.argmax().item()
|
||||
full_cand = sample['candidate_tokens'][0][best_idx]
|
||||
mask = sample['candidate_masks'][0][best_idx]
|
||||
full_cand = sample["candidate_tokens"][0][best_idx]
|
||||
mask = sample["candidate_masks"][0][best_idx]
|
||||
toks = full_cand[mask.bool()]
|
||||
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
|
||||
|
||||
@ -264,7 +278,7 @@ class WSCTask(LegacyFairseqTask):
|
||||
return self.vocab
|
||||
|
||||
|
||||
@register_task('winogrande')
|
||||
@register_task("winogrande")
|
||||
class WinograndeTask(WSCTask):
|
||||
"""
|
||||
Task for WinoGrande dataset. Efficient implementation for Winograd schema
|
||||
@ -273,24 +287,26 @@ class WinograndeTask(WSCTask):
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
assert args.criterion == 'winogrande', 'Must set --criterion=winogrande'
|
||||
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
|
||||
|
||||
# load data and label dictionaries
|
||||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
|
||||
print('| dictionary: {} types'.format(len(vocab)))
|
||||
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
||||
print("| dictionary: {} types".format(len(vocab)))
|
||||
|
||||
return cls(args, vocab)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
|
||||
def load_dataset(
|
||||
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
||||
):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if data_path is None:
|
||||
data_path = os.path.join(self.args.data, split + '.jsonl')
|
||||
data_path = os.path.join(self.args.data, split + ".jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
|
||||
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
||||
|
||||
query_tokens = []
|
||||
query_masks = []
|
||||
@ -299,19 +315,23 @@ class WinograndeTask(WSCTask):
|
||||
candidate_masks = []
|
||||
candidate_lengths = []
|
||||
|
||||
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test'))
|
||||
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
|
||||
|
||||
for sample in itr:
|
||||
sentence, pronoun_span, query, cand_text = sample
|
||||
prefix = sentence[:pronoun_span[0]].rstrip()
|
||||
suffix = sentence[pronoun_span[1]:]
|
||||
prefix = sentence[: pronoun_span[0]].rstrip()
|
||||
suffix = sentence[pronoun_span[1] :]
|
||||
|
||||
leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else ''
|
||||
trailing_space = ''
|
||||
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
|
||||
trailing_space = ""
|
||||
|
||||
if query is not None:
|
||||
query_toks, query_mask = self.binarize_with_mask(
|
||||
query, prefix, suffix, leading_space, trailing_space,
|
||||
query,
|
||||
prefix,
|
||||
suffix,
|
||||
leading_space,
|
||||
trailing_space,
|
||||
)
|
||||
query_len = len(query_toks)
|
||||
else:
|
||||
@ -322,7 +342,11 @@ class WinograndeTask(WSCTask):
|
||||
query_lengths.append(query_len)
|
||||
|
||||
cand_toks, cand_mask = self.binarize_with_mask(
|
||||
cand_text, prefix, suffix, leading_space, trailing_space,
|
||||
cand_text,
|
||||
prefix,
|
||||
suffix,
|
||||
leading_space,
|
||||
trailing_space,
|
||||
)
|
||||
|
||||
candidate_tokens.append(cand_toks)
|
||||
@ -342,17 +366,19 @@ class WinograndeTask(WSCTask):
|
||||
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
|
||||
|
||||
candidate_lengths = np.array(candidate_lengths)
|
||||
candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad())
|
||||
candidate_tokens = get_pad_dataset_fn(
|
||||
candidate_tokens, candidate_lengths, self.vocab.pad()
|
||||
)
|
||||
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
|
||||
|
||||
dataset = {
|
||||
'id': IdDataset(),
|
||||
'query_tokens': query_tokens,
|
||||
'query_masks': query_masks,
|
||||
'candidate_tokens': candidate_tokens,
|
||||
'candidate_masks': candidate_masks,
|
||||
'nsentences': NumSamplesDataset(),
|
||||
'ntokens': NumelDataset(query_tokens, reduce=True),
|
||||
"id": IdDataset(),
|
||||
"query_tokens": query_tokens,
|
||||
"query_masks": query_masks,
|
||||
"candidate_tokens": candidate_tokens,
|
||||
"candidate_masks": candidate_masks,
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(query_tokens, reduce=True),
|
||||
}
|
||||
|
||||
nested_dataset = NestedDictionaryDataset(
|
||||
|
@ -3,48 +3,48 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import lru_cache
|
||||
import json
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def convert_sentence_to_json(sentence):
|
||||
if '_' in sentence:
|
||||
prefix, rest = sentence.split('_', 1)
|
||||
query, rest = rest.split('_', 1)
|
||||
query_index = len(prefix.rstrip().split(' '))
|
||||
if "_" in sentence:
|
||||
prefix, rest = sentence.split("_", 1)
|
||||
query, rest = rest.split("_", 1)
|
||||
query_index = len(prefix.rstrip().split(" "))
|
||||
else:
|
||||
query, query_index = None, None
|
||||
|
||||
prefix, rest = sentence.split('[', 1)
|
||||
pronoun, rest = rest.split(']', 1)
|
||||
pronoun_index = len(prefix.rstrip().split(' '))
|
||||
prefix, rest = sentence.split("[", 1)
|
||||
pronoun, rest = rest.split("]", 1)
|
||||
pronoun_index = len(prefix.rstrip().split(" "))
|
||||
|
||||
sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
|
||||
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
|
||||
|
||||
return {
|
||||
'idx': 0,
|
||||
'text': sentence,
|
||||
'target': {
|
||||
'span1_index': query_index,
|
||||
'span1_text': query,
|
||||
'span2_index': pronoun_index,
|
||||
'span2_text': pronoun,
|
||||
"idx": 0,
|
||||
"text": sentence,
|
||||
"target": {
|
||||
"span1_index": query_index,
|
||||
"span1_text": query,
|
||||
"span2_index": pronoun_index,
|
||||
"span2_text": pronoun,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def extended_noun_chunks(sentence):
|
||||
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
|
||||
np_start, cur_np = 0, 'NONE'
|
||||
np_start, cur_np = 0, "NONE"
|
||||
for i, token in enumerate(sentence):
|
||||
np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
|
||||
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
|
||||
if np_type != cur_np:
|
||||
if cur_np != 'NONE':
|
||||
if cur_np != "NONE":
|
||||
noun_chunks.add((np_start, i))
|
||||
if np_type != 'NONE':
|
||||
if np_type != "NONE":
|
||||
np_start = i
|
||||
cur_np = np_type
|
||||
if cur_np != 'NONE':
|
||||
if cur_np != "NONE":
|
||||
noun_chunks.add((np_start, len(sentence)))
|
||||
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
|
||||
|
||||
@ -61,14 +61,14 @@ def find_token(sentence, start_pos):
|
||||
def find_span(sentence, search_text, start=0):
|
||||
search_text = search_text.lower()
|
||||
for tok in sentence[start:]:
|
||||
remainder = sentence[tok.i:].text.lower()
|
||||
remainder = sentence[tok.i :].text.lower()
|
||||
if remainder.startswith(search_text):
|
||||
len_to_consume = len(search_text)
|
||||
start_idx = tok.idx
|
||||
for next_tok in sentence[tok.i:]:
|
||||
for next_tok in sentence[tok.i :]:
|
||||
end_idx = next_tok.idx + len(next_tok.text)
|
||||
if end_idx - start_idx == len_to_consume:
|
||||
span = sentence[tok.i:next_tok.i + 1]
|
||||
span = sentence[tok.i : next_tok.i + 1]
|
||||
return span
|
||||
return None
|
||||
|
||||
@ -76,13 +76,15 @@ def find_span(sentence, search_text, start=0):
|
||||
@lru_cache(maxsize=1)
|
||||
def get_detokenizer():
|
||||
from sacremoses import MosesDetokenizer
|
||||
detok = MosesDetokenizer(lang='en')
|
||||
|
||||
detok = MosesDetokenizer(lang="en")
|
||||
return detok
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_spacy_nlp():
|
||||
import en_core_web_lg
|
||||
|
||||
nlp = en_core_web_lg.load()
|
||||
return nlp
|
||||
|
||||
@ -95,45 +97,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
||||
for line in fin:
|
||||
sample = json.loads(line.strip())
|
||||
|
||||
if positive_only and 'label' in sample and not sample['label']:
|
||||
if positive_only and "label" in sample and not sample["label"]:
|
||||
# only consider examples where the query is correct
|
||||
continue
|
||||
|
||||
target = sample['target']
|
||||
target = sample["target"]
|
||||
|
||||
# clean up the query
|
||||
query = target['span1_text']
|
||||
query = target["span1_text"]
|
||||
if query is not None:
|
||||
if '\n' in query:
|
||||
if "\n" in query:
|
||||
continue
|
||||
if query.endswith('.') or query.endswith(','):
|
||||
if query.endswith(".") or query.endswith(","):
|
||||
query = query[:-1]
|
||||
|
||||
# split tokens
|
||||
tokens = sample['text'].split(' ')
|
||||
tokens = sample["text"].split(" ")
|
||||
|
||||
def strip_pronoun(x):
|
||||
return x.rstrip('.,"')
|
||||
|
||||
# find the pronoun
|
||||
pronoun_idx = target['span2_index']
|
||||
pronoun = strip_pronoun(target['span2_text'])
|
||||
pronoun_idx = target["span2_index"]
|
||||
pronoun = strip_pronoun(target["span2_text"])
|
||||
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
|
||||
# hack: sometimes the index is misaligned
|
||||
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
|
||||
pronoun_idx += 1
|
||||
else:
|
||||
raise Exception('Misaligned pronoun!')
|
||||
raise Exception("Misaligned pronoun!")
|
||||
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
|
||||
|
||||
# split tokens before and after the pronoun
|
||||
before = tokens[:pronoun_idx]
|
||||
after = tokens[pronoun_idx + 1:]
|
||||
after = tokens[pronoun_idx + 1 :]
|
||||
|
||||
# the GPT BPE attaches leading spaces to tokens, so we keep track
|
||||
# of whether we need spaces before or after the pronoun
|
||||
leading_space = ' ' if pronoun_idx > 0 else ''
|
||||
trailing_space = ' ' if len(after) > 0 else ''
|
||||
leading_space = " " if pronoun_idx > 0 else ""
|
||||
trailing_space = " " if len(after) > 0 else ""
|
||||
|
||||
# detokenize
|
||||
before = detok.detokenize(before, return_str=True)
|
||||
@ -142,14 +144,14 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
||||
|
||||
# hack: when the pronoun ends in a period (or comma), move the
|
||||
# punctuation to the "after" part
|
||||
if pronoun.endswith('.') or pronoun.endswith(','):
|
||||
if pronoun.endswith(".") or pronoun.endswith(","):
|
||||
after = pronoun[-1] + trailing_space + after
|
||||
pronoun = pronoun[:-1]
|
||||
|
||||
# hack: when the "after" part begins with a comma or period, remove
|
||||
# the trailing space
|
||||
if after.startswith('.') or after.startswith(','):
|
||||
trailing_space = ''
|
||||
if after.startswith(".") or after.startswith(","):
|
||||
trailing_space = ""
|
||||
|
||||
# parse sentence with spacy
|
||||
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
|
||||
@ -164,13 +166,13 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
||||
# convert to format where pronoun is surrounded by "[]" and
|
||||
# query is surrounded by "_"
|
||||
query_span = find_span(sentence, query)
|
||||
query_with_ws = '_{}_{}'.format(
|
||||
query_with_ws = "_{}_{}".format(
|
||||
query_span.text,
|
||||
(' ' if query_span.text_with_ws.endswith(' ') else '')
|
||||
(" " if query_span.text_with_ws.endswith(" ") else ""),
|
||||
)
|
||||
pronoun_with_ws = '[{}]{}'.format(
|
||||
pronoun_with_ws = "[{}]{}".format(
|
||||
pronoun_span.text,
|
||||
(' ' if pronoun_span.text_with_ws.endswith(' ') else '')
|
||||
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
|
||||
)
|
||||
if query_span.start < pronoun_span.start:
|
||||
first = (query_span, query_with_ws)
|
||||
@ -179,41 +181,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
||||
first = (pronoun_span, pronoun_with_ws)
|
||||
second = (query_span, query_with_ws)
|
||||
sentence = (
|
||||
sentence[:first[0].start].text_with_ws
|
||||
sentence[: first[0].start].text_with_ws
|
||||
+ first[1]
|
||||
+ sentence[first[0].end:second[0].start].text_with_ws
|
||||
+ sentence[first[0].end : second[0].start].text_with_ws
|
||||
+ second[1]
|
||||
+ sentence[second[0].end:].text
|
||||
+ sentence[second[0].end :].text
|
||||
)
|
||||
yield sentence, sample.get('label', None)
|
||||
yield sentence, sample.get("label", None)
|
||||
else:
|
||||
yield sentence, pronoun_span, query, sample.get('label', None)
|
||||
yield sentence, pronoun_span, query, sample.get("label", None)
|
||||
|
||||
|
||||
def winogrande_jsonl_iterator(input_fname, eval=False):
|
||||
with open(input_fname) as fin:
|
||||
for line in fin:
|
||||
sample = json.loads(line.strip())
|
||||
sentence, option1, option2 = sample['sentence'], sample['option1'],\
|
||||
sample['option2']
|
||||
sentence, option1, option2 = (
|
||||
sample["sentence"],
|
||||
sample["option1"],
|
||||
sample["option2"],
|
||||
)
|
||||
|
||||
pronoun_span = (sentence.index('_'), sentence.index('_') + 1)
|
||||
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
|
||||
|
||||
if eval:
|
||||
query, cand = option1, option2
|
||||
else:
|
||||
query = option1 if sample['answer'] == '1' else option2
|
||||
cand = option2 if sample['answer'] == '1' else option1
|
||||
query = option1 if sample["answer"] == "1" else option2
|
||||
cand = option2 if sample["answer"] == "1" else option1
|
||||
yield sentence, pronoun_span, query, cand
|
||||
|
||||
|
||||
def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
|
||||
def filter_noun_chunks(
|
||||
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
|
||||
):
|
||||
if exclude_pronouns:
|
||||
chunks = [
|
||||
np for np in chunks if (
|
||||
np.lemma_ != '-PRON-'
|
||||
and not all(tok.pos_ == 'PRON' for tok in np)
|
||||
)
|
||||
np
|
||||
for np in chunks
|
||||
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
|
||||
]
|
||||
|
||||
if exclude_query is not None:
|
||||
@ -224,9 +230,8 @@ def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact
|
||||
found = False
|
||||
for excl in excl_txt:
|
||||
if (
|
||||
(not exact_match and (lower_chunk in excl or excl in lower_chunk))
|
||||
or lower_chunk == excl
|
||||
):
|
||||
not exact_match and (lower_chunk in excl or excl in lower_chunk)
|
||||
) or lower_chunk == excl:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
|
@ -3,4 +3,4 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import criterions, models, eval # noqa
|
||||
from . import criterions, eval, models # noqa
|
||||
|
@ -6,6 +6,7 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
criterion_name = file[: file.find(".py")]
|
||||
|
@ -3,21 +3,17 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from examples.simultaneous_translation.utils.latency import LatencyTraining
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterion
|
||||
)
|
||||
|
||||
from examples.simultaneous_translation.utils.latency import (
|
||||
LatencyTraining
|
||||
LabelSmoothedCrossEntropyCriterion,
|
||||
)
|
||||
|
||||
|
||||
@register_criterion('latency_augmented_label_smoothed_cross_entropy')
|
||||
@register_criterion("latency_augmented_label_smoothed_cross_entropy")
|
||||
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
LabelSmoothedCrossEntropyCriterion
|
||||
):
|
||||
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
self.eps = args.label_smoothing
|
||||
@ -40,7 +36,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
def add_args(parser):
|
||||
super(
|
||||
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
||||
LatencyAugmentedLabelSmoothedCrossEntropyCriterion
|
||||
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
||||
).add_args(parser)
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
@ -69,7 +65,8 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
|
||||
# Get latency loss
|
||||
latency_loss = self.latency_train.loss(
|
||||
attn_list, source_padding_mask, target_padding_mask)
|
||||
attn_list, source_padding_mask, target_padding_mask
|
||||
)
|
||||
|
||||
loss += latency_loss
|
||||
|
||||
|
@ -5,16 +5,20 @@
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
|
||||
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type')
|
||||
|
||||
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
|
||||
"--agent-type"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_EOS = '</s>'
|
||||
DEFAULT_EOS = "</s>"
|
||||
GET = 0
|
||||
SEND = 1
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
module = file[:file.find('.py')]
|
||||
importlib.import_module('agents.' + module)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("agents." + module)
|
||||
|
@ -3,14 +3,16 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import GET, SEND, DEFAULT_EOS
|
||||
import time
|
||||
from multiprocessing.pool import ThreadPool as Pool
|
||||
from functools import partial
|
||||
from multiprocessing.pool import ThreadPool as Pool
|
||||
|
||||
from . import DEFAULT_EOS, GET, SEND
|
||||
|
||||
|
||||
class Agent(object):
|
||||
"an agent needs to follow this pattern"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@ -40,26 +42,26 @@ class Agent(object):
|
||||
with Pool(10) as p:
|
||||
p.map(
|
||||
partial(self._decode_one, session),
|
||||
[sent_id for sent_id in range(low, high + 1)]
|
||||
[sent_id for sent_id in range(low, high + 1)],
|
||||
)
|
||||
else:
|
||||
for sent_id in range(low, high + 1):
|
||||
self._decode_one(session, sent_id)
|
||||
|
||||
print(f'Finished {low} to {high} in {time.time() - t0}s')
|
||||
print(f"Finished {low} to {high} in {time.time() - t0}s")
|
||||
|
||||
def _decode_one(self, session, sent_id):
|
||||
action = {}
|
||||
self.reset()
|
||||
states = self.init_states()
|
||||
while action.get('value', None) != DEFAULT_EOS:
|
||||
while action.get("value", None) != DEFAULT_EOS:
|
||||
# take an action
|
||||
action = self.policy(states)
|
||||
|
||||
if action['key'] == GET:
|
||||
if action["key"] == GET:
|
||||
new_states = session.get_src(sent_id, action["value"])
|
||||
states = self.update_states(states, new_states)
|
||||
|
||||
elif action['key'] == SEND:
|
||||
session.send_hypo(sent_id, action['value'])
|
||||
elif action["key"] == SEND:
|
||||
session.send_hypo(sent_id, action["value"])
|
||||
print(" ".join(states["tokens"]["tgt"]))
|
||||
|
@ -3,11 +3,13 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . agent import Agent
|
||||
from . import DEFAULT_EOS, GET, SEND
|
||||
from fairseq import checkpoint_utils, utils, tasks
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
from fairseq import checkpoint_utils, tasks, utils
|
||||
|
||||
from . import DEFAULT_EOS, GET, SEND
|
||||
from .agent import Agent
|
||||
|
||||
|
||||
class SimulTransAgent(Agent):
|
||||
@ -51,13 +53,15 @@ class SimulTransAgent(Agent):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self, args):
|
||||
args.user_dir = os.path.join(os.path.dirname(__file__), '..', '..')
|
||||
args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
utils.import_user_module(args)
|
||||
filename = args.model_path
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("Model file not found: {}".format(filename))
|
||||
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(filename, json.loads(args.model_overrides))
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(
|
||||
filename, json.loads(args.model_overrides)
|
||||
)
|
||||
|
||||
saved_args = state["args"]
|
||||
saved_args.data = args.data_bin
|
||||
@ -79,7 +83,7 @@ class SimulTransAgent(Agent):
|
||||
"steps": {"src": 0, "tgt": 0},
|
||||
"finished": False,
|
||||
"finish_read": False,
|
||||
"model_states": {}
|
||||
"model_states": {},
|
||||
}
|
||||
|
||||
def update_states(self, states, new_state):
|
||||
@ -115,38 +119,38 @@ class SimulTransAgent(Agent):
|
||||
def write_action(self, states):
|
||||
token, index = self.model.predict_from_states(states)
|
||||
|
||||
if index == self.dict["tgt"].eos() or len(states["tokens"]["tgt"]) > self.max_len:
|
||||
if (
|
||||
index == self.dict["tgt"].eos()
|
||||
or len(states["tokens"]["tgt"]) > self.max_len
|
||||
):
|
||||
# Finish this sentence is predict EOS
|
||||
states["finished"] = True
|
||||
end_idx_last_full_word = self._target_length(states)
|
||||
|
||||
else:
|
||||
states["tokens"]["tgt"] += [token]
|
||||
end_idx_last_full_word = (
|
||||
self.word_splitter["tgt"]
|
||||
.end_idx_last_full_word(states["tokens"]["tgt"])
|
||||
end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
|
||||
states["tokens"]["tgt"]
|
||||
)
|
||||
self._append_indices(states, [index], "tgt")
|
||||
|
||||
if end_idx_last_full_word > states["steps"]["tgt"]:
|
||||
# Only sent detokenized full words to the server
|
||||
word = self.word_splitter["tgt"].merge(
|
||||
states["tokens"]["tgt"][
|
||||
states["steps"]["tgt"]: end_idx_last_full_word
|
||||
]
|
||||
states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
|
||||
)
|
||||
states["steps"]["tgt"] = end_idx_last_full_word
|
||||
states["segments"]["tgt"] += [word]
|
||||
|
||||
return {'key': SEND, 'value': word}
|
||||
return {"key": SEND, "value": word}
|
||||
else:
|
||||
return None
|
||||
|
||||
def read_action(self, states):
|
||||
return {'key': GET, 'value': None}
|
||||
return {"key": GET, "value": None}
|
||||
|
||||
def finish_action(self):
|
||||
return {'key': SEND, 'value': DEFAULT_EOS}
|
||||
return {"key": SEND, "value": DEFAULT_EOS}
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
@ -160,4 +164,4 @@ class SimulTransAgent(Agent):
|
||||
states["indices"][key] += new_indices
|
||||
|
||||
def _target_length(self, states):
|
||||
return len(states["tokens"]['tgt'])
|
||||
return len(states["tokens"]["tgt"])
|
||||
|
@ -3,10 +3,9 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . simul_trans_agent import SimulTransAgent
|
||||
from . import DEFAULT_EOS, GET
|
||||
from . import register_agent
|
||||
from . word_splitter import SPLITTER_DICT
|
||||
from . import DEFAULT_EOS, GET, register_agent
|
||||
from .simul_trans_agent import SimulTransAgent
|
||||
from .word_splitter import SPLITTER_DICT
|
||||
|
||||
|
||||
@register_agent("simul_trans_text")
|
||||
@ -15,11 +14,11 @@ class SimulTransTextAgent(SimulTransAgent):
|
||||
self.word_splitter = {}
|
||||
|
||||
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
|
||||
getattr(args, f"src_splitter_path")
|
||||
)
|
||||
getattr(args, f"src_splitter_path")
|
||||
)
|
||||
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
|
||||
getattr(args, f"tgt_splitter_path")
|
||||
)
|
||||
getattr(args, f"tgt_splitter_path")
|
||||
)
|
||||
|
||||
def load_dictionary(self, task):
|
||||
self.dict = {}
|
||||
@ -37,12 +36,16 @@ class SimulTransTextAgent(SimulTransAgent):
|
||||
tokens = self.word_splitter["src"].split(new_word)
|
||||
# Get indices from dictionary
|
||||
# You can change to you own dictionary
|
||||
indices = self.dict["src"].encode_line(
|
||||
tokens,
|
||||
line_tokenizer=lambda x: x,
|
||||
add_if_not_exist=False,
|
||||
append_eos=False
|
||||
).tolist()
|
||||
indices = (
|
||||
self.dict["src"]
|
||||
.encode_line(
|
||||
tokens,
|
||||
line_tokenizer=lambda x: x,
|
||||
add_if_not_exist=False,
|
||||
append_eos=False,
|
||||
)
|
||||
.tolist()
|
||||
)
|
||||
else:
|
||||
tokens = [new_word]
|
||||
indices = [self.dict["src"].eos()]
|
||||
@ -61,11 +64,11 @@ class SimulTransTextAgent(SimulTransAgent):
|
||||
|
||||
# At leat one word is read
|
||||
if len(states["tokens"]["src"]) == 0:
|
||||
return {'key': GET, 'value': None}
|
||||
return {"key": GET, "value": None}
|
||||
|
||||
# Only request new word if there is no buffered tokens
|
||||
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
|
||||
return {'key': GET, 'value': None}
|
||||
return {"key": GET, "value": None}
|
||||
|
||||
return None
|
||||
|
||||
|
@ -40,6 +40,7 @@ class BPEWordSplitter(object):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
from subword_nmt.apply_bpe import BPE
|
||||
|
||||
with open(model_path) as f:
|
||||
self.model = BPE(f)
|
||||
|
||||
@ -48,7 +49,7 @@ class BPEWordSplitter(object):
|
||||
|
||||
def end_idx_last_full_word(self, tokens):
|
||||
# Begin of word indices
|
||||
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@']
|
||||
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
|
||||
|
||||
if len(bow_indices) < 2:
|
||||
return 0
|
||||
@ -63,6 +64,7 @@ class SentencePieceModelWordSplitter(object):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
import sentencepiece as spm
|
||||
|
||||
self.model = spm.SentencePieceProcessor()
|
||||
self.model.Load(model_path)
|
||||
|
||||
@ -71,7 +73,7 @@ class SentencePieceModelWordSplitter(object):
|
||||
|
||||
def end_idx_last_full_word(self, tokens):
|
||||
# Begin of word indices
|
||||
bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581']
|
||||
bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
|
||||
|
||||
if len(bow_indices) < 2:
|
||||
return 0
|
||||
|
@ -3,19 +3,20 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import requests
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from scorers import build_scorer
|
||||
|
||||
|
||||
class SimulSTEvaluationService(object):
|
||||
DEFAULT_HOSTNAME = 'localhost'
|
||||
DEFAULT_HOSTNAME = "localhost"
|
||||
DEFAULT_PORT = 12321
|
||||
|
||||
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
|
||||
self.hostname = hostname
|
||||
self.port = port
|
||||
self.base_url = f'http://{self.hostname}:{self.port}'
|
||||
self.base_url = f"http://{self.hostname}:{self.port}"
|
||||
|
||||
def __enter__(self):
|
||||
self.new_session()
|
||||
@ -25,56 +26,53 @@ class SimulSTEvaluationService(object):
|
||||
|
||||
def new_session(self):
|
||||
# start eval session
|
||||
url = f'{self.base_url}'
|
||||
url = f"{self.base_url}"
|
||||
|
||||
try:
|
||||
_ = requests.post(url)
|
||||
except Exception as e:
|
||||
print(f'Failed to start an evaluation session: {e}')
|
||||
print(f"Failed to start an evaluation session: {e}")
|
||||
|
||||
print('Evaluation session started.')
|
||||
print("Evaluation session started.")
|
||||
return self
|
||||
|
||||
def get_scores(self):
|
||||
# end eval session
|
||||
url = f'{self.base_url}/result'
|
||||
url = f"{self.base_url}/result"
|
||||
try:
|
||||
r = requests.get(url)
|
||||
print('Scores: {}'.format(r.json()))
|
||||
print('Evaluation session finished.')
|
||||
print("Scores: {}".format(r.json()))
|
||||
print("Evaluation session finished.")
|
||||
except Exception as e:
|
||||
print(f'Failed to end an evaluation session: {e}')
|
||||
print(f"Failed to end an evaluation session: {e}")
|
||||
|
||||
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
|
||||
url = f'{self.base_url}/src'
|
||||
url = f"{self.base_url}/src"
|
||||
params = {"sent_id": sent_id}
|
||||
if extra_params is not None:
|
||||
for key in extra_params.keys():
|
||||
params[key] = extra_params[key]
|
||||
try:
|
||||
r = requests.get(
|
||||
url,
|
||||
params=params
|
||||
)
|
||||
r = requests.get(url, params=params)
|
||||
except Exception as e:
|
||||
print(f'Failed to request a source segment: {e}')
|
||||
print(f"Failed to request a source segment: {e}")
|
||||
return r.json()
|
||||
|
||||
def send_hypo(self, sent_id: int, hypo: str) -> None:
|
||||
url = f'{self.base_url}/hypo'
|
||||
url = f"{self.base_url}/hypo"
|
||||
params = {"sent_id": sent_id}
|
||||
|
||||
try:
|
||||
requests.put(url, params=params, data=hypo.encode("utf-8"))
|
||||
except Exception as e:
|
||||
print(f'Failed to send a translated segment: {e}')
|
||||
print(f"Failed to send a translated segment: {e}")
|
||||
|
||||
def corpus_info(self):
|
||||
url = f'{self.base_url}'
|
||||
url = f"{self.base_url}"
|
||||
try:
|
||||
r = requests.get(url)
|
||||
except Exception as e:
|
||||
print(f'Failed to request corpus information: {e}')
|
||||
print(f"Failed to request corpus information: {e}")
|
||||
|
||||
return r.json()
|
||||
|
||||
|
@ -3,20 +3,21 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from examples.simultaneous_translation.utils.latency import LatencyInference
|
||||
import argparse
|
||||
import torch
|
||||
import json
|
||||
|
||||
import torch
|
||||
from examples.simultaneous_translation.utils.latency import LatencyInference
|
||||
|
||||
|
||||
LATENCY_METRICS = [
|
||||
'differentiable_average_lagging',
|
||||
'average_lagging',
|
||||
'average_proportion',
|
||||
"differentiable_average_lagging",
|
||||
"average_lagging",
|
||||
"average_proportion",
|
||||
]
|
||||
|
||||
|
||||
class LatencyScorer():
|
||||
class LatencyScorer:
|
||||
def __init__(self, start_from_zero=True):
|
||||
self.recorder = []
|
||||
self.scores = {}
|
||||
@ -26,10 +27,7 @@ class LatencyScorer():
|
||||
def update_reorder(self, list_of_dict):
|
||||
self.recorder = []
|
||||
for info in list_of_dict:
|
||||
delays = [
|
||||
int(x) - int(not self.start_from_zero)
|
||||
for x in info["delays"]
|
||||
]
|
||||
delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
|
||||
delays = torch.LongTensor(delays).unsqueeze(0)
|
||||
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
|
||||
|
||||
@ -59,7 +57,7 @@ if __name__ == "__main__":
|
||||
|
||||
scorer = LatencyInference()
|
||||
recorder = []
|
||||
with open(args.input, 'r') as f:
|
||||
with open(args.input, "r") as f:
|
||||
for line in f:
|
||||
info = json.loads(line)
|
||||
|
||||
@ -74,7 +72,7 @@ if __name__ == "__main__":
|
||||
average_results = {}
|
||||
|
||||
for metric in LATENCY_METRICS:
|
||||
average_results[metric] = sum(
|
||||
[x[metric][0, 0].item() for x in recorder]
|
||||
) / len(recorder)
|
||||
average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
|
||||
recorder
|
||||
)
|
||||
print(f"{metric}: {average_results[metric]}")
|
||||
|
@ -5,37 +5,48 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from agents import build_agent
|
||||
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
|
||||
from fairseq.registry import REGISTRIES
|
||||
from agents import build_agent
|
||||
|
||||
DEFAULT_HOSTNAME = 'localhost'
|
||||
|
||||
DEFAULT_HOSTNAME = "localhost"
|
||||
DEFAULT_PORT = 12321
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
|
||||
help='server hostname')
|
||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
||||
help='server port number')
|
||||
parser.add_argument('--agent-type', default='simul_trans_text',
|
||||
help='Agent type')
|
||||
parser.add_argument('--scorer-type', default='text',
|
||||
help='Scorer type')
|
||||
parser.add_argument('--start-idx', type=int, default=0,
|
||||
help='Start index of the sentence to evaluate')
|
||||
parser.add_argument('--end-idx', type=int, default=float('inf'),
|
||||
help='End index of the sentence to evaluate')
|
||||
parser.add_argument('--scores', action="store_true",
|
||||
help='Request scores from server')
|
||||
parser.add_argument('--reset-server', action="store_true",
|
||||
help='Reset the server')
|
||||
parser.add_argument('--num-threads', type=int, default=10,
|
||||
help='Number of threads used by agent')
|
||||
parser.add_argument('--local', action="store_true", default=False,
|
||||
help='Local evaluation')
|
||||
parser.add_argument(
|
||||
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=DEFAULT_PORT, help="server port number"
|
||||
)
|
||||
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
|
||||
parser.add_argument("--scorer-type", default="text", help="Scorer type")
|
||||
parser.add_argument(
|
||||
"--start-idx",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Start index of the sentence to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end-idx",
|
||||
type=int,
|
||||
default=float("inf"),
|
||||
help="End index of the sentence to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scores", action="store_true", help="Request scores from server"
|
||||
)
|
||||
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
|
||||
parser.add_argument(
|
||||
"--num-threads", type=int, default=10, help="Number of threads used by agent"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local", action="store_true", default=False, help="Local evaluation"
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
@ -5,15 +5,15 @@
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
(
|
||||
build_scorer,
|
||||
register_scorer,
|
||||
SCORER_REGISTRIES,
|
||||
_
|
||||
) = registry.setup_registry('--scorer-type')
|
||||
|
||||
|
||||
(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
|
||||
"--scorer-type"
|
||||
)
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
module = file[:file.find('.py')]
|
||||
importlib.import_module('scorers.' + module)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("scorers." + module)
|
||||
|
@ -3,16 +3,17 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from vizseq.scorers.bleu import BLEUScorer
|
||||
from vizseq.scorers.ter import TERScorer
|
||||
from vizseq.scorers.meteor import METEORScorer
|
||||
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
|
||||
from vizseq.scorers.bleu import BLEUScorer
|
||||
from vizseq.scorers.meteor import METEORScorer
|
||||
from vizseq.scorers.ter import TERScorer
|
||||
|
||||
|
||||
DEFAULT_EOS = '</s>'
|
||||
DEFAULT_EOS = "</s>"
|
||||
|
||||
|
||||
class SimulScorer(object):
|
||||
@ -23,7 +24,7 @@ class SimulScorer(object):
|
||||
self.output_files = {
|
||||
"text": os.path.join(args.output, "text"),
|
||||
"delay": os.path.join(args.output, "delay"),
|
||||
"scores": os.path.join(args.output, "scores")
|
||||
"scores": os.path.join(args.output, "scores"),
|
||||
}
|
||||
else:
|
||||
self.output_files = None
|
||||
@ -52,14 +53,7 @@ class SimulScorer(object):
|
||||
|
||||
def recv_hyp(self, sent_id, list_of_tokens):
|
||||
for token in list_of_tokens:
|
||||
self.translations[
|
||||
sent_id
|
||||
].append(
|
||||
(
|
||||
token,
|
||||
self.steps[sent_id]
|
||||
)
|
||||
)
|
||||
self.translations[sent_id].append((token, self.steps[sent_id]))
|
||||
|
||||
def reset(self):
|
||||
self.steps = defaultdict(int)
|
||||
@ -76,8 +70,9 @@ class SimulScorer(object):
|
||||
delays += [[t[1] for t in self.translations[i]]]
|
||||
|
||||
bleu_score = BLEUScorer(
|
||||
sent_level=False, corpus_level=True,
|
||||
extra_args={'bleu_tokenizer': self.tokenizer}
|
||||
sent_level=False,
|
||||
corpus_level=True,
|
||||
extra_args={"bleu_tokenizer": self.tokenizer},
|
||||
).score(translations, [self.data["tgt"]])
|
||||
|
||||
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
|
||||
@ -92,16 +87,16 @@ class SimulScorer(object):
|
||||
{"src_len": src_len, "delays": delay}
|
||||
for src_len, delay in zip(self.src_lengths(), delays)
|
||||
],
|
||||
start_from_zero=False
|
||||
start_from_zero=False,
|
||||
)
|
||||
|
||||
scores = {
|
||||
'BLEU': bleu_score[0],
|
||||
'TER': ter_score[0],
|
||||
'METEOR': meteor_score[0],
|
||||
'DAL': latency_score['differentiable_average_lagging'],
|
||||
'AL': latency_score['average_lagging'],
|
||||
'AP': latency_score['average_proportion'],
|
||||
"BLEU": bleu_score[0],
|
||||
"TER": ter_score[0],
|
||||
"METEOR": meteor_score[0],
|
||||
"DAL": latency_score["differentiable_average_lagging"],
|
||||
"AL": latency_score["average_lagging"],
|
||||
"AP": latency_score["average_proportion"],
|
||||
}
|
||||
|
||||
if self.output_files is not None:
|
||||
@ -109,9 +104,9 @@ class SimulScorer(object):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.write_results_to_file(translations, delays, scores)
|
||||
except BaseException as be:
|
||||
print(f'Failed to write results to {self.output_dir}.')
|
||||
print(f"Failed to write results to {self.output_dir}.")
|
||||
print(be)
|
||||
print('Skip writing predictions')
|
||||
print("Skip writing predictions")
|
||||
|
||||
return scores
|
||||
|
||||
@ -125,12 +120,8 @@ class SimulScorer(object):
|
||||
with open(self.output_files["delay"], "w") as f:
|
||||
for i, delay in enumerate(delays):
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"src_len": self.src_lengths()[i],
|
||||
"delays": delay
|
||||
}
|
||||
) + "\n"
|
||||
json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
with open(self.output_files["scores"], "w") as f:
|
||||
@ -163,7 +154,7 @@ class SimulScorer(object):
|
||||
list_to_return.append(
|
||||
{
|
||||
"path": item["input"]["path"].strip(),
|
||||
"length": item["input"]["length_ms"]
|
||||
"length": item["input"]["length_ms"],
|
||||
}
|
||||
)
|
||||
return list_to_return
|
||||
|
@ -3,8 +3,8 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . scorer import SimulScorer
|
||||
from . import register_scorer
|
||||
from .scorer import SimulScorer
|
||||
|
||||
|
||||
@register_scorer("text")
|
||||
@ -13,7 +13,7 @@ class SimulTextScorer(SimulScorer):
|
||||
super().__init__(args)
|
||||
self.data = {
|
||||
"src": self._load_text_file(args.src_file, split=True),
|
||||
"tgt": self._load_text_file(args.tgt_file, split=False)
|
||||
"tgt": self._load_text_file(args.tgt_file, split=False),
|
||||
}
|
||||
|
||||
def send_src(self, sent_id, *args):
|
||||
@ -21,7 +21,7 @@ class SimulTextScorer(SimulScorer):
|
||||
dict_to_return = {
|
||||
"sent_id": sent_id,
|
||||
"segment_id": self.steps[sent_id],
|
||||
"segment": self.eos
|
||||
"segment": self.eos,
|
||||
}
|
||||
# Consider EOS
|
||||
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
|
||||
@ -29,7 +29,7 @@ class SimulTextScorer(SimulScorer):
|
||||
dict_to_return = {
|
||||
"sent_id": sent_id,
|
||||
"segment_id": self.steps[sent_id],
|
||||
"segment": self.data["src"][sent_id][self.steps[sent_id]]
|
||||
"segment": self.data["src"][sent_id][self.steps[sent_id]],
|
||||
}
|
||||
|
||||
self.steps[sent_id] += 1
|
||||
|
@ -3,12 +3,14 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import argparse
|
||||
import sys
|
||||
import json
|
||||
from tornado import web, ioloop
|
||||
from scorers import build_scorer
|
||||
import sys
|
||||
|
||||
DEFAULT_HOSTNAME = 'localhost'
|
||||
from scorers import build_scorer
|
||||
from tornado import ioloop, web
|
||||
|
||||
|
||||
DEFAULT_HOSTNAME = "localhost"
|
||||
DEFAULT_PORT = 12321
|
||||
|
||||
|
||||
@ -34,10 +36,10 @@ class ResultHandler(ScorerHandler):
|
||||
|
||||
class SourceHandler(ScorerHandler):
|
||||
def get(self):
|
||||
sent_id = int(self.get_argument('sent_id'))
|
||||
sent_id = int(self.get_argument("sent_id"))
|
||||
segment_size = None
|
||||
if "segment_size" in self.request.arguments:
|
||||
string = self.get_argument('segment_size')
|
||||
string = self.get_argument("segment_size")
|
||||
if len(string) > 0:
|
||||
segment_size = int(string)
|
||||
|
||||
@ -48,8 +50,8 @@ class SourceHandler(ScorerHandler):
|
||||
|
||||
class HypothesisHandler(ScorerHandler):
|
||||
def put(self):
|
||||
sent_id = int(self.get_argument('sent_id'))
|
||||
list_of_tokens = self.request.body.decode('utf-8').strip().split()
|
||||
sent_id = int(self.get_argument("sent_id"))
|
||||
list_of_tokens = self.request.body.decode("utf-8").strip().split()
|
||||
self.scorer.recv_hyp(sent_id, list_of_tokens)
|
||||
|
||||
|
||||
@ -67,18 +69,21 @@ def add_args():
|
||||
|
||||
|
||||
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
|
||||
app = web.Application([
|
||||
(r'/result', ResultHandler, dict(scorer=scorer)),
|
||||
(r'/src', SourceHandler, dict(scorer=scorer)),
|
||||
(r'/hypo', HypothesisHandler, dict(scorer=scorer)),
|
||||
(r'/', EvalSessionHandler, dict(scorer=scorer)),
|
||||
], debug=debug)
|
||||
app = web.Application(
|
||||
[
|
||||
(r"/result", ResultHandler, dict(scorer=scorer)),
|
||||
(r"/src", SourceHandler, dict(scorer=scorer)),
|
||||
(r"/hypo", HypothesisHandler, dict(scorer=scorer)),
|
||||
(r"/", EvalSessionHandler, dict(scorer=scorer)),
|
||||
],
|
||||
debug=debug,
|
||||
)
|
||||
app.listen(port, max_buffer_size=1024 ** 3)
|
||||
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
|
||||
ioloop.IOLoop.current().start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = add_args()
|
||||
scorer = build_scorer(args)
|
||||
start_server(scorer, args.hostname, args.port, args.debug)
|
||||
|
@ -6,7 +6,10 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
model_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.simultaneous_translation.models.' + model_name)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
model_name = file[: file.find(".py")]
|
||||
importlib.import_module(
|
||||
"examples.simultaneous_translation.models." + model_name
|
||||
)
|
||||
|
@ -6,42 +6,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.models import (
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
|
||||
TransformerMonotonicDecoderLayer,
|
||||
TransformerMonotonicEncoderLayer,
|
||||
)
|
||||
|
||||
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.transformer import (
|
||||
TransformerModel,
|
||||
TransformerEncoder,
|
||||
TransformerDecoder,
|
||||
TransformerEncoder,
|
||||
TransformerModel,
|
||||
base_architecture,
|
||||
transformer_iwslt_de_en,
|
||||
transformer_vaswani_wmt_en_de_big,
|
||||
)
|
||||
|
||||
|
||||
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
|
||||
TransformerMonotonicDecoderLayer,
|
||||
TransformerMonotonicEncoderLayer
|
||||
)
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
|
||||
|
||||
@register_model('transformer_unidirectional')
|
||||
@register_model("transformer_unidirectional")
|
||||
class TransformerUnidirectionalModel(TransformerModel):
|
||||
@classmethod
|
||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
||||
|
||||
|
||||
@register_model('transformer_monotonic')
|
||||
@register_model("transformer_monotonic")
|
||||
class TransformerMonotonicModel(TransformerModel):
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
||||
@ -62,26 +54,17 @@ class TransformerMonotonicModel(TransformerModel):
|
||||
)
|
||||
|
||||
tgt_indices = tensor(
|
||||
[
|
||||
[self.decoder.dictionary.eos()]
|
||||
+ states["indices"]["tgt"]
|
||||
]
|
||||
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
|
||||
)
|
||||
else:
|
||||
src_indices = states["indices"]["src"][: 1 +
|
||||
states["steps"]["src"]]
|
||||
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
|
||||
tgt_indices = states["indices"]["tgt"]
|
||||
|
||||
return src_indices, None, tgt_indices
|
||||
|
||||
def predict_from_states(self, states):
|
||||
decoder_states = self.decoder.output_layer(
|
||||
states["decoder_features"]
|
||||
)
|
||||
lprobs = self.get_normalized_probs(
|
||||
[decoder_states[:, -1:]],
|
||||
log_probs=True
|
||||
)
|
||||
decoder_states = self.decoder.output_layer(states["decoder_features"])
|
||||
lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
|
||||
|
||||
index = lprobs.argmax(dim=-1)
|
||||
|
||||
@ -90,25 +73,24 @@ class TransformerMonotonicModel(TransformerModel):
|
||||
return token, index[0, 0].item()
|
||||
|
||||
def decision_from_states(self, states):
|
||||
'''
|
||||
"""
|
||||
This funcion take states dictionary as input, and gives the agent
|
||||
a decision of whether read a token from server. Moreover, the decoder
|
||||
states are also calculated here so we can directly generate a target
|
||||
token without recompute every thing
|
||||
'''
|
||||
"""
|
||||
|
||||
self.eval()
|
||||
|
||||
if len(states["tokens"]["src"]) == 0:
|
||||
return 0
|
||||
|
||||
src_indices, src_lengths, tgt_indices = self._indices_from_states(
|
||||
states)
|
||||
src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
|
||||
|
||||
# Update encoder states if needed
|
||||
if (
|
||||
"encoder_states" not in states or
|
||||
states["encoder_states"][0].size(1) <= states["steps"]["src"]
|
||||
"encoder_states" not in states
|
||||
or states["encoder_states"][0].size(1) <= states["steps"]["src"]
|
||||
):
|
||||
encoder_out_dict = self.encoder(src_indices, src_lengths)
|
||||
states["encoder_states"] = encoder_out_dict
|
||||
@ -136,16 +118,14 @@ class TransformerMonotonicModel(TransformerModel):
|
||||
|
||||
|
||||
class TransformerMonotonicEncoder(TransformerEncoder):
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens):
|
||||
super().__init__(args, dictionary, embed_tokens)
|
||||
|
||||
self.dictionary = dictionary
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend([
|
||||
TransformerMonotonicEncoderLayer(args)
|
||||
for i in range(args.encoder_layers)
|
||||
])
|
||||
self.layers.extend(
|
||||
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
|
||||
)
|
||||
|
||||
|
||||
class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
@ -166,19 +146,24 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
|
||||
self.dictionary = dictionary
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend([
|
||||
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
])
|
||||
self.layers.extend(
|
||||
[
|
||||
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def pre_attention(
|
||||
self, prev_output_tokens, encoder_out_dict,
|
||||
incremental_state=None
|
||||
self, prev_output_tokens, encoder_out_dict, incremental_state=None
|
||||
):
|
||||
positions = self.embed_positions(
|
||||
prev_output_tokens,
|
||||
incremental_state=incremental_state,
|
||||
) if self.embed_positions is not None else None
|
||||
positions = (
|
||||
self.embed_positions(
|
||||
prev_output_tokens,
|
||||
incremental_state=incremental_state,
|
||||
)
|
||||
if self.embed_positions is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
@ -216,8 +201,7 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
return x
|
||||
|
||||
def extract_features(
|
||||
self, prev_output_tokens, encoder_out,
|
||||
incremental_state=None, **unused
|
||||
self, prev_output_tokens, encoder_out, incremental_state=None, **unused
|
||||
):
|
||||
"""
|
||||
Similar to *forward* but only return features.
|
||||
@ -228,14 +212,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
# incremental_state = None
|
||||
(
|
||||
x,
|
||||
encoder_outs,
|
||||
encoder_padding_mask
|
||||
) = self.pre_attention(
|
||||
prev_output_tokens,
|
||||
encoder_out,
|
||||
incremental_state
|
||||
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
|
||||
prev_output_tokens, encoder_out, incremental_state
|
||||
)
|
||||
attn = None
|
||||
inner_states = [x]
|
||||
@ -250,7 +228,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
encoder_padding_mask=encoder_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
self_attn_mask=self.buffered_future_mask(x)
|
||||
if incremental_state is None else None,
|
||||
if incremental_state is None
|
||||
else None,
|
||||
)
|
||||
|
||||
inner_states.append(x)
|
||||
@ -261,38 +240,30 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
step_list.append(curr_steps)
|
||||
|
||||
if incremental_state.get("online", False):
|
||||
p_choose = attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
|
||||
|
||||
new_steps = (
|
||||
curr_steps
|
||||
+ (p_choose < 0.5).t().type_as(curr_steps)
|
||||
p_choose = (
|
||||
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
|
||||
)
|
||||
|
||||
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
|
||||
|
||||
if (new_steps >= incremental_state["steps"]["src"]).any():
|
||||
# We need to prune the last self_attn saved_state
|
||||
# if model decide not to read
|
||||
# otherwise there will be duplicated saved_state
|
||||
for j in range(i + 1):
|
||||
self.layers[j].prune_incremental_state(
|
||||
incremental_state)
|
||||
self.layers[j].prune_incremental_state(incremental_state)
|
||||
|
||||
return x, {"action": 0}
|
||||
|
||||
if (
|
||||
incremental_state is not None
|
||||
and not incremental_state.get("online", False)
|
||||
):
|
||||
if incremental_state is not None and not incremental_state.get("online", False):
|
||||
# Here is for fast evaluation
|
||||
fastest_step = torch.max(
|
||||
torch.cat(step_list, dim=1),
|
||||
dim=1,
|
||||
keepdim=True
|
||||
)[0] + 1
|
||||
fastest_step = (
|
||||
torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
|
||||
)
|
||||
|
||||
if "fastest_step" in incremental_state:
|
||||
incremental_state["fastest_step"] = torch.cat(
|
||||
[incremental_state["fastest_step"], fastest_step],
|
||||
dim=1
|
||||
[incremental_state["fastest_step"], fastest_step], dim=1
|
||||
)
|
||||
else:
|
||||
incremental_state["fastest_step"] = fastest_step
|
||||
@ -310,25 +281,19 @@ class TransformerMonotonicDecoder(TransformerDecoder):
|
||||
def reorder_incremental_state(self, incremental_state, new_order):
|
||||
super().reorder_incremental_state(incremental_state, new_order)
|
||||
if "fastest_step" in incremental_state:
|
||||
incremental_state["fastest_step"] = (
|
||||
incremental_state["fastest_step"]
|
||||
.index_select(0, new_order)
|
||||
)
|
||||
incremental_state["fastest_step"] = incremental_state[
|
||||
"fastest_step"
|
||||
].index_select(0, new_order)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
'transformer_monotonic',
|
||||
'transformer_monotonic'
|
||||
)
|
||||
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
|
||||
def base_monotonic_rchitecture(args):
|
||||
base_architecture(args)
|
||||
args.encoder_unidirectional = getattr(
|
||||
args, 'encoder_unidirectional', False)
|
||||
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
'transformer_monotonic',
|
||||
'transformer_monotonic_iwslt_de_en'
|
||||
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
|
||||
)
|
||||
def transformer_monotonic_iwslt_de_en(args):
|
||||
transformer_iwslt_de_en(args)
|
||||
@ -337,24 +302,21 @@ def transformer_monotonic_iwslt_de_en(args):
|
||||
|
||||
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
||||
@register_model_architecture(
|
||||
'transformer_monotonic',
|
||||
'transformer_monotonic_vaswani_wmt_en_de_big'
|
||||
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
|
||||
)
|
||||
def transformer_monotonic_vaswani_wmt_en_de_big(args):
|
||||
transformer_vaswani_wmt_en_de_big(args)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
'transformer_monotonic',
|
||||
'transformer_monotonic_vaswani_wmt_en_fr_big'
|
||||
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
|
||||
)
|
||||
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
|
||||
transformer_monotonic_vaswani_wmt_en_fr_big(args)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
'transformer_unidirectional',
|
||||
'transformer_unidirectional_iwslt_de_en'
|
||||
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
|
||||
)
|
||||
def transformer_unidirectional_iwslt_de_en(args):
|
||||
transformer_iwslt_de_en(args)
|
||||
|
@ -7,14 +7,18 @@ import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
|
||||
|
||||
(
|
||||
build_monotonic_attention,
|
||||
register_monotonic_attention,
|
||||
MONOTONIC_ATTENTION_REGISTRY,
|
||||
_
|
||||
) = registry.setup_registry('--simul-type')
|
||||
_,
|
||||
) = registry.setup_registry("--simul-type")
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
model_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.simultaneous_translation.modules.' + model_name)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
model_name = file[: file.find(".py")]
|
||||
importlib.import_module(
|
||||
"examples.simultaneous_translation.modules." + model_name
|
||||
)
|
||||
|
@ -4,22 +4,19 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq import utils
|
||||
|
||||
from fairseq.modules import MultiheadAttention
|
||||
|
||||
import torch.nn.functional as F
|
||||
from examples.simultaneous_translation.utils.functions import (
|
||||
exclusive_cumprod,
|
||||
lengths_to_mask
|
||||
lengths_to_mask,
|
||||
)
|
||||
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||
from fairseq.modules import MultiheadAttention
|
||||
from fairseq.utils import convert_padding_direction
|
||||
|
||||
from . import register_monotonic_attention
|
||||
|
||||
|
||||
@ -28,6 +25,7 @@ class MonotonicAttention(nn.Module):
|
||||
"""
|
||||
Abstract class of monotonic attentions
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
self.eps = args.attention_eps
|
||||
self.mass_preservation = args.mass_preservation
|
||||
@ -38,7 +36,8 @@ class MonotonicAttention(nn.Module):
|
||||
self.energy_bias_init = args.energy_bias_init
|
||||
self.energy_bias = (
|
||||
nn.Parameter(self.energy_bias_init * torch.ones([1]))
|
||||
if args.energy_bias is True else 0
|
||||
if args.energy_bias is True
|
||||
else 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -90,7 +89,7 @@ class MonotonicAttention(nn.Module):
|
||||
if key_padding_mask is not None:
|
||||
attn_energy = attn_energy.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
|
||||
float('-inf'),
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
return attn_energy
|
||||
@ -131,10 +130,7 @@ class MonotonicAttention(nn.Module):
|
||||
alpha_i = (
|
||||
p_choose[:, i]
|
||||
* cumprod_1mp[:, i]
|
||||
* torch.cumsum(
|
||||
previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
|
||||
dim=1
|
||||
)
|
||||
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
|
||||
).clamp(0, 1.0)
|
||||
previous_attn.append(alpha_i.unsqueeze(1))
|
||||
|
||||
@ -170,8 +166,7 @@ class MonotonicAttention(nn.Module):
|
||||
# prev_monotonic_step: bsz, num_heads
|
||||
bsz = bsz_num_heads // self.num_heads
|
||||
prev_monotonic_step = monotonic_cache.get(
|
||||
"step",
|
||||
p_choose.new_zeros([bsz, self.num_heads]).long()
|
||||
"step", p_choose.new_zeros([bsz, self.num_heads]).long()
|
||||
)
|
||||
bsz, num_heads = prev_monotonic_step.size()
|
||||
assert num_heads == self.num_heads
|
||||
@ -181,8 +176,7 @@ class MonotonicAttention(nn.Module):
|
||||
p_choose = p_choose.view(bsz, num_heads, src_len)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
src_lengths = src_len - \
|
||||
key_padding_mask.sum(dim=1, keepdim=True).long()
|
||||
src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
|
||||
else:
|
||||
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
|
||||
|
||||
@ -197,10 +191,7 @@ class MonotonicAttention(nn.Module):
|
||||
# left_pad_source = True:
|
||||
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
||||
|
||||
max_steps = (
|
||||
src_lengths - 1 if self.mass_preservation
|
||||
else src_lengths
|
||||
)
|
||||
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
|
||||
|
||||
# finish_read: bsz, num_heads
|
||||
finish_read = new_monotonic_step.eq(max_steps)
|
||||
@ -210,11 +201,11 @@ class MonotonicAttention(nn.Module):
|
||||
# only choose the p at monotonic steps
|
||||
# p_choose_i: bsz , self.num_heads
|
||||
p_choose_i = (
|
||||
p_choose
|
||||
.gather(
|
||||
p_choose.gather(
|
||||
2,
|
||||
(step_offset + new_monotonic_step).unsqueeze(2)
|
||||
.clamp(0, src_len - 1)
|
||||
(step_offset + new_monotonic_step)
|
||||
.unsqueeze(2)
|
||||
.clamp(0, src_len - 1),
|
||||
)
|
||||
).squeeze(2)
|
||||
|
||||
@ -239,21 +230,17 @@ class MonotonicAttention(nn.Module):
|
||||
|
||||
# alpha: bsz * num_heads, 1, src_len
|
||||
# new_monotonic_step: bsz, num_heads
|
||||
alpha = (
|
||||
p_choose
|
||||
.new_zeros([bsz * self.num_heads, src_len])
|
||||
.scatter(
|
||||
1,
|
||||
(step_offset + new_monotonic_step).view(bsz *
|
||||
self.num_heads, 1).clamp(0, src_len - 1),
|
||||
1
|
||||
)
|
||||
alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
|
||||
1,
|
||||
(step_offset + new_monotonic_step)
|
||||
.view(bsz * self.num_heads, 1)
|
||||
.clamp(0, src_len - 1),
|
||||
1,
|
||||
)
|
||||
|
||||
if not self.mass_preservation:
|
||||
alpha = alpha.masked_fill(
|
||||
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1),
|
||||
0
|
||||
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
|
||||
)
|
||||
|
||||
alpha = alpha.unsqueeze(1)
|
||||
@ -266,8 +253,14 @@ class MonotonicAttention(nn.Module):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self, query, key, value,
|
||||
key_padding_mask=None, incremental_state=None, *args, **kwargs,
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_padding_mask=None,
|
||||
incremental_state=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
@ -280,25 +273,24 @@ class MonotonicAttention(nn.Module):
|
||||
# expected alignment alpha
|
||||
# bsz * self.num_heads, tgt_len, src_len
|
||||
if incremental_state is not None:
|
||||
alpha = self.expected_alignment_infer(p_choose, key_padding_mask, incremental_state)
|
||||
alpha = self.expected_alignment_infer(
|
||||
p_choose, key_padding_mask, incremental_state
|
||||
)
|
||||
else:
|
||||
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
|
||||
|
||||
# expected attention beta
|
||||
# bsz * self.num_heads, tgt_len, src_len
|
||||
beta = self.expected_attention(alpha, query, key, value, key_padding_mask, incremental_state)
|
||||
beta = self.expected_attention(
|
||||
alpha, query, key, value, key_padding_mask, incremental_state
|
||||
)
|
||||
|
||||
attn_weights = beta
|
||||
|
||||
v_proj = self.v_proj_output(value)
|
||||
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
|
||||
|
||||
attn = (
|
||||
attn
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
@ -318,26 +310,32 @@ class MonotonicAttention(nn.Module):
|
||||
self._set_monotonic_buffer(incremental_state, input_buffer)
|
||||
|
||||
def _get_monotonic_buffer(self, incremental_state):
|
||||
return utils.get_incremental_state(
|
||||
self,
|
||||
incremental_state,
|
||||
'monotonic',
|
||||
) or {}
|
||||
return (
|
||||
utils.get_incremental_state(
|
||||
self,
|
||||
incremental_state,
|
||||
"monotonic",
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
def _set_monotonic_buffer(self, incremental_state, buffer):
|
||||
utils.set_incremental_state(
|
||||
self,
|
||||
incremental_state,
|
||||
'monotonic',
|
||||
"monotonic",
|
||||
buffer,
|
||||
)
|
||||
|
||||
def get_pointer(self, incremental_state):
|
||||
return utils.get_incremental_state(
|
||||
self,
|
||||
incremental_state,
|
||||
'monotonic',
|
||||
) or {}
|
||||
return (
|
||||
utils.get_incremental_state(
|
||||
self,
|
||||
incremental_state,
|
||||
"monotonic",
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
def get_fastest_pointer(self, incremental_state):
|
||||
return self.get_pointer(incremental_state)["step"].max(0)[0]
|
||||
@ -354,23 +352,22 @@ class MonotonicAttention(nn.Module):
|
||||
utils.set_incremental_state(
|
||||
self,
|
||||
incremental_state,
|
||||
'monotonic',
|
||||
"monotonic",
|
||||
{"step": buffer},
|
||||
)
|
||||
|
||||
|
||||
@register_monotonic_attention("hard_aligned")
|
||||
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
||||
|
||||
def __init__(self, args):
|
||||
MultiheadAttention.__init__(
|
||||
self,
|
||||
embed_dim=args.decoder_embed_dim,
|
||||
num_heads=args.decoder_attention_heads,
|
||||
kdim=getattr(args, 'encoder_embed_dim', None),
|
||||
vdim=getattr(args, 'encoder_embed_dim', None),
|
||||
kdim=getattr(args, "encoder_embed_dim", None),
|
||||
vdim=getattr(args, "encoder_embed_dim", None),
|
||||
dropout=args.attention_dropout,
|
||||
encoder_decoder_attention=True
|
||||
encoder_decoder_attention=True,
|
||||
)
|
||||
|
||||
MonotonicAttention.__init__(self, args)
|
||||
@ -395,21 +392,33 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
||||
bsz = query.size(1)
|
||||
q = self.q_in_proj[name](query)
|
||||
q *= self.scaling
|
||||
q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
else:
|
||||
q = None
|
||||
|
||||
if key is not None:
|
||||
bsz = key.size(1)
|
||||
k = self.k_in_proj[name](key)
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
else:
|
||||
k = None
|
||||
|
||||
if value is not None:
|
||||
bsz = value.size(1)
|
||||
v = self.v_in_proj[name](value)
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
else:
|
||||
v = None
|
||||
|
||||
@ -441,8 +450,7 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
||||
if self.training:
|
||||
# add noise here to encourage discretness
|
||||
noise = (
|
||||
torch
|
||||
.normal(self.noise_mean, self.noise_var, attn_energy.size())
|
||||
torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
|
||||
.type_as(attn_energy)
|
||||
.to(attn_energy.device)
|
||||
)
|
||||
@ -454,9 +462,9 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
||||
return p_choose.view(-1, tgt_len, src_len)
|
||||
|
||||
def expected_attention(self, alpha, *args):
|
||||
'''
|
||||
"""
|
||||
For MMA-H, beta = alpha
|
||||
'''
|
||||
"""
|
||||
return alpha
|
||||
|
||||
def v_proj_output(self, value):
|
||||
@ -479,13 +487,19 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(
|
||||
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
|
||||
)
|
||||
nn.init.xavier_uniform_(
|
||||
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
|
||||
)
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
|
||||
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
|
||||
|
||||
def expected_attention(self, alpha, query, key, value, key_padding_mask, incremental_state):
|
||||
def expected_attention(
|
||||
self, alpha, query, key, value, key_padding_mask, incremental_state
|
||||
):
|
||||
# monotonic attention, we will calculate milk here
|
||||
bsz_x_num_heads, tgt_len, src_len = alpha.size()
|
||||
bsz = int(bsz_x_num_heads / self.num_heads)
|
||||
@ -507,9 +521,10 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
|
||||
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
||||
monotonic_step += step_offset
|
||||
mask = lengths_to_mask(
|
||||
monotonic_step.view(-1), soft_energy.size(2), 1).unsqueeze(1)
|
||||
monotonic_step.view(-1), soft_energy.size(2), 1
|
||||
).unsqueeze(1)
|
||||
|
||||
soft_energy = soft_energy.masked_fill(~ mask.bool(), float('-inf'))
|
||||
soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
|
||||
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
|
||||
exp_soft_energy = torch.exp(soft_energy)
|
||||
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
|
||||
@ -524,14 +539,20 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
|
||||
if key_padding_mask is not None:
|
||||
if key_padding_mask.any():
|
||||
exp_soft_energy_cumsum = (
|
||||
exp_soft_energy_cumsum.view(-1, self.num_heads, tgt_len, src_len)
|
||||
.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps)
|
||||
exp_soft_energy_cumsum.view(
|
||||
-1, self.num_heads, tgt_len, src_len
|
||||
)
|
||||
.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
|
||||
)
|
||||
.view(-1, tgt_len, src_len)
|
||||
)
|
||||
|
||||
inner_items = alpha / exp_soft_energy_cumsum
|
||||
|
||||
beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2])
|
||||
beta = exp_soft_energy * torch.cumsum(
|
||||
inner_items.flip(dims=[2]), dim=2
|
||||
).flip(dims=[2])
|
||||
|
||||
beta = self.dropout_module(beta)
|
||||
|
||||
@ -547,7 +568,9 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
|
||||
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
|
||||
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
|
||||
self.waitk_lagging = args.waitk_lagging
|
||||
assert self.waitk_lagging > 0, f"Lagging has to been larger than 0, get {self.waitk_lagging}."
|
||||
assert (
|
||||
self.waitk_lagging > 0
|
||||
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
@ -556,10 +579,13 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
|
||||
MonotonicMultiheadAttentionWaitk,
|
||||
).add_args(parser)
|
||||
|
||||
parser.add_argument('--waitk-lagging', type=int, required=True,
|
||||
help='Wait k lagging')
|
||||
parser.add_argument(
|
||||
"--waitk-lagging", type=int, required=True, help="Wait k lagging"
|
||||
)
|
||||
|
||||
def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None):
|
||||
def p_choose(
|
||||
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
|
||||
):
|
||||
"""
|
||||
query: bsz, tgt_len
|
||||
key: bsz, src_len
|
||||
@ -574,16 +600,22 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
|
||||
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
|
||||
# Left pad source
|
||||
# add -1 to the end
|
||||
p_choose = p_choose.masked_fill(key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1)
|
||||
p_choose = convert_padding_direction(p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True)
|
||||
p_choose = p_choose.masked_fill(
|
||||
key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
|
||||
)
|
||||
p_choose = convert_padding_direction(
|
||||
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
|
||||
)
|
||||
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
|
||||
# remove -1
|
||||
p_choose[p_choose.eq(-1)] = 0
|
||||
|
||||
# Extend to each head
|
||||
p_choose = (
|
||||
p_choose.contiguous().unsqueeze(1)
|
||||
.expand(-1, self.num_heads, -1, -1).contiguous()
|
||||
p_choose.contiguous()
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.num_heads, -1, -1)
|
||||
.contiguous()
|
||||
.view(-1, tgt_len, src_len)
|
||||
)
|
||||
|
||||
|
@ -3,37 +3,32 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.modules import (
|
||||
LayerNorm,
|
||||
TransformerEncoderLayer,
|
||||
TransformerDecoderLayer
|
||||
)
|
||||
from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
|
||||
|
||||
from . import build_monotonic_attention
|
||||
|
||||
|
||||
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
|
||||
|
||||
def forward(self, x, encoder_padding_mask):
|
||||
seq_len, _, _ = x.size()
|
||||
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float('-inf'))
|
||||
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
|
||||
return super().forward(x, encoder_padding_mask, attn_mask)
|
||||
|
||||
|
||||
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
|
||||
|
||||
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
|
||||
def __init__(
|
||||
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
||||
):
|
||||
super().__init__(
|
||||
args,
|
||||
no_encoder_attn=True,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn
|
||||
add_zero_attn=add_zero_attn,
|
||||
)
|
||||
self.encoder_attn = build_monotonic_attention(args)
|
||||
self.encoder_attn_layer_norm = LayerNorm(
|
||||
self.embed_dim,
|
||||
export=getattr(args, 'char_inputs', False)
|
||||
self.embed_dim, export=getattr(args, "char_inputs", False)
|
||||
)
|
||||
|
||||
def prune_incremental_state(self, incremental_state):
|
||||
@ -46,12 +41,8 @@ class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
|
||||
input_buffer = {}
|
||||
break
|
||||
module._set_input_buffer(incremental_state, input_buffer)
|
||||
|
||||
prune(self.self_attn)
|
||||
|
||||
def get_steps(self, incremental_state):
|
||||
return (
|
||||
self.encoder_attn
|
||||
._get_monotonic_buffer(
|
||||
incremental_state
|
||||
).get("step", 0)
|
||||
)
|
||||
return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
|
||||
|
@ -9,6 +9,6 @@ import os
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
module = file[:file.find('.py')]
|
||||
importlib.import_module('examples.simultaneous_translation.utils.' + module)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("examples.simultaneous_translation.utils." + module)
|
||||
|
@ -16,7 +16,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
|
||||
tensor_size = list(tensor.size())
|
||||
tensor_size[dim] = 1
|
||||
return_tensor = safe_cumprod(
|
||||
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps
|
||||
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
|
||||
dim=dim,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
if dim == 0:
|
||||
@ -132,12 +134,14 @@ def moving_sum(x, start_idx: int, end_idx: int):
|
||||
# batch_size, 1, src_len
|
||||
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
|
||||
|
||||
moving_sum = torch.nn.functional.conv1d(
|
||||
x,
|
||||
moving_sum_weight,
|
||||
padding=start_idx + end_idx - 1
|
||||
).squeeze(1).t()
|
||||
moving_sum = moving_sum[end_idx: -start_idx]
|
||||
moving_sum = (
|
||||
torch.nn.functional.conv1d(
|
||||
x, moving_sum_weight, padding=start_idx + end_idx - 1
|
||||
)
|
||||
.squeeze(1)
|
||||
.t()
|
||||
)
|
||||
moving_sum = moving_sum[end_idx:-start_idx]
|
||||
|
||||
assert src_len == moving_sum.size(0)
|
||||
assert batch_size == moving_sum.size(1)
|
||||
|
@ -18,7 +18,7 @@ class LatencyMetric(object):
|
||||
src_lens,
|
||||
target_padding_mask=None,
|
||||
batch_first: bool = False,
|
||||
start_from_zero: bool = True
|
||||
start_from_zero: bool = True,
|
||||
):
|
||||
assert len(delays.size()) == 2
|
||||
assert len(src_lens.size()) == 2
|
||||
@ -59,11 +59,7 @@ class LatencyMetric(object):
|
||||
start_from_zero: bool = True,
|
||||
):
|
||||
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
|
||||
delays,
|
||||
src_lens,
|
||||
target_padding_mask,
|
||||
batch_first,
|
||||
start_from_zero
|
||||
delays, src_lens, target_padding_mask, batch_first, start_from_zero
|
||||
)
|
||||
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
|
||||
|
||||
@ -89,10 +85,13 @@ class AverageProportion(LatencyMetric):
|
||||
|
||||
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
||||
if target_padding_mask is not None:
|
||||
AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True)
|
||||
AP = torch.sum(
|
||||
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
|
||||
)
|
||||
else:
|
||||
AP = torch.sum(delays, dim=0, keepdim=True)
|
||||
|
||||
@ -116,14 +115,24 @@ class AverageLagging(LatencyMetric):
|
||||
gamma = |y| / |x|
|
||||
tau = argmin_i(delays_i = |x|)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
||||
# tau = argmin_i(delays_i = |x|)
|
||||
tgt_len, bsz = delays.size()
|
||||
lagging_padding_mask = delays >= src_lens
|
||||
lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :]
|
||||
lagging_padding_mask = torch.nn.functional.pad(
|
||||
lagging_padding_mask.t(), (1, 0)
|
||||
).t()[:-1, :]
|
||||
gamma = tgt_lens / src_lens
|
||||
lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
|
||||
lagging = (
|
||||
delays
|
||||
- torch.arange(delays.size(0))
|
||||
.unsqueeze(1)
|
||||
.type_as(delays)
|
||||
.expand_as(delays)
|
||||
/ gamma
|
||||
)
|
||||
lagging.masked_fill_(lagging_padding_mask, 0)
|
||||
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
|
||||
AL = lagging.sum(dim=0, keepdim=True) / tau
|
||||
@ -149,6 +158,7 @@ class DifferentiableAverageLagging(LatencyMetric):
|
||||
2. max(delays_i, delays'_{i-1} + 1 / gamma)
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
|
||||
tgt_len, bsz = delays.size()
|
||||
@ -163,13 +173,18 @@ class DifferentiableAverageLagging(LatencyMetric):
|
||||
new_delays[i] = torch.cat(
|
||||
[
|
||||
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
|
||||
delays[i].unsqueeze(0)
|
||||
delays[i].unsqueeze(0),
|
||||
],
|
||||
dim=0
|
||||
dim=0,
|
||||
).max(dim=0)[0]
|
||||
|
||||
DAL = (
|
||||
new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
|
||||
new_delays
|
||||
- torch.arange(delays.size(0))
|
||||
.unsqueeze(1)
|
||||
.type_as(delays)
|
||||
.expand_as(delays)
|
||||
/ gamma
|
||||
)
|
||||
if target_padding_mask is not None:
|
||||
DAL = DAL.masked_fill(target_padding_mask, 0)
|
||||
@ -186,7 +201,7 @@ class LatencyMetricVariance(LatencyMetric):
|
||||
src_lens,
|
||||
target_padding_mask=None,
|
||||
batch_first: bool = True,
|
||||
start_from_zero: bool = True
|
||||
start_from_zero: bool = True,
|
||||
):
|
||||
assert batch_first
|
||||
assert len(delays.size()) == 3
|
||||
@ -256,25 +271,21 @@ class LatencyInference(object):
|
||||
|
||||
src_lens = src_lens
|
||||
|
||||
delays = (
|
||||
monotonic_step
|
||||
.view(monotonic_step.size(0), -1, monotonic_step.size(-1))
|
||||
.max(dim=1)[0]
|
||||
)
|
||||
delays = monotonic_step.view(
|
||||
monotonic_step.size(0), -1, monotonic_step.size(-1)
|
||||
).max(dim=1)[0]
|
||||
|
||||
delays = (
|
||||
delays.masked_fill(delays >= src_lens, 0)
|
||||
+ (src_lens - 1)
|
||||
.expand_as(delays)
|
||||
.masked_fill(delays < src_lens, 0)
|
||||
)
|
||||
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
|
||||
delays
|
||||
).masked_fill(delays < src_lens, 0)
|
||||
return_dict = {}
|
||||
for key, func in self.metric_calculator.items():
|
||||
return_dict[key] = func(
|
||||
delays.float(), src_lens.float(),
|
||||
delays.float(),
|
||||
src_lens.float(),
|
||||
target_padding_mask=None,
|
||||
batch_first=True,
|
||||
start_from_zero=True
|
||||
start_from_zero=True,
|
||||
).t()
|
||||
|
||||
return return_dict
|
||||
@ -282,8 +293,13 @@ class LatencyInference(object):
|
||||
|
||||
class LatencyTraining(object):
|
||||
def __init__(
|
||||
self, avg_weight, var_weight, avg_type, var_type,
|
||||
stay_on_last_token, average_method,
|
||||
self,
|
||||
avg_weight,
|
||||
var_weight,
|
||||
avg_type,
|
||||
var_type,
|
||||
stay_on_last_token,
|
||||
average_method,
|
||||
):
|
||||
self.avg_weight = avg_weight
|
||||
self.var_weight = var_weight
|
||||
@ -319,17 +335,12 @@ class LatencyTraining(object):
|
||||
attention = attention.view(-1, tgt_len, src_len)
|
||||
|
||||
if not self.stay_on_last_token:
|
||||
residual_attention = \
|
||||
1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
|
||||
attention = torch.cat(
|
||||
[attention[:, :, :-1], residual_attention],
|
||||
dim=2
|
||||
)
|
||||
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
|
||||
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
|
||||
|
||||
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
|
||||
steps = (
|
||||
torch
|
||||
.arange(1, 1 + src_len)
|
||||
torch.arange(1, 1 + src_len)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(1)
|
||||
.expand_as(attention)
|
||||
@ -355,15 +366,12 @@ class LatencyTraining(object):
|
||||
src_lens = src_lens.view(-1, 1)
|
||||
|
||||
# bsz * num_heads_num_layers, tgt_len, src_len
|
||||
expected_delays = (steps * attention).sum(dim=2).view(
|
||||
bsz, num_heads_x_layers, tgt_len
|
||||
expected_delays = (
|
||||
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
|
||||
)
|
||||
|
||||
if target_padding_mask is not None:
|
||||
expected_delays.masked_fill_(
|
||||
target_padding_mask.unsqueeze(1),
|
||||
0
|
||||
)
|
||||
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
|
||||
|
||||
return expected_delays, src_lens
|
||||
|
||||
@ -371,8 +379,7 @@ class LatencyTraining(object):
|
||||
|
||||
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
|
||||
target_padding_mask = (
|
||||
target_padding_mask
|
||||
.unsqueeze(1)
|
||||
target_padding_mask.unsqueeze(1)
|
||||
.expand_as(expected_delays)
|
||||
.contiguous()
|
||||
.view(-1, tgt_len)
|
||||
@ -396,8 +403,11 @@ class LatencyTraining(object):
|
||||
if self.avg_weight > 0.0:
|
||||
if self.avg_type in self.metric_calculator:
|
||||
average_delays = self.metric_calculator[self.avg_type](
|
||||
expected_delays, src_lens, target_padding_mask,
|
||||
batch_first=True, start_from_zero=False
|
||||
expected_delays,
|
||||
src_lens,
|
||||
target_padding_mask,
|
||||
batch_first=True,
|
||||
start_from_zero=False,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"{self.avg_type} is not supported.")
|
||||
@ -408,12 +418,17 @@ class LatencyTraining(object):
|
||||
return 0.0
|
||||
|
||||
def var_loss(self, expected_delays, src_lens, target_padding_mask):
|
||||
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1]
|
||||
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
|
||||
:, :1
|
||||
]
|
||||
if self.var_weight > 0.0:
|
||||
if self.var_type in self.variance_calculator:
|
||||
variance_delays = self.variance_calculator[self.var_type](
|
||||
expected_delays, src_lens, target_padding_mask,
|
||||
batch_first=True, start_from_zero=False
|
||||
expected_delays,
|
||||
src_lens,
|
||||
target_padding_mask,
|
||||
batch_first=True,
|
||||
start_from_zero=False,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"{self.var_type} is not supported.")
|
||||
|
@ -1 +1 @@
|
||||
from . import tasks, criterions, models # noqa
|
||||
from . import criterions, models, tasks # noqa
|
||||
|
@ -6,9 +6,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from examples.speech_recognition.data.replabels import pack_replabels
|
||||
from fairseq import utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from examples.speech_recognition.data.replabels import pack_replabels
|
||||
|
||||
|
||||
@register_criterion("asg_loss")
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
from .asr_dataset import AsrDataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AsrDataset',
|
||||
"AsrDataset",
|
||||
]
|
||||
|
@ -4,6 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
@ -30,16 +31,22 @@ class AsrDataset(FairseqDataset):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, aud_paths, aud_durations_ms, tgt,
|
||||
tgt_dict, ids, speakers,
|
||||
num_mel_bins=80, frame_length=25.0, frame_shift=10.0
|
||||
self,
|
||||
aud_paths,
|
||||
aud_durations_ms,
|
||||
tgt,
|
||||
tgt_dict,
|
||||
ids,
|
||||
speakers,
|
||||
num_mel_bins=80,
|
||||
frame_length=25.0,
|
||||
frame_shift=10.0,
|
||||
):
|
||||
assert frame_length > 0
|
||||
assert frame_shift > 0
|
||||
assert all(x > frame_length for x in aud_durations_ms)
|
||||
self.frame_sizes = [
|
||||
int(1 + (d - frame_length) / frame_shift)
|
||||
for d in aud_durations_ms
|
||||
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
|
||||
]
|
||||
|
||||
assert len(aud_paths) > 0
|
||||
@ -57,13 +64,17 @@ class AsrDataset(FairseqDataset):
|
||||
self.frame_shift = frame_shift
|
||||
|
||||
self.s2s_collater = Seq2SeqCollater(
|
||||
0, 1, pad_index=self.tgt_dict.pad(),
|
||||
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
|
||||
0,
|
||||
1,
|
||||
pad_index=self.tgt_dict.pad(),
|
||||
eos_index=self.tgt_dict.eos(),
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
tgt_item = self.tgt[index] if self.tgt is not None else None
|
||||
|
||||
path = self.aud_paths[index]
|
||||
@ -74,7 +85,7 @@ class AsrDataset(FairseqDataset):
|
||||
sound,
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift
|
||||
frame_shift=self.frame_shift,
|
||||
)
|
||||
output_cmvn = data_utils.apply_mv_norm(output)
|
||||
|
||||
|
@ -12,18 +12,18 @@
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
|
||||
|
||||
class Seq2SeqCollater(object):
|
||||
"""
|
||||
Implements collate function mainly for seq2seq tasks
|
||||
This expects each sample to contain feature (src_tokens) and
|
||||
targets.
|
||||
This collator is also used for aligned training task.
|
||||
Implements collate function mainly for seq2seq tasks
|
||||
This expects each sample to contain feature (src_tokens) and
|
||||
targets.
|
||||
This collator is also used for aligned training task.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -6,52 +6,74 @@
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from collections import namedtuple
|
||||
import concurrent.futures
|
||||
from itertools import chain
|
||||
import argparse
|
||||
import os
|
||||
import concurrent.futures
|
||||
import json
|
||||
import sentencepiece as spm
|
||||
import multiprocessing
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
|
||||
import sentencepiece as spm
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
MILLISECONDS_TO_SECONDS = 0.001
|
||||
|
||||
|
||||
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
||||
import torchaudio
|
||||
|
||||
input = {}
|
||||
output = {}
|
||||
si, ei = torchaudio.info(aud_path)
|
||||
input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS)
|
||||
input["length_ms"] = int(
|
||||
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
|
||||
)
|
||||
input["path"] = aud_path
|
||||
|
||||
token = " ".join(sp.EncodeAsPieces(lable))
|
||||
ids = tgt_dict.encode_line(token, append_eos=False)
|
||||
output["text"] = lable
|
||||
output["token"] = token
|
||||
output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids]))
|
||||
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
|
||||
return {utt_id: {"input": input, "output": output}}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True,
|
||||
help="input directories with audio files")
|
||||
parser.add_argument("--labels", required=True,
|
||||
help="aggregated input labels with format <ID LABEL> per line",
|
||||
type=argparse.FileType('r', encoding='UTF-8'))
|
||||
parser.add_argument("--spm-model", required=True,
|
||||
help="sentencepiece model to use for encoding",
|
||||
type=argparse.FileType('r', encoding='UTF-8'))
|
||||
parser.add_argument("--dictionary", required=True,
|
||||
help="file to load fairseq dictionary from",
|
||||
type=argparse.FileType('r', encoding='UTF-8'))
|
||||
parser.add_argument(
|
||||
"--audio-dirs",
|
||||
nargs="+",
|
||||
default=["-"],
|
||||
required=True,
|
||||
help="input directories with audio files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
required=True,
|
||||
help="aggregated input labels with format <ID LABEL> per line",
|
||||
type=argparse.FileType("r", encoding="UTF-8"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spm-model",
|
||||
required=True,
|
||||
help="sentencepiece model to use for encoding",
|
||||
type=argparse.FileType("r", encoding="UTF-8"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dictionary",
|
||||
required=True,
|
||||
help="file to load fairseq dictionary from",
|
||||
type=argparse.FileType("r", encoding="UTF-8"),
|
||||
)
|
||||
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
||||
parser.add_argument("--output", required=True, type=argparse.FileType('w'),
|
||||
help="path to save json output")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
type=argparse.FileType("w"),
|
||||
help="path to save json output",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
@ -64,15 +86,17 @@ def main():
|
||||
(utt_id, label) = line.split(" ", 1)
|
||||
labels[utt_id] = label
|
||||
if len(labels) == 0:
|
||||
raise Exception('No labels found in ', args.labels_path)
|
||||
raise Exception("No labels found in ", args.labels_path)
|
||||
|
||||
Sample = namedtuple('Sample', 'aud_path utt_id')
|
||||
Sample = namedtuple("Sample", "aud_path utt_id")
|
||||
samples = []
|
||||
for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs):
|
||||
for path, _, files in chain.from_iterable(
|
||||
os.walk(path) for path in args.audio_dirs
|
||||
):
|
||||
for f in files:
|
||||
if f.endswith(args.audio_format):
|
||||
if len(os.path.splitext(f)) != 2:
|
||||
raise Exception('Expect <utt_id.extension> file name. Got: ', f)
|
||||
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
|
||||
utt_id = os.path.splitext(f)[0]
|
||||
if utt_id not in labels:
|
||||
continue
|
||||
@ -81,12 +105,17 @@ def main():
|
||||
utts = {}
|
||||
num_cpu = multiprocessing.cpu_count()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
||||
future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples}
|
||||
future_to_sample = {
|
||||
executor.submit(
|
||||
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
|
||||
): s
|
||||
for s in samples
|
||||
}
|
||||
for future in concurrent.futures.as_completed(future_to_sample):
|
||||
try:
|
||||
data = future.result()
|
||||
except Exception as exc:
|
||||
print('generated an exception: ', exc)
|
||||
print("generated an exception: ", exc)
|
||||
else:
|
||||
utts.update(data)
|
||||
json.dump({"utts": utts}, args.output, indent=4)
|
||||
|
@ -8,17 +8,17 @@
|
||||
Run inference for pre-processed data with a trained model.
|
||||
"""
|
||||
|
||||
import editdistance
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import editdistance
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.data.data_utils import post_process
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
@ -52,10 +52,12 @@ output units",
|
||||
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder"
|
||||
"--w2l-decoder",
|
||||
choices=["viterbi", "kenlm", "fairseqlm"],
|
||||
help="use a w2l decoder",
|
||||
)
|
||||
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
|
||||
parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm")
|
||||
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
|
||||
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
|
||||
parser.add_argument("--beam-threshold", type=float, default=25.0)
|
||||
parser.add_argument("--beam-size-token", type=float, default=100)
|
||||
@ -87,10 +89,10 @@ def check_args(args):
|
||||
# assert args.path is not None, "--path required for generation!"
|
||||
# assert args.results_path is not None, "--results_path required for generation!"
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
not args.sampling or args.nbest == args.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
args.replace_unk is None or args.raw_text
|
||||
args.replace_unk is None or args.raw_text
|
||||
), "--replace-unk requires a raw text dataset (--raw-text)"
|
||||
|
||||
|
||||
@ -110,7 +112,7 @@ def get_dataset_itr(args, task, models):
|
||||
|
||||
|
||||
def process_predictions(
|
||||
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
||||
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
||||
):
|
||||
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
||||
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
||||
@ -122,16 +124,25 @@ def process_predictions(
|
||||
|
||||
if res_files is not None:
|
||||
print(
|
||||
"{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
|
||||
"{} ({}-{})".format(hyp_pieces, speaker, id),
|
||||
file=res_files["hypo.units"],
|
||||
)
|
||||
print(
|
||||
"{} ({}-{})".format(hyp_words, speaker, id),
|
||||
file=res_files["hypo.words"],
|
||||
)
|
||||
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
|
||||
|
||||
tgt_pieces = tgt_dict.string(target_tokens)
|
||||
tgt_words = post_process(tgt_pieces, args.remove_bpe)
|
||||
|
||||
if res_files is not None:
|
||||
print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
|
||||
print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
|
||||
print(
|
||||
"{} ({}-{})".format(tgt_pieces, speaker, id),
|
||||
file=res_files["ref.units"],
|
||||
)
|
||||
print(
|
||||
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
|
||||
)
|
||||
# only score top hypothesis
|
||||
if not args.quiet:
|
||||
logger.debug("HYPO:" + hyp_words)
|
||||
@ -146,7 +157,7 @@ def process_predictions(
|
||||
def prepare_result_files(args):
|
||||
def get_res_file(file_prefix):
|
||||
if args.num_shards > 1:
|
||||
file_prefix = f'{args.shard_id}_{file_prefix}'
|
||||
file_prefix = f"{args.shard_id}_{file_prefix}"
|
||||
path = os.path.join(
|
||||
args.results_path,
|
||||
"{}-{}-{}.txt".format(
|
||||
@ -166,15 +177,17 @@ def prepare_result_files(args):
|
||||
}
|
||||
|
||||
|
||||
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None):
|
||||
def load_models_and_criterions(
|
||||
filenames, data_path, arg_overrides=None, task=None, model_state=None
|
||||
):
|
||||
models = []
|
||||
criterions = []
|
||||
|
||||
if arg_overrides is None:
|
||||
arg_overrides = {}
|
||||
|
||||
arg_overrides['wer_args'] = None
|
||||
arg_overrides['data'] = data_path
|
||||
arg_overrides["wer_args"] = None
|
||||
arg_overrides["data"] = data_path
|
||||
|
||||
if filenames is None:
|
||||
assert model_state is not None
|
||||
@ -205,8 +218,7 @@ def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=No
|
||||
|
||||
|
||||
def optimize_models(args, use_cuda, models):
|
||||
"""Optimize ensemble for generation
|
||||
"""
|
||||
"""Optimize ensemble for generation"""
|
||||
for model in models:
|
||||
model.make_generation_fast_(
|
||||
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
||||
@ -229,7 +241,7 @@ class ExistingEmissionsDecoder(object):
|
||||
emissions = np.stack(self.emissions[ids])
|
||||
except:
|
||||
print([x.shape for x in self.emissions[ids]])
|
||||
raise Exception('invalid sizes')
|
||||
raise Exception("invalid sizes")
|
||||
emissions = torch.from_numpy(emissions)
|
||||
return self.decoder.decode(emissions)
|
||||
|
||||
@ -300,7 +312,9 @@ def main(args, task=None, model_state=None):
|
||||
|
||||
return W2lFairseqLMDecoder(args, task.target_dictionary)
|
||||
else:
|
||||
print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment')
|
||||
print(
|
||||
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
|
||||
)
|
||||
|
||||
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
|
||||
generator = build_generator(args)
|
||||
@ -361,7 +375,11 @@ def main(args, task=None, model_state=None):
|
||||
encoder_out = models[0](**sample["net_input"])
|
||||
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
|
||||
for i, id in enumerate(sample["id"]):
|
||||
padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None
|
||||
padding = (
|
||||
encoder_out["encoder_padding_mask"][i].cpu().numpy()
|
||||
if encoder_out["encoder_padding_mask"] is not None
|
||||
else None
|
||||
)
|
||||
features[id.item()] = (feat[i], padding)
|
||||
continue
|
||||
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
||||
@ -372,20 +390,31 @@ def main(args, task=None, model_state=None):
|
||||
speaker = None
|
||||
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
||||
id = sample_id
|
||||
toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :]
|
||||
target_tokens = (
|
||||
utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
||||
toks = (
|
||||
sample["target"][i, :]
|
||||
if "target_label" not in sample
|
||||
else sample["target_label"][i, :]
|
||||
)
|
||||
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
||||
# Process top predictions
|
||||
errs, length = process_predictions(
|
||||
args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id
|
||||
args,
|
||||
hypos[i],
|
||||
None,
|
||||
tgt_dict,
|
||||
target_tokens,
|
||||
res_files,
|
||||
speaker,
|
||||
id,
|
||||
)
|
||||
errs_t += errs
|
||||
lengths_t += length
|
||||
|
||||
wps_meter.update(num_generated_tokens)
|
||||
t.log({"wps": round(wps_meter.avg)})
|
||||
num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
||||
num_sentences += (
|
||||
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
||||
)
|
||||
|
||||
wer = None
|
||||
if args.dump_emissions:
|
||||
@ -413,7 +442,7 @@ def main(args, task=None, model_state=None):
|
||||
gen_timer.sum,
|
||||
num_sentences / gen_timer.sum,
|
||||
1.0 / gen_timer.avg,
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
||||
return task, wer
|
||||
@ -424,6 +453,7 @@ def make_parser():
|
||||
parser = add_asr_eval_argument(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def cli_main():
|
||||
parser = make_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
model_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.speech_recognition.models.' + model_name)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
model_name = file[: file.find(".py")]
|
||||
importlib.import_module("examples.speech_recognition.models." + model_name)
|
||||
|
@ -9,18 +9,22 @@ from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
||||
from fairseq import utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqEncoderModel,
|
||||
FairseqIncrementalDecoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules import LinearizedConvolution
|
||||
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
||||
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
|
||||
from fairseq.modules import (
|
||||
LinearizedConvolution,
|
||||
TransformerDecoderLayer,
|
||||
TransformerEncoderLayer,
|
||||
VGGBlock,
|
||||
)
|
||||
|
||||
|
||||
@register_model("asr_vggtransformer")
|
||||
@ -29,6 +33,7 @@ class VGGTransformerModel(FairseqEncoderDecoderModel):
|
||||
Transformers with convolutional context for ASR
|
||||
https://arxiv.org/abs/1904.11660
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@ -602,18 +607,22 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
self.layers = nn.ModuleList()
|
||||
if conv_config[-1][0] != transformer_config[0][0]:
|
||||
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
|
||||
self.layers.append(TransformerDecoderLayer(
|
||||
prepare_transformer_decoder_params(*transformer_config[0])
|
||||
))
|
||||
self.layers.append(
|
||||
TransformerDecoderLayer(
|
||||
prepare_transformer_decoder_params(*transformer_config[0])
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(1, len(transformer_config)):
|
||||
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
||||
self.layers.append(
|
||||
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
||||
)
|
||||
self.layers.append(TransformerDecoderLayer(
|
||||
prepare_transformer_decoder_params(*transformer_config[i])
|
||||
))
|
||||
self.layers.append(
|
||||
TransformerDecoderLayer(
|
||||
prepare_transformer_decoder_params(*transformer_config[i])
|
||||
)
|
||||
)
|
||||
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
||||
@ -713,6 +722,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
|
||||
@register_model("asr_vggtransformer_encoder")
|
||||
class VGGTransformerEncoderModel(FairseqEncoderModel):
|
||||
def __init__(self, encoder):
|
||||
|
@ -10,7 +10,6 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderModel,
|
||||
|
@ -1,7 +1,8 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
task_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.speech_recognition.tasks.' + task_name)
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
task_name = file[: file.find(".py")]
|
||||
importlib.import_module("examples.speech_recognition.tasks." + task_name)
|
||||
|
@ -9,10 +9,10 @@ import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
||||
from examples.speech_recognition.data import AsrDataset
|
||||
from examples.speech_recognition.data.replabels import replabel_symbol
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
def get_asr_dataset_from_json(data_json_path, tgt_dict):
|
||||
@ -78,10 +78,20 @@ class SpeechRecognitionTask(LegacyFairseqTask):
|
||||
parser.add_argument(
|
||||
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
|
||||
)
|
||||
parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N',
|
||||
help='max number of frames in the source sequence')
|
||||
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the target sequence')
|
||||
parser.add_argument(
|
||||
"--max-source-positions",
|
||||
default=sys.maxsize,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of frames in the source sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-target-positions",
|
||||
default=1024,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the target sequence",
|
||||
)
|
||||
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args)
|
||||
|
@ -9,16 +9,18 @@
|
||||
Wav2letter decoders.
|
||||
"""
|
||||
|
||||
from collections import namedtuple, deque
|
||||
import gc
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
import torch
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from collections import deque, namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from examples.speech_recognition.data.replabels import unpack_replabels
|
||||
from fairseq import tasks
|
||||
from fairseq.utils import apply_to_sample
|
||||
from examples.speech_recognition.data.replabels import unpack_replabels
|
||||
|
||||
|
||||
try:
|
||||
from wav2letter.common import create_word_dict, load_words
|
||||
|
@ -4,66 +4,76 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
import csv
|
||||
import os
|
||||
import os.path as op
|
||||
from glob import glob
|
||||
import zipfile
|
||||
import csv
|
||||
from functools import reduce
|
||||
from typing import Dict, Any, List
|
||||
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
|
||||
from glob import glob
|
||||
from multiprocessing import cpu_count
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import sentencepiece as sp
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import sentencepiece as sp
|
||||
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
|
||||
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
|
||||
from tqdm import tqdm
|
||||
|
||||
UNK_TOKEN, UNK_TOKEN_ID = '<unk>', 3
|
||||
BOS_TOKEN, BOS_TOKEN_ID = '<s>', 0
|
||||
EOS_TOKEN, EOS_TOKEN_ID = '</s>', 2
|
||||
PAD_TOKEN, PAD_TOKEN_ID = '<pad>', 1
|
||||
|
||||
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
|
||||
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
|
||||
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
|
||||
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
|
||||
|
||||
|
||||
def gen_vocab(
|
||||
input_path: str, output_path_prefix: str, model_type='bpe',
|
||||
vocab_size=1000,
|
||||
input_path: str,
|
||||
output_path_prefix: str,
|
||||
model_type="bpe",
|
||||
vocab_size=1000,
|
||||
):
|
||||
# Train SentencePiece Model
|
||||
arguments = [
|
||||
f'--input={input_path}',
|
||||
f'--model_prefix={output_path_prefix}',
|
||||
f'--model_type={model_type}',
|
||||
f'--vocab_size={vocab_size}',
|
||||
'--character_coverage=1.0',
|
||||
f'--num_threads={cpu_count()}',
|
||||
f'--unk_id={UNK_TOKEN_ID}',
|
||||
f'--bos_id={BOS_TOKEN_ID}',
|
||||
f'--eos_id={EOS_TOKEN_ID}',
|
||||
f'--pad_id={PAD_TOKEN_ID}'
|
||||
f"--input={input_path}",
|
||||
f"--model_prefix={output_path_prefix}",
|
||||
f"--model_type={model_type}",
|
||||
f"--vocab_size={vocab_size}",
|
||||
"--character_coverage=1.0",
|
||||
f"--num_threads={cpu_count()}",
|
||||
f"--unk_id={UNK_TOKEN_ID}",
|
||||
f"--bos_id={BOS_TOKEN_ID}",
|
||||
f"--eos_id={EOS_TOKEN_ID}",
|
||||
f"--pad_id={PAD_TOKEN_ID}",
|
||||
]
|
||||
sp.SentencePieceTrainer.Train(' '.join(arguments))
|
||||
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
||||
# Export fairseq dictionary
|
||||
spm = sp.SentencePieceProcessor()
|
||||
spm.Load(output_path_prefix + '.model')
|
||||
spm.Load(output_path_prefix + ".model")
|
||||
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
||||
assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \
|
||||
vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \
|
||||
vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \
|
||||
vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
||||
assert (
|
||||
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
|
||||
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
|
||||
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
|
||||
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
||||
)
|
||||
vocab = {
|
||||
i: s for i, s in vocab.items()
|
||||
i: s
|
||||
for i, s in vocab.items()
|
||||
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
||||
}
|
||||
with open(output_path_prefix + '.txt', 'w') as f_out:
|
||||
with open(output_path_prefix + ".txt", "w") as f_out:
|
||||
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
||||
f_out.write(f'{s} 1\n')
|
||||
f_out.write(f"{s} 1\n")
|
||||
|
||||
|
||||
def extract_fbank_features(waveform, sample_rate, output_path=None,
|
||||
n_mel_bins=80, apply_utterance_cmvn=True,
|
||||
overwrite=False):
|
||||
def extract_fbank_features(
|
||||
waveform,
|
||||
sample_rate,
|
||||
output_path=None,
|
||||
n_mel_bins=80,
|
||||
apply_utterance_cmvn=True,
|
||||
overwrite=False,
|
||||
):
|
||||
if output_path is not None and op.exists(output_path) and not overwrite:
|
||||
return
|
||||
|
||||
@ -74,8 +84,10 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
|
||||
if features is None:
|
||||
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
||||
if features is None:
|
||||
raise ImportError('Please install pyKaldi or torchaudio to enable '
|
||||
'online filterbank feature extraction')
|
||||
raise ImportError(
|
||||
"Please install pyKaldi or torchaudio to enable "
|
||||
"online filterbank feature extraction"
|
||||
)
|
||||
|
||||
if apply_utterance_cmvn:
|
||||
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
|
||||
@ -89,8 +101,8 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
|
||||
def create_zip(data_root, zip_path):
|
||||
cwd = os.path.abspath(os.curdir)
|
||||
os.chdir(data_root)
|
||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f:
|
||||
for filename in tqdm(glob('*.npy')):
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
|
||||
for filename in tqdm(glob("*.npy")):
|
||||
f.write(filename)
|
||||
os.chdir(cwd)
|
||||
|
||||
@ -101,69 +113,80 @@ def is_npy_data(data: bytes) -> bool:
|
||||
|
||||
def get_zip_manifest(zip_root, zip_filename):
|
||||
zip_path = op.join(zip_root, zip_filename)
|
||||
with zipfile.ZipFile(zip_path, mode='r') as f:
|
||||
with zipfile.ZipFile(zip_path, mode="r") as f:
|
||||
info = f.infolist()
|
||||
manifest = {}
|
||||
for i in tqdm(info):
|
||||
utt_id = op.splitext(i.filename)[0]
|
||||
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
||||
manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}'
|
||||
with open(zip_path, 'rb') as f:
|
||||
manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
|
||||
with open(zip_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
data = f.read(file_size)
|
||||
assert len(data) > 1 and is_npy_data(data)
|
||||
return manifest
|
||||
|
||||
|
||||
def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml',
|
||||
specaugment_policy='lb'):
|
||||
assert specaugment_policy in {'lb', 'ld'}
|
||||
def gen_config_yaml(
|
||||
data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
|
||||
):
|
||||
assert specaugment_policy in {"lb", "ld"}
|
||||
data_root = op.abspath(data_root)
|
||||
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
|
||||
writer.set_audio_root(op.abspath(data_root))
|
||||
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
|
||||
writer.set_input_channels(1)
|
||||
writer.set_input_feat_per_channel(80)
|
||||
if specaugment_policy == 'lb':
|
||||
if specaugment_policy == "lb":
|
||||
writer.set_specaugment_lb_policy()
|
||||
else:
|
||||
writer.set_specaugment_ld_policy()
|
||||
writer.set_bpe_tokenizer(
|
||||
{'bpe': 'sentencepiece',
|
||||
'sentencepiece_model': op.join(data_root, spm_filename)}
|
||||
{
|
||||
"bpe": "sentencepiece",
|
||||
"sentencepiece_model": op.join(data_root, spm_filename),
|
||||
}
|
||||
)
|
||||
writer.set_feature_transforms('_train', ['specaugment'])
|
||||
writer.set_feature_transforms("_train", ["specaugment"])
|
||||
writer.flush()
|
||||
|
||||
|
||||
def save_df_to_tsv(dataframe, path):
|
||||
dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8",
|
||||
escapechar='\\', quoting=csv.QUOTE_NONE)
|
||||
dataframe.to_csv(
|
||||
path,
|
||||
sep="\t",
|
||||
header=True,
|
||||
index=False,
|
||||
encoding="utf-8",
|
||||
escapechar="\\",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
|
||||
|
||||
def filter_manifest_df(df, is_train_split=False, extra_filters=None,
|
||||
min_n_frames=5, max_n_frames=3000):
|
||||
def filter_manifest_df(
|
||||
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
|
||||
):
|
||||
filters = {
|
||||
'no speech': df['audio'] == '',
|
||||
f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames,
|
||||
'empty sentence': df['tgt_text'] == '',
|
||||
"no speech": df["audio"] == "",
|
||||
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
|
||||
"empty sentence": df["tgt_text"] == "",
|
||||
}
|
||||
if is_train_split:
|
||||
filters[f'long speech (>{max_n_frames} frames)'] = \
|
||||
df['n_frames'] > max_n_frames
|
||||
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
|
||||
if extra_filters is not None:
|
||||
filters.update(extra_filters)
|
||||
invalid = reduce(lambda x, y: x | y, filters.values())
|
||||
valid = ~invalid
|
||||
print(
|
||||
'| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) +
|
||||
f', total {invalid.sum()} filtered, {valid.sum()} remained.'
|
||||
"| "
|
||||
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
|
||||
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
|
||||
)
|
||||
return df[valid]
|
||||
|
||||
|
||||
class S2TDataConfigWriter(object):
|
||||
DEFAULT_VOCAB_FILENAME = 'dict.txt'
|
||||
DEFAULT_VOCAB_FILENAME = "dict.txt"
|
||||
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
||||
DEFAULT_INPUT_CHANNELS = 1
|
||||
|
||||
@ -171,48 +194,69 @@ class S2TDataConfigWriter(object):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print('Please install PyYAML to load YAML files for S2T data config')
|
||||
print("Please install PyYAML to load YAML files for S2T data config")
|
||||
self.yaml = yaml
|
||||
self.yaml_path = yaml_path
|
||||
self.config = {}
|
||||
|
||||
def flush(self):
|
||||
with open(self.yaml_path, 'w') as f:
|
||||
with open(self.yaml_path, "w") as f:
|
||||
self.yaml.dump(self.config, f)
|
||||
|
||||
def set_audio_root(self, audio_root=''):
|
||||
self.config['audio_root'] = audio_root
|
||||
def set_audio_root(self, audio_root=""):
|
||||
self.config["audio_root"] = audio_root
|
||||
|
||||
def set_vocab_filename(self, vocab_filename='dict.txt'):
|
||||
self.config['vocab_filename'] = vocab_filename
|
||||
def set_vocab_filename(self, vocab_filename="dict.txt"):
|
||||
self.config["vocab_filename"] = vocab_filename
|
||||
|
||||
def set_specaugment(self, time_wrap_w: int, freq_mask_n: int,
|
||||
freq_mask_f: int, time_mask_n: int, time_mask_t: int,
|
||||
time_mask_p: float):
|
||||
self.config['specaugment'] = {
|
||||
'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n,
|
||||
'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n,
|
||||
'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p,
|
||||
def set_specaugment(
|
||||
self,
|
||||
time_wrap_w: int,
|
||||
freq_mask_n: int,
|
||||
freq_mask_f: int,
|
||||
time_mask_n: int,
|
||||
time_mask_t: int,
|
||||
time_mask_p: float,
|
||||
):
|
||||
self.config["specaugment"] = {
|
||||
"time_wrap_W": time_wrap_w,
|
||||
"freq_mask_N": freq_mask_n,
|
||||
"freq_mask_F": freq_mask_f,
|
||||
"time_mask_N": time_mask_n,
|
||||
"time_mask_T": time_mask_t,
|
||||
"time_mask_p": time_mask_p,
|
||||
}
|
||||
|
||||
def set_specaugment_lb_policy(self):
|
||||
self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27,
|
||||
time_mask_n=1, time_mask_t=100, time_mask_p=1.0)
|
||||
self.set_specaugment(
|
||||
time_wrap_w=0,
|
||||
freq_mask_n=1,
|
||||
freq_mask_f=27,
|
||||
time_mask_n=1,
|
||||
time_mask_t=100,
|
||||
time_mask_p=1.0,
|
||||
)
|
||||
|
||||
def set_specaugment_ld_policy(self):
|
||||
self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27,
|
||||
time_mask_n=2, time_mask_t=100, time_mask_p=1.0)
|
||||
self.set_specaugment(
|
||||
time_wrap_w=0,
|
||||
freq_mask_n=2,
|
||||
freq_mask_f=27,
|
||||
time_mask_n=2,
|
||||
time_mask_t=100,
|
||||
time_mask_p=1.0,
|
||||
)
|
||||
|
||||
def set_input_channels(self, input_channels=1):
|
||||
self.config['input_channels'] = input_channels
|
||||
self.config["input_channels"] = input_channels
|
||||
|
||||
def set_input_feat_per_channel(self, input_feat_per_channel=80):
|
||||
self.config['input_feat_per_channel'] = input_feat_per_channel
|
||||
self.config["input_feat_per_channel"] = input_feat_per_channel
|
||||
|
||||
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
||||
self.config['bpe_tokenizer'] = bpe_tokenizer
|
||||
self.config["bpe_tokenizer"] = bpe_tokenizer
|
||||
|
||||
def set_feature_transforms(self, split, transforms: List[str]):
|
||||
if 'transforms' not in self.config:
|
||||
self.config['transforms'] = {}
|
||||
self.config['transforms'][split] = transforms
|
||||
if "transforms" not in self.config:
|
||||
self.config["transforms"] = {}
|
||||
self.config["transforms"][split] = transforms
|
||||
|
@ -5,30 +5,35 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
import os.path as op
|
||||
import shutil
|
||||
from typing import Tuple, Optional
|
||||
import csv
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import torchaudio
|
||||
from examples.speech_to_text.data_utils import (
|
||||
create_zip,
|
||||
extract_fbank_features,
|
||||
filter_manifest_df,
|
||||
gen_config_yaml,
|
||||
gen_vocab,
|
||||
get_zip_manifest,
|
||||
save_df_to_tsv,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchaudio.datasets.utils import download_url, extract_archive
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
from torch import Tensor
|
||||
|
||||
from examples.speech_to_text.data_utils import (
|
||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
||||
extract_fbank_features, gen_config_yaml, filter_manifest_df
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
||||
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
||||
|
||||
|
||||
class CoVoST(Dataset):
|
||||
@ -44,40 +49,82 @@ class CoVoST(Dataset):
|
||||
found at root path. (default: ``False``).
|
||||
"""
|
||||
|
||||
CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \
|
||||
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
|
||||
COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \
|
||||
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
|
||||
CV_URL_TEMPLATE = (
|
||||
"https://voice-prod-bundler-ee1969a6ce8178826482b88"
|
||||
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
|
||||
)
|
||||
COVOST_URL_TEMPLATE = (
|
||||
"https://dl.fbaipublicfiles.com/covost/"
|
||||
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
|
||||
)
|
||||
|
||||
VERSIONS = {2}
|
||||
SPLITS = ['train', 'dev', 'test']
|
||||
SPLITS = ["train", "dev", "test"]
|
||||
|
||||
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
|
||||
|
||||
XX_EN_LANGUAGES = {
|
||||
1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn',
|
||||
'zh-CN'],
|
||||
2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn',
|
||||
'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy']
|
||||
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
|
||||
2: [
|
||||
"fr",
|
||||
"de",
|
||||
"es",
|
||||
"ca",
|
||||
"it",
|
||||
"ru",
|
||||
"zh-CN",
|
||||
"pt",
|
||||
"fa",
|
||||
"et",
|
||||
"mn",
|
||||
"nl",
|
||||
"tr",
|
||||
"ar",
|
||||
"sv-SE",
|
||||
"lv",
|
||||
"sl",
|
||||
"ta",
|
||||
"ja",
|
||||
"id",
|
||||
"cy",
|
||||
],
|
||||
}
|
||||
EN_XX_LANGUAGES = {
|
||||
1: [],
|
||||
2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et',
|
||||
'id',
|
||||
'ar', 'ta', 'lv', 'ja']
|
||||
2: [
|
||||
"de",
|
||||
"tr",
|
||||
"fa",
|
||||
"sv-SE",
|
||||
"mn",
|
||||
"zh-CN",
|
||||
"cy",
|
||||
"ca",
|
||||
"sl",
|
||||
"et",
|
||||
"id",
|
||||
"ar",
|
||||
"ta",
|
||||
"lv",
|
||||
"ja",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, root: str, split: str, source_language: str,
|
||||
target_language: Optional[str] = None, version: int = 2,
|
||||
download: bool = False
|
||||
self,
|
||||
root: str,
|
||||
split: str,
|
||||
source_language: str,
|
||||
target_language: Optional[str] = None,
|
||||
version: int = 2,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
assert version in self.VERSIONS and split in self.SPLITS
|
||||
assert source_language is not None
|
||||
self.no_translation = (target_language is None)
|
||||
self.no_translation = target_language is None
|
||||
if not self.no_translation:
|
||||
assert 'en' in {source_language, target_language}
|
||||
if source_language == 'en':
|
||||
assert "en" in {source_language, target_language}
|
||||
if source_language == "en":
|
||||
assert target_language in self.EN_XX_LANGUAGES[version]
|
||||
else:
|
||||
assert source_language in self.XX_EN_LANGUAGES[version]
|
||||
@ -85,51 +132,60 @@ class CoVoST(Dataset):
|
||||
# Hack here so that we can get "split" column from CoVoST TSV.
|
||||
# Note that we use CoVoST train split for ASR which is an extension
|
||||
# to Common Voice train split.
|
||||
target_language = 'de' if source_language == 'en' else 'en'
|
||||
target_language = "de" if source_language == "en" else "en"
|
||||
|
||||
self.root = os.path.join(root, 'raw')
|
||||
self.root = os.path.join(root, "raw")
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
|
||||
cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version],
|
||||
lang=source_language)
|
||||
cv_url = self.CV_URL_TEMPLATE.format(
|
||||
ver=self.CV_VERSION_ID[version], lang=source_language
|
||||
)
|
||||
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
|
||||
if download:
|
||||
if not os.path.isfile(cv_archive):
|
||||
download_url(cv_url, self.root, hash_value=None)
|
||||
extract_archive(cv_archive)
|
||||
|
||||
covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language,
|
||||
tgt_lang=target_language)
|
||||
covost_url = self.COVOST_URL_TEMPLATE.format(
|
||||
src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
|
||||
if download:
|
||||
if not os.path.isfile(covost_archive):
|
||||
download_url(covost_url, self.root, hash_value=None)
|
||||
extract_archive(covost_archive)
|
||||
|
||||
cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv'))
|
||||
cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
|
||||
covost_tsv = self.load_from_tsv(
|
||||
os.path.join(self.root,
|
||||
os.path.basename(covost_url).replace('.tar.gz', ''))
|
||||
os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
|
||||
)
|
||||
df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']],
|
||||
right=covost_tsv[['path', 'translation', 'split']],
|
||||
how='inner', on='path')
|
||||
if split == 'train':
|
||||
df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')]
|
||||
df = pd.merge(
|
||||
left=cv_tsv[["path", "sentence", "client_id"]],
|
||||
right=covost_tsv[["path", "translation", "split"]],
|
||||
how="inner",
|
||||
on="path",
|
||||
)
|
||||
if split == "train":
|
||||
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
|
||||
else:
|
||||
df = df[df['split'] == split]
|
||||
self.data = df.to_dict(orient='index').items()
|
||||
df = df[df["split"] == split]
|
||||
self.data = df.to_dict(orient="index").items()
|
||||
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
|
||||
|
||||
@classmethod
|
||||
def load_from_tsv(cls, path: str):
|
||||
return pd.read_csv(
|
||||
path, sep='\t', header=0, encoding='utf-8', escapechar='\\',
|
||||
quoting=csv.QUOTE_NONE, na_filter=False
|
||||
path,
|
||||
sep="\t",
|
||||
header=0,
|
||||
encoding="utf-8",
|
||||
escapechar="\\",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
na_filter=False,
|
||||
)
|
||||
|
||||
def __getitem__(
|
||||
self, n: int
|
||||
self, n: int
|
||||
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
|
||||
"""Load the n-th sample from the dataset.
|
||||
|
||||
@ -141,12 +197,12 @@ class CoVoST(Dataset):
|
||||
sample_id)``
|
||||
"""
|
||||
data = self.data[n]
|
||||
path = os.path.join(self.root, 'clips', data['path'])
|
||||
path = os.path.join(self.root, "clips", data["path"])
|
||||
waveform, sample_rate = torchaudio.load(path)
|
||||
sentence = data['sentence']
|
||||
translation = None if self.no_translation else data['translation']
|
||||
speaker_id = data['client_id']
|
||||
_id = data['path'].replace('.mp3', '')
|
||||
sentence = data["sentence"]
|
||||
translation = None if self.no_translation else data["translation"]
|
||||
speaker_id = data["client_id"]
|
||||
_id = data["path"].replace(".mp3", "")
|
||||
return waveform, sample_rate, sentence, translation, speaker_id, _id
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -157,76 +213,82 @@ def process(args):
|
||||
root = op.join(args.data_root, args.src_lang)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
# Extract features
|
||||
feature_root = op.join(root, 'fbank80')
|
||||
feature_root = op.join(root, "fbank80")
|
||||
os.makedirs(feature_root, exist_ok=True)
|
||||
for split in CoVoST.SPLITS:
|
||||
print(f'Fetching split {split}...')
|
||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang,
|
||||
download=True)
|
||||
print('Extracting log mel filter bank features...')
|
||||
print(f"Fetching split {split}...")
|
||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
|
||||
print("Extracting log mel filter bank features...")
|
||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||
extract_fbank_features(waveform, sample_rate,
|
||||
op.join(feature_root, f'{utt_id}.npy'))
|
||||
extract_fbank_features(
|
||||
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
|
||||
)
|
||||
# Pack features into ZIP
|
||||
zip_filename = 'fbank80.zip'
|
||||
zip_filename = "fbank80.zip"
|
||||
zip_path = op.join(root, zip_filename)
|
||||
print('ZIPing features...')
|
||||
print("ZIPing features...")
|
||||
create_zip(feature_root, zip_path)
|
||||
print('Fetching ZIP manifest...')
|
||||
zip_manifest = get_zip_manifest(args.data_root,
|
||||
f'{args.src_lang}/{zip_filename}')
|
||||
print("Fetching ZIP manifest...")
|
||||
zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
|
||||
# Generate TSV manifest
|
||||
print('Generating manifest...')
|
||||
print("Generating manifest...")
|
||||
train_text = []
|
||||
task = f'asr_{args.src_lang}'
|
||||
task = f"asr_{args.src_lang}"
|
||||
if args.tgt_lang is not None:
|
||||
task = f'st_{args.src_lang}_{args.tgt_lang}'
|
||||
task = f"st_{args.src_lang}_{args.tgt_lang}"
|
||||
for split in CoVoST.SPLITS:
|
||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
|
||||
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
||||
manifest['id'].append(utt_id)
|
||||
manifest['audio'].append(zip_manifest[utt_id])
|
||||
manifest["id"].append(utt_id)
|
||||
manifest["audio"].append(zip_manifest[utt_id])
|
||||
duration_ms = int(wav.size(1) / sr * 1000)
|
||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
||||
manifest['tgt_text'].append(
|
||||
src_utt if args.tgt_lang is None else tgt_utt
|
||||
)
|
||||
manifest['speaker'].append(speaker_id)
|
||||
is_train_split = split.startswith('train')
|
||||
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
||||
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
|
||||
manifest["speaker"].append(speaker_id)
|
||||
is_train_split = split.startswith("train")
|
||||
if is_train_split:
|
||||
train_text.extend(manifest['tgt_text'])
|
||||
train_text.extend(manifest["tgt_text"])
|
||||
df = pd.DataFrame.from_dict(manifest)
|
||||
df = filter_manifest_df(df, is_train_split=is_train_split)
|
||||
save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv'))
|
||||
save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
|
||||
# Generate vocab
|
||||
vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size)
|
||||
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}'
|
||||
with NamedTemporaryFile(mode='w') as f:
|
||||
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
|
||||
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
|
||||
with NamedTemporaryFile(mode="w") as f:
|
||||
for t in train_text:
|
||||
f.write(t + '\n')
|
||||
gen_vocab(f.name, op.join(root, spm_filename_prefix),
|
||||
args.vocab_type, args.vocab_size)
|
||||
f.write(t + "\n")
|
||||
gen_vocab(
|
||||
f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
|
||||
)
|
||||
# Generate config YAML
|
||||
gen_config_yaml(root, spm_filename_prefix + '.model',
|
||||
yaml_filename=f'config_{task}.yaml',
|
||||
specaugment_policy='lb')
|
||||
gen_config_yaml(
|
||||
root,
|
||||
spm_filename_prefix + ".model",
|
||||
yaml_filename=f"config_{task}.yaml",
|
||||
specaugment_policy="lb",
|
||||
)
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-root', '-d', required=True, type=str)
|
||||
parser.add_argument('--vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--vocab-size', default=1000, type=int)
|
||||
parser.add_argument('--src-lang', '-s', required=True, type=str)
|
||||
parser.add_argument('--tgt-lang', '-t', type=str)
|
||||
parser.add_argument("--data-root", "-d", required=True, type=str)
|
||||
parser.add_argument(
|
||||
"--vocab-type",
|
||||
default="unigram",
|
||||
required=True,
|
||||
type=str,
|
||||
choices=["bpe", "unigram", "char"],
|
||||
),
|
||||
parser.add_argument("--vocab-size", default=1000, type=int)
|
||||
parser.add_argument("--src-lang", "-s", required=True, type=str)
|
||||
parser.add_argument("--tgt-lang", "-t", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -6,91 +6,114 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
import shutil
|
||||
import os.path as op
|
||||
import shutil
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from tqdm import tqdm
|
||||
from torchaudio.datasets import LIBRISPEECH
|
||||
import pandas as pd
|
||||
|
||||
from examples.speech_to_text.data_utils import (
|
||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
||||
extract_fbank_features, gen_config_yaml
|
||||
create_zip,
|
||||
extract_fbank_features,
|
||||
gen_config_yaml,
|
||||
gen_vocab,
|
||||
get_zip_manifest,
|
||||
save_df_to_tsv,
|
||||
)
|
||||
from torchaudio.datasets import LIBRISPEECH
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean',
|
||||
'dev-other', 'test-clean', 'test-other']
|
||||
SPLITS = [
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
]
|
||||
|
||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
||||
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
||||
|
||||
|
||||
def process(args):
|
||||
os.makedirs(args.output_root, exist_ok=True)
|
||||
# Extract features
|
||||
feature_root = op.join(args.output_root, 'fbank80')
|
||||
feature_root = op.join(args.output_root, "fbank80")
|
||||
os.makedirs(feature_root, exist_ok=True)
|
||||
for split in SPLITS:
|
||||
print(f'Fetching split {split}...')
|
||||
print(f"Fetching split {split}...")
|
||||
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
|
||||
print('Extracting log mel filter bank features...')
|
||||
print("Extracting log mel filter bank features...")
|
||||
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
|
||||
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
|
||||
extract_fbank_features(wav, sample_rate,
|
||||
op.join(feature_root, f'{sample_id}.npy'))
|
||||
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
|
||||
extract_fbank_features(
|
||||
wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
|
||||
)
|
||||
# Pack features into ZIP
|
||||
zip_filename = 'fbank80.zip'
|
||||
zip_filename = "fbank80.zip"
|
||||
zip_path = op.join(args.output_root, zip_filename)
|
||||
print('ZIPing features...')
|
||||
print("ZIPing features...")
|
||||
create_zip(feature_root, zip_path)
|
||||
print('Fetching ZIP manifest...')
|
||||
print("Fetching ZIP manifest...")
|
||||
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
|
||||
# Generate TSV manifest
|
||||
print('Generating manifest...')
|
||||
print("Generating manifest...")
|
||||
train_text = []
|
||||
for split in SPLITS:
|
||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||
dataset = LIBRISPEECH(args.output_root, url=split)
|
||||
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
|
||||
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
|
||||
manifest['id'].append(sample_id)
|
||||
manifest['audio'].append(zip_manifest[sample_id])
|
||||
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
|
||||
manifest["id"].append(sample_id)
|
||||
manifest["audio"].append(zip_manifest[sample_id])
|
||||
duration_ms = int(wav.size(1) / sample_rate * 1000)
|
||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
||||
manifest['tgt_text'].append(utt)
|
||||
manifest['speaker'].append(spk_id)
|
||||
save_df_to_tsv(pd.DataFrame.from_dict(manifest),
|
||||
op.join(args.output_root, f'{split}.tsv'))
|
||||
if split.startswith('train'):
|
||||
train_text.extend(manifest['tgt_text'])
|
||||
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
||||
manifest["tgt_text"].append(utt)
|
||||
manifest["speaker"].append(spk_id)
|
||||
save_df_to_tsv(
|
||||
pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
|
||||
)
|
||||
if split.startswith("train"):
|
||||
train_text.extend(manifest["tgt_text"])
|
||||
# Generate vocab
|
||||
vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size)
|
||||
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}'
|
||||
with NamedTemporaryFile(mode='w') as f:
|
||||
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
|
||||
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
|
||||
with NamedTemporaryFile(mode="w") as f:
|
||||
for t in train_text:
|
||||
f.write(t + '\n')
|
||||
gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix),
|
||||
args.vocab_type, args.vocab_size)
|
||||
f.write(t + "\n")
|
||||
gen_vocab(
|
||||
f.name,
|
||||
op.join(args.output_root, spm_filename_prefix),
|
||||
args.vocab_type,
|
||||
args.vocab_size,
|
||||
)
|
||||
# Generate config YAML
|
||||
gen_config_yaml(args.output_root, spm_filename_prefix + '.model',
|
||||
specaugment_policy='ld')
|
||||
gen_config_yaml(
|
||||
args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
|
||||
)
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output-root', '-o', required=True, type=str)
|
||||
parser.add_argument('--vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--vocab-size', default=10000, type=int)
|
||||
parser.add_argument("--output-root", "-o", required=True, type=str)
|
||||
parser.add_argument(
|
||||
"--vocab-type",
|
||||
default="unigram",
|
||||
required=True,
|
||||
type=str,
|
||||
choices=["bpe", "unigram", "char"],
|
||||
),
|
||||
parser.add_argument("--vocab-size", default=10000, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -6,29 +6,34 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
import os.path as op
|
||||
import shutil
|
||||
from typing import Tuple
|
||||
from itertools import groupby
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Tuple
|
||||
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
from torch import Tensor
|
||||
|
||||
from examples.speech_to_text.data_utils import (
|
||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
||||
extract_fbank_features, gen_config_yaml, filter_manifest_df
|
||||
create_zip,
|
||||
extract_fbank_features,
|
||||
filter_manifest_df,
|
||||
gen_config_yaml,
|
||||
gen_vocab,
|
||||
get_zip_manifest,
|
||||
save_df_to_tsv,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
||||
TASKS = ['asr', 'st']
|
||||
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
||||
TASKS = ["asr", "st"]
|
||||
|
||||
|
||||
class MUSTC(Dataset):
|
||||
@ -37,49 +42,55 @@ class MUSTC(Dataset):
|
||||
waveform, sample_rate, source utterance, target utterance, speaker_id,
|
||||
utterance_id
|
||||
"""
|
||||
SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
|
||||
LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru']
|
||||
|
||||
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
|
||||
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
|
||||
|
||||
def __init__(self, root: str, lang: str, split: str) -> None:
|
||||
assert split in self.SPLITS and lang in self.LANGUAGES
|
||||
_root = op.join(root, f'en-{lang}', 'data', split)
|
||||
wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt')
|
||||
_root = op.join(root, f"en-{lang}", "data", split)
|
||||
wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
|
||||
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
|
||||
# Load audio segments
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print('Please install PyYAML to load YAML files for '
|
||||
'the MuST-C dataset')
|
||||
with open(op.join(txt_root, f'{split}.yaml')) as f:
|
||||
print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
|
||||
with open(op.join(txt_root, f"{split}.yaml")) as f:
|
||||
segments = yaml.load(f, Loader=yaml.BaseLoader)
|
||||
# Load source and target utterances
|
||||
for _lang in ['en', lang]:
|
||||
with open(op.join(txt_root, f'{split}.{_lang}')) as f:
|
||||
for _lang in ["en", lang]:
|
||||
with open(op.join(txt_root, f"{split}.{_lang}")) as f:
|
||||
utterances = [r.strip() for r in f]
|
||||
assert len(segments) == len(utterances)
|
||||
for i, u in enumerate(utterances):
|
||||
segments[i][_lang] = u
|
||||
# Gather info
|
||||
self.data = []
|
||||
for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']):
|
||||
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
|
||||
wav_path = op.join(wav_root, wav_filename)
|
||||
sample_rate = torchaudio.info(wav_path)[0].rate
|
||||
seg_group = sorted(_seg_group, key=lambda x: x['offset'])
|
||||
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
|
||||
for i, segment in enumerate(seg_group):
|
||||
offset = int(float(segment['offset']) * sample_rate)
|
||||
n_frames = int(float(segment['duration']) * sample_rate)
|
||||
_id = f'{op.splitext(wav_filename)[0]}_{i}'
|
||||
offset = int(float(segment["offset"]) * sample_rate)
|
||||
n_frames = int(float(segment["duration"]) * sample_rate)
|
||||
_id = f"{op.splitext(wav_filename)[0]}_{i}"
|
||||
self.data.append(
|
||||
(wav_path, offset, n_frames, sample_rate, segment['en'],
|
||||
segment[lang], segment['speaker_id'], _id)
|
||||
(
|
||||
wav_path,
|
||||
offset,
|
||||
n_frames,
|
||||
sample_rate,
|
||||
segment["en"],
|
||||
segment[lang],
|
||||
segment["speaker_id"],
|
||||
_id,
|
||||
)
|
||||
)
|
||||
|
||||
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
|
||||
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \
|
||||
self.data[n]
|
||||
waveform, _ = torchaudio.load(wav_path, offset=offset,
|
||||
num_frames=n_frames)
|
||||
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
|
||||
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
|
||||
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -88,85 +99,102 @@ class MUSTC(Dataset):
|
||||
|
||||
def process(args):
|
||||
for lang in MUSTC.LANGUAGES:
|
||||
cur_root = op.join(args.data_root, f'en-{lang}')
|
||||
cur_root = op.join(args.data_root, f"en-{lang}")
|
||||
if not op.isdir(cur_root):
|
||||
print(f'{cur_root} does not exist. Skipped.')
|
||||
print(f"{cur_root} does not exist. Skipped.")
|
||||
continue
|
||||
# Extract features
|
||||
feature_root = op.join(cur_root, 'fbank80')
|
||||
feature_root = op.join(cur_root, "fbank80")
|
||||
os.makedirs(feature_root, exist_ok=True)
|
||||
for split in MUSTC.SPLITS:
|
||||
print(f'Fetching split {split}...')
|
||||
print(f"Fetching split {split}...")
|
||||
dataset = MUSTC(args.data_root, lang, split)
|
||||
print('Extracting log mel filter bank features...')
|
||||
print("Extracting log mel filter bank features...")
|
||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||
extract_fbank_features(waveform, sample_rate,
|
||||
op.join(feature_root, f'{utt_id}.npy'))
|
||||
extract_fbank_features(
|
||||
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
|
||||
)
|
||||
# Pack features into ZIP
|
||||
zip_filename = 'fbank80.zip'
|
||||
zip_filename = "fbank80.zip"
|
||||
zip_path = op.join(cur_root, zip_filename)
|
||||
print('ZIPing features...')
|
||||
print("ZIPing features...")
|
||||
create_zip(feature_root, zip_path)
|
||||
print('Fetching ZIP manifest...')
|
||||
zip_manifest = get_zip_manifest(args.data_root,
|
||||
f'en-{lang}/{zip_filename}')
|
||||
print("Fetching ZIP manifest...")
|
||||
zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
|
||||
# Generate TSV manifest
|
||||
print('Generating manifest...')
|
||||
print("Generating manifest...")
|
||||
train_text = {task: [] for task in TASKS}
|
||||
for split in MUSTC.SPLITS:
|
||||
is_train_split = split.startswith('train')
|
||||
is_train_split = split.startswith("train")
|
||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||
text = {task: [] for task in TASKS}
|
||||
dataset = MUSTC(args.data_root, lang, split)
|
||||
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
||||
manifest['id'].append(utt_id)
|
||||
manifest['audio'].append(zip_manifest[utt_id])
|
||||
manifest["id"].append(utt_id)
|
||||
manifest["audio"].append(zip_manifest[utt_id])
|
||||
duration_ms = int(wav.size(1) / sr * 1000)
|
||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
||||
text['asr'].append(src_utt)
|
||||
text['st'].append(tgt_utt)
|
||||
manifest['speaker'].append(speaker_id)
|
||||
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
||||
text["asr"].append(src_utt)
|
||||
text["st"].append(tgt_utt)
|
||||
manifest["speaker"].append(speaker_id)
|
||||
if is_train_split:
|
||||
for task in TASKS:
|
||||
train_text[task].extend(text[task])
|
||||
for task in TASKS:
|
||||
manifest['tgt_text'] = text[task]
|
||||
manifest["tgt_text"] = text[task]
|
||||
df = pd.DataFrame.from_dict(manifest)
|
||||
df = filter_manifest_df(df, is_train_split=is_train_split)
|
||||
save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv'))
|
||||
save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
|
||||
# Generate vocab
|
||||
for task in TASKS:
|
||||
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
|
||||
if task == 'st':
|
||||
if task == "st":
|
||||
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
|
||||
vocab_size_str = '' if vocab_type == 'char' else str(vocab_size)
|
||||
spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}'
|
||||
with NamedTemporaryFile(mode='w') as f:
|
||||
vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
|
||||
spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
|
||||
with NamedTemporaryFile(mode="w") as f:
|
||||
for t in train_text[task]:
|
||||
f.write(t + '\n')
|
||||
gen_vocab(f.name, op.join(cur_root, spm_filename_prefix),
|
||||
vocab_type, vocab_size)
|
||||
f.write(t + "\n")
|
||||
gen_vocab(
|
||||
f.name,
|
||||
op.join(cur_root, spm_filename_prefix),
|
||||
vocab_type,
|
||||
vocab_size,
|
||||
)
|
||||
# Generate config YAML
|
||||
gen_config_yaml(cur_root, spm_filename_prefix + '.model',
|
||||
yaml_filename=f'config_{task}.yaml',
|
||||
specaugment_policy='lb')
|
||||
gen_config_yaml(
|
||||
cur_root,
|
||||
spm_filename_prefix + ".model",
|
||||
yaml_filename=f"config_{task}.yaml",
|
||||
specaugment_policy="lb",
|
||||
)
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-root', '-d', required=True, type=str)
|
||||
parser.add_argument('--asr-vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--st-vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--asr-vocab-size', default=5000, type=int)
|
||||
parser.add_argument('--st-vocab-size', default=8000, type=int)
|
||||
parser.add_argument("--data-root", "-d", required=True, type=str)
|
||||
parser.add_argument(
|
||||
"--asr-vocab-type",
|
||||
default="unigram",
|
||||
required=True,
|
||||
type=str,
|
||||
choices=["bpe", "unigram", "char"],
|
||||
),
|
||||
parser.add_argument(
|
||||
"--st-vocab-type",
|
||||
default="unigram",
|
||||
required=True,
|
||||
type=str,
|
||||
choices=["bpe", "unigram", "char"],
|
||||
),
|
||||
parser.add_argument("--asr-vocab-size", default=5000, type=int)
|
||||
parser.add_argument("--st-vocab-size", default=8000, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -12,9 +12,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from itertools import chain
|
||||
import sys
|
||||
import random
|
||||
import sys
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
|
||||
@ -22,17 +22,21 @@ from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(sys.argv[0])
|
||||
parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
|
||||
help='path to system output')
|
||||
parser.add_argument('--ref', default='', metavar='FILE',
|
||||
help='path to references')
|
||||
parser.add_argument('--output', default='', metavar='FILE',
|
||||
help='print outputs into a pretty format')
|
||||
parser.add_argument(
|
||||
"--sys", nargs="*", default="", metavar="FILE", help="path to system output"
|
||||
)
|
||||
parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="",
|
||||
metavar="FILE",
|
||||
help="print outputs into a pretty format",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sys:
|
||||
src, tgt, hypos, log_probs = load_sys(args.sys)
|
||||
print('pairwise BLEU: %.2f' % pairwise(hypos))
|
||||
print("pairwise BLEU: %.2f" % pairwise(hypos))
|
||||
if args.output:
|
||||
merge(src, tgt, hypos, log_probs, args.output)
|
||||
|
||||
@ -58,18 +62,18 @@ def load_sys(paths):
|
||||
# S: source
|
||||
# T: target
|
||||
# D: detokenized system output
|
||||
if line.startswith(('S-', 'T-', 'D-')):
|
||||
i = int(line[line.find('-')+1:line.find('\t')])
|
||||
if line.startswith('S-'):
|
||||
src[i] = line.split('\t')[1]
|
||||
if line.startswith('T-'):
|
||||
tgt[i] = line.split('\t')[1]
|
||||
if line.startswith('D-'):
|
||||
if line.startswith(("S-", "T-", "D-")):
|
||||
i = int(line[line.find("-") + 1 : line.find("\t")])
|
||||
if line.startswith("S-"):
|
||||
src[i] = line.split("\t")[1]
|
||||
if line.startswith("T-"):
|
||||
tgt[i] = line.split("\t")[1]
|
||||
if line.startswith("D-"):
|
||||
if i not in hypos:
|
||||
hypos[i] = []
|
||||
log_probs[i] = []
|
||||
hypos[i].append(line.split('\t')[2])
|
||||
log_probs[i].append(float(line.split('\t')[1]))
|
||||
hypos[i].append(line.split("\t")[2])
|
||||
log_probs[i].append(float(line.split("\t")[1]))
|
||||
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
|
||||
|
||||
|
||||
@ -79,34 +83,34 @@ def load_ref(path):
|
||||
src, tgt, refs = [], [], []
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
if lines[i].startswith('S-'):
|
||||
src.append(lines[i].split('\t')[1].rstrip())
|
||||
if lines[i].startswith("S-"):
|
||||
src.append(lines[i].split("\t")[1].rstrip())
|
||||
i += 1
|
||||
elif lines[i].startswith('T-'):
|
||||
tgt.append(lines[i].split('\t')[1].rstrip())
|
||||
elif lines[i].startswith("T-"):
|
||||
tgt.append(lines[i].split("\t")[1].rstrip())
|
||||
i += 1
|
||||
else:
|
||||
a = []
|
||||
while i < len(lines) and lines[i].startswith('R'):
|
||||
a.append(lines[i].split('\t')[1].rstrip())
|
||||
while i < len(lines) and lines[i].startswith("R"):
|
||||
a.append(lines[i].split("\t")[1].rstrip())
|
||||
i += 1
|
||||
refs.append(a)
|
||||
return src, tgt, refs
|
||||
|
||||
|
||||
def merge(src, tgt, hypos, log_probs, path):
|
||||
with open(path, 'w') as f:
|
||||
with open(path, "w") as f:
|
||||
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
|
||||
f.write(s + '\n')
|
||||
f.write(t + '\n')
|
||||
f.write('\n')
|
||||
f.write(s + "\n")
|
||||
f.write(t + "\n")
|
||||
f.write("\n")
|
||||
for h, lp in zip(hs, lps):
|
||||
f.write('\t%f\t%s\n' % (lp, h.strip()))
|
||||
f.write('------------------------------------------------------\n')
|
||||
f.write("\t%f\t%s\n" % (lp, h.strip()))
|
||||
f.write("------------------------------------------------------\n")
|
||||
|
||||
|
||||
def corpus_bleu(sys_stream, ref_streams):
|
||||
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
|
||||
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
|
||||
return bleu.score
|
||||
|
||||
|
||||
@ -116,9 +120,11 @@ def sentence_bleu(hypothesis, reference):
|
||||
bleu.counts[i] += 1
|
||||
bleu.totals[i] += 1
|
||||
bleu = compute_bleu(
|
||||
bleu.counts, bleu.totals,
|
||||
bleu.sys_len, bleu.ref_len,
|
||||
smooth_method='exp',
|
||||
bleu.counts,
|
||||
bleu.totals,
|
||||
bleu.sys_len,
|
||||
bleu.ref_len,
|
||||
smooth_method="exp",
|
||||
)
|
||||
return bleu.score
|
||||
|
||||
@ -150,7 +156,7 @@ def multi_ref(refs, hypos):
|
||||
best = [k for k in range(len(rs)) if s[k] == s[j]]
|
||||
a.add(random.choice(best))
|
||||
ref_cnt += len(a)
|
||||
print('#refs covered: %.2f' % (ref_cnt / len(refs)))
|
||||
print("#refs covered: %.2f" % (ref_cnt / len(refs)))
|
||||
|
||||
# transpose refs and hypos
|
||||
refs = list(zip(*refs))
|
||||
@ -160,33 +166,32 @@ def multi_ref(refs, hypos):
|
||||
k = len(hypos)
|
||||
m = len(refs)
|
||||
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
|
||||
duplicated_refs = [
|
||||
[ref for ref in refs_i for _ in range(k)]
|
||||
for refs_i in refs
|
||||
]
|
||||
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
|
||||
loo_bleus = []
|
||||
for held_out_ref in range(m):
|
||||
remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
|
||||
remaining_refs = (
|
||||
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
|
||||
)
|
||||
assert len(remaining_refs) == m - 1
|
||||
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
|
||||
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
|
||||
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
|
||||
|
||||
|
||||
def intra_ref(refs):
|
||||
print('ref pairwise BLEU: %.2f' % pairwise(refs))
|
||||
print("ref pairwise BLEU: %.2f" % pairwise(refs))
|
||||
refs = list(zip(*refs))
|
||||
m = len(refs)
|
||||
concat_h = []
|
||||
concat_rest = [[] for j in range(m - 1)]
|
||||
for i, h in enumerate(refs):
|
||||
rest = refs[:i] + refs[i+1:]
|
||||
rest = refs[:i] + refs[i + 1 :]
|
||||
concat_h.append(h)
|
||||
for j in range(m - 1):
|
||||
concat_rest[j].extend(rest[j])
|
||||
concat_h = list(chain.from_iterable(concat_h))
|
||||
bleu = corpus_bleu(concat_h, concat_rest)
|
||||
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
|
||||
print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -21,6 +21,6 @@ class LogSumExpMoE(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
posterior, = ctx.saved_tensors
|
||||
(posterior,) = ctx.saved_tensors
|
||||
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
|
||||
return grad_logp, None, None
|
||||
|
@ -26,15 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
|
||||
|
||||
def forward(self, encoder_out):
|
||||
if not (
|
||||
hasattr(encoder_out, 'encoder_out')
|
||||
and hasattr(encoder_out, 'encoder_padding_mask')
|
||||
hasattr(encoder_out, "encoder_out")
|
||||
and hasattr(encoder_out, "encoder_padding_mask")
|
||||
and encoder_out.encoder_out.size(2) == self.embed_dim
|
||||
):
|
||||
raise ValueError('Unexpected format for encoder_out')
|
||||
raise ValueError("Unexpected format for encoder_out")
|
||||
|
||||
# mean pooling over time
|
||||
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
|
||||
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
|
||||
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
|
||||
if encoder_padding_mask is not None:
|
||||
encoder_out = encoder_out.clone() # required because of transpose above
|
||||
encoder_out[encoder_padding_mask] = 0
|
||||
|
@ -4,7 +4,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.translation import TranslationTask
|
||||
@ -13,7 +12,7 @@ from .logsumexp_moe import LogSumExpMoE
|
||||
from .mean_pool_gating_network import MeanPoolGatingNetwork
|
||||
|
||||
|
||||
@register_task('translation_moe')
|
||||
@register_task("translation_moe")
|
||||
class TranslationMoETask(TranslationTask):
|
||||
"""
|
||||
Translation task for Mixture of Experts (MoE) models.
|
||||
@ -58,19 +57,19 @@ class TranslationMoETask(TranslationTask):
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, src_dict, tgt_dict):
|
||||
if args.method == 'sMoElp':
|
||||
if args.method == "sMoElp":
|
||||
# soft MoE with learned prior
|
||||
self.uniform_prior = False
|
||||
self.hard_selection = False
|
||||
elif args.method == 'sMoEup':
|
||||
elif args.method == "sMoEup":
|
||||
# soft MoE with uniform prior
|
||||
self.uniform_prior = True
|
||||
self.hard_selection = False
|
||||
elif args.method == 'hMoElp':
|
||||
elif args.method == "hMoElp":
|
||||
# hard MoE with learned prior
|
||||
self.uniform_prior = False
|
||||
self.hard_selection = True
|
||||
elif args.method == 'hMoEup':
|
||||
elif args.method == "hMoEup":
|
||||
# hard MoE with uniform prior
|
||||
self.uniform_prior = True
|
||||
self.hard_selection = True
|
||||
@ -78,50 +77,56 @@ class TranslationMoETask(TranslationTask):
|
||||
# add indicator tokens for each expert
|
||||
for i in range(args.num_experts):
|
||||
# add to both dictionaries in case we're sharing embeddings
|
||||
src_dict.add_symbol('<expert_{}>'.format(i))
|
||||
tgt_dict.add_symbol('<expert_{}>'.format(i))
|
||||
src_dict.add_symbol("<expert_{}>".format(i))
|
||||
tgt_dict.add_symbol("<expert_{}>".format(i))
|
||||
|
||||
super().__init__(args, src_dict, tgt_dict)
|
||||
|
||||
def build_model(self, args):
|
||||
from fairseq import models
|
||||
|
||||
model = models.build_model(args, self)
|
||||
if not self.uniform_prior and not hasattr(model, 'gating_network'):
|
||||
if not self.uniform_prior and not hasattr(model, "gating_network"):
|
||||
if self.args.mean_pool_gating_network:
|
||||
if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
|
||||
if getattr(args, "mean_pool_gating_network_encoder_dim", None):
|
||||
encoder_dim = args.mean_pool_gating_network_encoder_dim
|
||||
elif getattr(args, 'encoder_embed_dim', None):
|
||||
elif getattr(args, "encoder_embed_dim", None):
|
||||
# assume that encoder_embed_dim is the encoder's output dimension
|
||||
encoder_dim = args.encoder_embed_dim
|
||||
else:
|
||||
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
|
||||
raise ValueError(
|
||||
"Must specify --mean-pool-gating-network-encoder-dim"
|
||||
)
|
||||
|
||||
if getattr(args, 'mean_pool_gating_network_dropout', None):
|
||||
if getattr(args, "mean_pool_gating_network_dropout", None):
|
||||
dropout = args.mean_pool_gating_network_dropout
|
||||
elif getattr(args, 'dropout', None):
|
||||
elif getattr(args, "dropout", None):
|
||||
dropout = args.dropout
|
||||
else:
|
||||
raise ValueError('Must specify --mean-pool-gating-network-dropout')
|
||||
raise ValueError("Must specify --mean-pool-gating-network-dropout")
|
||||
|
||||
model.gating_network = MeanPoolGatingNetwork(
|
||||
encoder_dim, args.num_experts, dropout,
|
||||
encoder_dim,
|
||||
args.num_experts,
|
||||
dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'translation_moe task with learned prior requires the model to '
|
||||
'have a gating network; try using --mean-pool-gating-network'
|
||||
"translation_moe task with learned prior requires the model to "
|
||||
"have a gating network; try using --mean-pool-gating-network"
|
||||
)
|
||||
return model
|
||||
|
||||
def expert_index(self, i):
|
||||
return i + self.tgt_dict.index('<expert_0>')
|
||||
return i + self.tgt_dict.index("<expert_0>")
|
||||
|
||||
def _get_loss(self, sample, model, criterion):
|
||||
assert hasattr(criterion, 'compute_loss'), \
|
||||
'translation_moe task requires the criterion to implement the compute_loss() method'
|
||||
assert hasattr(
|
||||
criterion, "compute_loss"
|
||||
), "translation_moe task requires the criterion to implement the compute_loss() method"
|
||||
|
||||
k = self.args.num_experts
|
||||
bsz = sample['target'].size(0)
|
||||
bsz = sample["target"].size(0)
|
||||
|
||||
def get_lprob_y(encoder_out, prev_output_tokens_k):
|
||||
net_output = model.decoder(
|
||||
@ -134,20 +139,22 @@ class TranslationMoETask(TranslationTask):
|
||||
|
||||
def get_lprob_yz(winners=None):
|
||||
encoder_out = model.encoder(
|
||||
src_tokens=sample['net_input']['src_tokens'],
|
||||
src_lengths=sample['net_input']['src_lengths'],
|
||||
src_tokens=sample["net_input"]["src_tokens"],
|
||||
src_lengths=sample["net_input"]["src_lengths"],
|
||||
)
|
||||
|
||||
if winners is None:
|
||||
lprob_y = []
|
||||
for i in range(k):
|
||||
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
|
||||
prev_output_tokens_k = sample["net_input"][
|
||||
"prev_output_tokens"
|
||||
].clone()
|
||||
assert not prev_output_tokens_k.requires_grad
|
||||
prev_output_tokens_k[:, 0] = self.expert_index(i)
|
||||
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
|
||||
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
|
||||
else:
|
||||
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
|
||||
prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
|
||||
prev_output_tokens_k[:, 0] = self.expert_index(winners)
|
||||
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
|
||||
|
||||
@ -177,17 +184,21 @@ class TranslationMoETask(TranslationTask):
|
||||
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
|
||||
|
||||
loss = loss.sum()
|
||||
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
'loss': utils.item(loss.data),
|
||||
'ntokens': sample['ntokens'],
|
||||
'nsentences': bsz,
|
||||
'sample_size': sample_size,
|
||||
'posterior': prob_z_xy.float().sum(dim=0).cpu(),
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": bsz,
|
||||
"sample_size": sample_size,
|
||||
"posterior": prob_z_xy.float().sum(dim=0).cpu(),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
model.train()
|
||||
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
||||
if ignore_grad:
|
||||
@ -201,7 +212,15 @@ class TranslationMoETask(TranslationTask):
|
||||
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None):
|
||||
def inference_step(
|
||||
self,
|
||||
generator,
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens=None,
|
||||
expert=None,
|
||||
constraints=None,
|
||||
):
|
||||
expert = expert or self.args.gen_expert
|
||||
with torch.no_grad():
|
||||
return generator.generate(
|
||||
@ -215,6 +234,6 @@ class TranslationMoETask(TranslationTask):
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
metrics.log_scalar(
|
||||
'posterior',
|
||||
sum(log['posterior'] for log in logging_outputs if 'posterior' in log)
|
||||
"posterior",
|
||||
sum(log["posterior"] for log in logging_outputs if "posterior" in log),
|
||||
)
|
||||
|
@ -4,37 +4,38 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
aggregate_funcs = {
|
||||
'std': np.std,
|
||||
'var': np.var,
|
||||
'median': np.median,
|
||||
'mean': np.mean,
|
||||
'min': np.min,
|
||||
'max': np.max,
|
||||
"std": np.std,
|
||||
"var": np.var,
|
||||
"median": np.median,
|
||||
"mean": np.mean,
|
||||
"min": np.min,
|
||||
"max": np.max,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_file', required=True, type=str)
|
||||
parser.add_argument('-n', '--repeat_times', required=True, type=int)
|
||||
parser.add_argument('-o', '--output_file', required=False)
|
||||
parser.add_argument('-f', '--func', required=False, default='mean')
|
||||
parser.add_argument("-i", "--input_file", required=True, type=str)
|
||||
parser.add_argument("-n", "--repeat_times", required=True, type=int)
|
||||
parser.add_argument("-o", "--output_file", required=False)
|
||||
parser.add_argument("-f", "--func", required=False, default="mean")
|
||||
args = parser.parse_args()
|
||||
|
||||
stream = open(args.output_file, 'w') if args.output_file else sys.stdout
|
||||
stream = open(args.output_file, "w") if args.output_file else sys.stdout
|
||||
|
||||
segment_scores = []
|
||||
for line in open(args.input_file):
|
||||
segment_scores.append(float(line.strip()))
|
||||
if len(segment_scores) == args.repeat_times:
|
||||
stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores)))
|
||||
stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
|
||||
segment_scores = []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -4,14 +4,13 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
import math
|
||||
|
||||
from itertools import combinations
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from itertools import combinations
|
||||
|
||||
|
||||
def read_translations(path, n_repeats):
|
||||
@ -19,7 +18,7 @@ def read_translations(path, n_repeats):
|
||||
segment_translations = []
|
||||
translations = defaultdict(list)
|
||||
for line in open(path):
|
||||
segment_translations.append(' '.join(line.split()))
|
||||
segment_translations.append(" ".join(line.split()))
|
||||
if len(segment_translations) == n_repeats:
|
||||
translations[segment_counter] = segment_translations
|
||||
segment_translations = []
|
||||
@ -30,42 +29,55 @@ def read_translations(path, n_repeats):
|
||||
def generate_input(translations, n_repeats):
|
||||
_, ref_path = tempfile.mkstemp()
|
||||
_, mt_path = tempfile.mkstemp()
|
||||
ref_fh = open(ref_path, 'w')
|
||||
mt_fh = open(mt_path, 'w')
|
||||
ref_fh = open(ref_path, "w")
|
||||
mt_fh = open(mt_path, "w")
|
||||
for segid in sorted(translations.keys()):
|
||||
assert len(translations[segid]) == n_repeats
|
||||
indexes = combinations(range(n_repeats), 2)
|
||||
for idx1, idx2 in indexes:
|
||||
mt_fh.write(translations[segid][idx1].strip() + '\n')
|
||||
ref_fh.write(translations[segid][idx2].strip() + '\n')
|
||||
sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path))
|
||||
mt_fh.write(translations[segid][idx1].strip() + "\n")
|
||||
ref_fh.write(translations[segid][idx2].strip() + "\n")
|
||||
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
|
||||
return ref_path, mt_path
|
||||
|
||||
|
||||
def run_meteor(ref_path, mt_path, metric_path, lang='en'):
|
||||
def run_meteor(ref_path, mt_path, metric_path, lang="en"):
|
||||
_, out_path = tempfile.mkstemp()
|
||||
subprocess.call([
|
||||
'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path,
|
||||
'-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R
|
||||
'-norm',
|
||||
'-l', lang], stdout=open(out_path, 'w'))
|
||||
subprocess.call(
|
||||
[
|
||||
"java",
|
||||
"-Xmx2G",
|
||||
"-jar",
|
||||
metric_path,
|
||||
mt_path,
|
||||
ref_path,
|
||||
"-p",
|
||||
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
|
||||
"-norm",
|
||||
"-l",
|
||||
lang,
|
||||
],
|
||||
stdout=open(out_path, "w"),
|
||||
)
|
||||
os.remove(ref_path)
|
||||
os.remove(mt_path)
|
||||
sys.stderr.write('\nSaved Meteor output to %s' % out_path)
|
||||
sys.stderr.write("\nSaved Meteor output to %s" % out_path)
|
||||
return out_path
|
||||
|
||||
|
||||
def read_output(meteor_output_path, n_repeats):
|
||||
n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2))
|
||||
n_combinations = math.factorial(n_repeats) / (
|
||||
math.factorial(2) * math.factorial(n_repeats - 2)
|
||||
)
|
||||
raw_scores = []
|
||||
average_scores = []
|
||||
for line in open(meteor_output_path):
|
||||
if not line.startswith('Segment '):
|
||||
if not line.startswith("Segment "):
|
||||
continue
|
||||
score = float(line.strip().split('\t')[1])
|
||||
score = float(line.strip().split("\t")[1])
|
||||
raw_scores.append(score)
|
||||
if len(raw_scores) == n_combinations:
|
||||
average_scores.append(sum(raw_scores)/n_combinations)
|
||||
average_scores.append(sum(raw_scores) / n_combinations)
|
||||
raw_scores = []
|
||||
os.remove(meteor_output_path)
|
||||
return average_scores
|
||||
@ -73,25 +85,25 @@ def read_output(meteor_output_path, n_repeats):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input')
|
||||
parser.add_argument('-n', '--repeat_times', type=int)
|
||||
parser.add_argument('-m', '--meteor')
|
||||
parser.add_argument('-o', '--output')
|
||||
parser.add_argument("-i", "--input")
|
||||
parser.add_argument("-n", "--repeat_times", type=int)
|
||||
parser.add_argument("-m", "--meteor")
|
||||
parser.add_argument("-o", "--output")
|
||||
args = parser.parse_args()
|
||||
|
||||
translations = read_translations(args.infile, args.repetitions)
|
||||
sys.stderr.write('\nGenerating input for Meteor...')
|
||||
sys.stderr.write("\nGenerating input for Meteor...")
|
||||
ref_path, mt_path = generate_input(translations, args.repetitions)
|
||||
sys.stderr.write('\nRunning Meteor...')
|
||||
sys.stderr.write("\nRunning Meteor...")
|
||||
out_path = run_meteor(ref_path, mt_path, args.meteor)
|
||||
sys.stderr.write('\nReading output...')
|
||||
sys.stderr.write("\nReading output...")
|
||||
scores = read_output(out_path, args.repetitions)
|
||||
sys.stderr.write('\nWriting results...')
|
||||
with open(args.output, 'w') as o:
|
||||
sys.stderr.write("\nWriting results...")
|
||||
with open(args.output, "w") as o:
|
||||
for scr in scores:
|
||||
o.write('{}\n'.format(scr))
|
||||
o.write("{}\n".format(scr))
|
||||
o.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -8,21 +8,21 @@ import sys
|
||||
|
||||
|
||||
def _normalize_spaces(line):
|
||||
return ' '.join(line.split())
|
||||
return " ".join(line.split())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_file', required=True, type=str)
|
||||
parser.add_argument('-n', '--repeat_times', required=True, type=int)
|
||||
parser.add_argument('-o', '--output_file', required=False, type=str)
|
||||
parser.add_argument("-i", "--input_file", required=True, type=str)
|
||||
parser.add_argument("-n", "--repeat_times", required=True, type=int)
|
||||
parser.add_argument("-o", "--output_file", required=False, type=str)
|
||||
args = parser.parse_args()
|
||||
stream = open(args.output_file, 'w') if args.output_file else sys.stdout
|
||||
stream = open(args.output_file, "w") if args.output_file else sys.stdout
|
||||
|
||||
for line in open(args.input_file):
|
||||
for _ in range(args.repeat_times):
|
||||
stream.write(_normalize_spaces(line) + '\n')
|
||||
stream.write(_normalize_spaces(line) + "\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -8,30 +8,31 @@
|
||||
Helper script to pre-compute embeddings for a wav2letter++ dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import pprint
|
||||
import glob, os, argparse
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import tqdm
|
||||
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
try:
|
||||
import tqdm
|
||||
except:
|
||||
print("Install tqdm to use --log-format=tqdm")
|
||||
|
||||
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
||||
|
||||
import tqdm
|
||||
import soundfile as sf
|
||||
from torch.utils.data import DataLoader
|
||||
import os.path as osp
|
||||
|
||||
|
||||
class FilesDataset:
|
||||
def __init__(self, files, labels):
|
||||
self.files = files
|
||||
if labels and osp.exists(labels):
|
||||
with open(labels, 'r') as lbl_f:
|
||||
with open(labels, "r") as lbl_f:
|
||||
self.labels = [line.rstrip() for line in lbl_f]
|
||||
else:
|
||||
self.labels = labels
|
||||
@ -50,7 +51,7 @@ class FilesDataset:
|
||||
if self.labels:
|
||||
if isinstance(self.labels, str):
|
||||
lbl_file = osp.splitext(fname)[0] + "." + self.labels
|
||||
with open(lbl_file, 'r') as lblf:
|
||||
with open(lbl_file, "r") as lblf:
|
||||
lbls = lblf.readline()
|
||||
assert lbls is not None
|
||||
else:
|
||||
@ -116,24 +117,24 @@ class DatasetWriter:
|
||||
assert len(files) > 0
|
||||
|
||||
if self.args.shard is not None:
|
||||
files = files[self.args.shard::self.args.num_shards]
|
||||
files = files[self.args.shard :: self.args.num_shards]
|
||||
|
||||
lbls = []
|
||||
with open(self.data_file(split), 'w') as srcf:
|
||||
with open(self.data_file(split), "w") as srcf:
|
||||
for line, lbl in self.iterate(files):
|
||||
print(line, file=srcf)
|
||||
if self.args.labels:
|
||||
lbls.append(lbl + '\n')
|
||||
lbls.append(lbl + "\n")
|
||||
|
||||
if self.args.labels:
|
||||
assert all(a is not None for a in lbls)
|
||||
with open(self.lbl_file(split), 'w') as lblf:
|
||||
with open(self.lbl_file(split), "w") as lblf:
|
||||
lblf.writelines(lbls)
|
||||
|
||||
def iterate(self, files):
|
||||
|
||||
data = self.load_data(files)
|
||||
for samples in tqdm.tqdm(data, total=len(files)//32):
|
||||
for samples in tqdm.tqdm(data, total=len(files) // 32):
|
||||
|
||||
for wav, lbl in samples:
|
||||
x = wav.unsqueeze(0).float().cuda()
|
||||
@ -162,7 +163,6 @@ class DatasetWriter:
|
||||
idx = torch.cat(result, dim=0)
|
||||
yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
|
||||
|
||||
|
||||
def lbl_file(self, name):
|
||||
shard_part = "" if self.args.shard is None else f".{self.args.shard}"
|
||||
return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
|
||||
@ -230,7 +230,9 @@ class DatasetWriter:
|
||||
|
||||
self.process_splits()
|
||||
|
||||
if hasattr(self.model.feature_extractor, "vars") and (self.args.shard is None or self.args.shard == 0):
|
||||
if hasattr(self.model.feature_extractor, "vars") and (
|
||||
self.args.shard is None or self.args.shard == 0
|
||||
):
|
||||
vars = (
|
||||
self.model.feature_extractor.vars.view(
|
||||
self.model.feature_extractor.banks,
|
||||
|
@ -14,13 +14,12 @@ import os
|
||||
from shutil import copy
|
||||
|
||||
import h5py
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
import tqdm
|
||||
|
||||
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
|
||||
from torch import nn
|
||||
|
||||
|
||||
def read_audio(fname):
|
||||
@ -33,7 +32,6 @@ def read_audio(fname):
|
||||
|
||||
|
||||
class PretrainedWav2VecModel(nn.Module):
|
||||
|
||||
def __init__(self, fname):
|
||||
super().__init__()
|
||||
|
||||
@ -55,32 +53,33 @@ class PretrainedWav2VecModel(nn.Module):
|
||||
|
||||
|
||||
class EmbeddingWriterConfig(argparse.ArgumentParser):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Pre-compute embeddings for wav2letter++ datasets")
|
||||
|
||||
kwargs = {"action": "store", "type": str, "required": True}
|
||||
|
||||
self.add_argument("--input", "-i",
|
||||
help="Input Directory", **kwargs)
|
||||
self.add_argument("--output", "-o",
|
||||
help="Output Directory", **kwargs)
|
||||
self.add_argument("--model",
|
||||
help="Path to model checkpoint", **kwargs)
|
||||
self.add_argument("--split",
|
||||
help="Dataset Splits", nargs='+', **kwargs)
|
||||
self.add_argument("--ext", default="wav", required=False,
|
||||
help="Audio file extension")
|
||||
self.add_argument("--input", "-i", help="Input Directory", **kwargs)
|
||||
self.add_argument("--output", "-o", help="Output Directory", **kwargs)
|
||||
self.add_argument("--model", help="Path to model checkpoint", **kwargs)
|
||||
self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs)
|
||||
self.add_argument(
|
||||
"--ext", default="wav", required=False, help="Audio file extension"
|
||||
)
|
||||
|
||||
self.add_argument("--no-copy-labels", action="store_true",
|
||||
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.")
|
||||
self.add_argument("--use-feat", action="store_true",
|
||||
help="Use the feature vector ('z') instead of context vector ('c') for features")
|
||||
self.add_argument("--gpu",
|
||||
help="GPU to use", default=0, type=int)
|
||||
self.add_argument(
|
||||
"--no-copy-labels",
|
||||
action="store_true",
|
||||
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.",
|
||||
)
|
||||
self.add_argument(
|
||||
"--use-feat",
|
||||
action="store_true",
|
||||
help="Use the feature vector ('z') instead of context vector ('c') for features",
|
||||
)
|
||||
self.add_argument("--gpu", help="GPU to use", default=0, type=int)
|
||||
|
||||
|
||||
class Prediction():
|
||||
class Prediction:
|
||||
""" Lightweight wrapper around a fairspeech embedding model """
|
||||
|
||||
def __init__(self, fname, gpu=0):
|
||||
@ -95,7 +94,7 @@ class Prediction():
|
||||
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
|
||||
|
||||
|
||||
class H5Writer():
|
||||
class H5Writer:
|
||||
""" Write features as hdf5 file in wav2letter++ compatible format """
|
||||
|
||||
def __init__(self, fname):
|
||||
@ -112,7 +111,7 @@ class H5Writer():
|
||||
|
||||
|
||||
class EmbeddingDatasetWriter(object):
|
||||
""" Given a model and a wav2letter++ dataset, pre-compute and store embeddings
|
||||
"""Given a model and a wav2letter++ dataset, pre-compute and store embeddings
|
||||
|
||||
Args:
|
||||
input_root, str :
|
||||
@ -123,13 +122,17 @@ class EmbeddingDatasetWriter(object):
|
||||
Dataset split
|
||||
"""
|
||||
|
||||
def __init__(self, input_root, output_root, split,
|
||||
model_fname,
|
||||
extension="wav",
|
||||
gpu=0,
|
||||
verbose=False,
|
||||
use_feat=False,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
input_root,
|
||||
output_root,
|
||||
split,
|
||||
model_fname,
|
||||
extension="wav",
|
||||
gpu=0,
|
||||
verbose=False,
|
||||
use_feat=False,
|
||||
):
|
||||
|
||||
assert os.path.exists(model_fname)
|
||||
|
||||
@ -143,8 +146,9 @@ class EmbeddingDatasetWriter(object):
|
||||
self.extension = extension
|
||||
self.use_feat = use_feat
|
||||
|
||||
assert os.path.exists(self.input_path), \
|
||||
"Input path '{}' does not exist".format(self.input_path)
|
||||
assert os.path.exists(self.input_path), "Input path '{}' does not exist".format(
|
||||
self.input_path
|
||||
)
|
||||
|
||||
def _progress(self, iterable, **kwargs):
|
||||
if self.verbose:
|
||||
@ -176,7 +180,11 @@ class EmbeddingDatasetWriter(object):
|
||||
def copy_labels(self):
|
||||
self.require_output_path()
|
||||
|
||||
labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))))
|
||||
labels = list(
|
||||
filter(
|
||||
lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))
|
||||
)
|
||||
)
|
||||
for fname in tqdm.tqdm(labels):
|
||||
copy(fname, self.output_path)
|
||||
|
||||
@ -191,10 +199,16 @@ class EmbeddingDatasetWriter(object):
|
||||
|
||||
paths = self.input_fnames
|
||||
|
||||
fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \
|
||||
map(os.path.basename, paths))
|
||||
fnames_context = map(
|
||||
lambda x: os.path.join(
|
||||
self.output_path, x.replace("." + self.extension, ".h5context")
|
||||
),
|
||||
map(os.path.basename, paths),
|
||||
)
|
||||
|
||||
for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)):
|
||||
for name, target_fname in self._progress(
|
||||
zip(paths, fnames_context), total=len(self)
|
||||
):
|
||||
wav, sr = read_audio(name)
|
||||
z, c = self.model(wav)
|
||||
feat = z if self.use_feat else c
|
||||
@ -204,7 +218,8 @@ class EmbeddingDatasetWriter(object):
|
||||
def __repr__(self):
|
||||
|
||||
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
|
||||
n_files=len(self), **self.__dict__)
|
||||
n_files=len(self), **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -10,32 +10,50 @@ Data pre-processing: build vocabularies and binarize training data.
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import soundfile
|
||||
import random
|
||||
|
||||
import soundfile
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
|
||||
parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
|
||||
help='percentage of data to use as validation set (between 0 and 1)')
|
||||
parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
|
||||
parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
|
||||
parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
|
||||
parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
|
||||
help='if set, path must contain this substring for a file to be included in the manifest')
|
||||
parser.add_argument(
|
||||
"root", metavar="DIR", help="root directory containing flac files to index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--valid-percent",
|
||||
default=0.01,
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="percentage of data to use as validation set (between 0 and 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest", default=".", type=str, metavar="DIR", help="output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
|
||||
parser.add_argument(
|
||||
"--path-must-contain",
|
||||
default=None,
|
||||
type=str,
|
||||
metavar="FRAG",
|
||||
help="if set, path must contain this substring for a file to be included in the manifest",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.valid_percent >= 0 and args.valid_percent <= 1.
|
||||
assert args.valid_percent >= 0 and args.valid_percent <= 1.0
|
||||
|
||||
dir_path = os.path.realpath(args.root)
|
||||
search_path = os.path.join(dir_path, '**/*.' + args.ext)
|
||||
search_path = os.path.join(dir_path, "**/*." + args.ext)
|
||||
rand = random.Random(args.seed)
|
||||
|
||||
with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
|
||||
os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
|
||||
with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open(
|
||||
os.path.join(args.dest, "valid.tsv"), "w"
|
||||
) as valid_f:
|
||||
print(dir_path, file=train_f)
|
||||
print(dir_path, file=valid_f)
|
||||
|
||||
@ -47,10 +65,12 @@ def main(args):
|
||||
|
||||
frames = soundfile.info(fname).frames
|
||||
dest = train_f if rand.random() > args.valid_percent else valid_f
|
||||
print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
|
||||
print(
|
||||
"{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -4,16 +4,17 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
__all__ = ['pdb']
|
||||
__version__ = '1.0.0a0'
|
||||
__all__ = ["pdb"]
|
||||
__version__ = "1.0.0a0"
|
||||
|
||||
import sys
|
||||
|
||||
# backwards compatibility to support `from fairseq.meters import AverageMeter`
|
||||
from fairseq.logging import meters, metrics, progress_bar # noqa
|
||||
sys.modules['fairseq.meters'] = meters
|
||||
sys.modules['fairseq.metrics'] = metrics
|
||||
sys.modules['fairseq.progress_bar'] = progress_bar
|
||||
|
||||
sys.modules["fairseq.meters"] = meters
|
||||
sys.modules["fairseq.metrics"] = metrics
|
||||
sys.modules["fairseq.progress_bar"] = progress_bar
|
||||
|
||||
import fairseq.criterions # noqa
|
||||
import fairseq.models # noqa
|
||||
|
@ -4,9 +4,4 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# import models/tasks to register them
|
||||
from . import ( # noqa
|
||||
dummy_lm,
|
||||
dummy_masked_lm,
|
||||
dummy_model,
|
||||
dummy_mt,
|
||||
)
|
||||
from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
|
||||
|
@ -7,25 +7,27 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task('dummy_lm')
|
||||
@register_task("dummy_lm")
|
||||
class DummyLMTask(LegacyFairseqTask):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument('--dict-size', default=49996, type=int)
|
||||
parser.add_argument('--dataset-size', default=100000, type=int)
|
||||
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
||||
help='max number of total tokens over all segments '
|
||||
'per sample for BERT dataset')
|
||||
parser.add_argument("--dict-size", default=49996, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments "
|
||||
"per sample for BERT dataset",
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
@ -44,8 +46,8 @@ class DummyLMTask(LegacyFairseqTask):
|
||||
"""Setup the task. """
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol('word{}'.format(i))
|
||||
logger.info('dictionary: {} types'.format(len(dictionary)))
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
@ -59,16 +61,16 @@ class DummyLMTask(LegacyFairseqTask):
|
||||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
'id': 1,
|
||||
'net_input': {
|
||||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
'src_lengths': torch.full(
|
||||
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.tokens_per_sample, dtype=torch.long
|
||||
),
|
||||
},
|
||||
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
'nsentences': bsz,
|
||||
'ntokens': bsz * self.args.tokens_per_sample,
|
||||
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tokens_per_sample,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=self.args.tokens_per_sample,
|
||||
@ -84,7 +86,6 @@ class DummyLMTask(LegacyFairseqTask):
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
|
@ -7,32 +7,34 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task('dummy_masked_lm')
|
||||
@register_task("dummy_masked_lm")
|
||||
class DummyMaskedLMTask(LegacyFairseqTask):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument('--dict-size', default=49995, type=int)
|
||||
parser.add_argument('--dataset-size', default=100000, type=int)
|
||||
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
||||
help='max number of total tokens over all segments '
|
||||
'per sample for BERT dataset')
|
||||
parser.add_argument("--dict-size", default=49995, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments "
|
||||
"per sample for BERT dataset",
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
|
||||
# add mask token
|
||||
self.mask_idx = dictionary.add_symbol('<mask>')
|
||||
self.mask_idx = dictionary.add_symbol("<mask>")
|
||||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
|
||||
mask_idx = 0
|
||||
@ -52,8 +54,8 @@ class DummyMaskedLMTask(LegacyFairseqTask):
|
||||
"""Setup the task. """
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol('word{}'.format(i))
|
||||
logger.info('dictionary: {} types'.format(len(dictionary)))
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
@ -67,16 +69,16 @@ class DummyMaskedLMTask(LegacyFairseqTask):
|
||||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
'id': 1,
|
||||
'net_input': {
|
||||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
'src_lengths': torch.full(
|
||||
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.tokens_per_sample, dtype=torch.long
|
||||
),
|
||||
},
|
||||
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
'nsentences': bsz,
|
||||
'ntokens': bsz * self.args.tokens_per_sample,
|
||||
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tokens_per_sample,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=self.args.tokens_per_sample,
|
||||
@ -92,7 +94,6 @@ class DummyMaskedLMTask(LegacyFairseqTask):
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.models import (
|
||||
FairseqDecoder,
|
||||
@ -15,17 +14,16 @@ from fairseq.models import (
|
||||
)
|
||||
|
||||
|
||||
@register_model('dummy_model')
|
||||
@register_model("dummy_model")
|
||||
class DummyModel(FairseqLanguageModel):
|
||||
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__(encoder)
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument('--num-layers', type=int, default=24)
|
||||
parser.add_argument('--embed-dim', type=int, default=1024)
|
||||
parser.add_argument("--num-layers", type=int, default=24)
|
||||
parser.add_argument("--embed-dim", type=int, default=1024)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
@ -41,32 +39,35 @@ class DummyModel(FairseqLanguageModel):
|
||||
|
||||
|
||||
class DummyEncoder(FairseqDecoder):
|
||||
|
||||
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
||||
super().__init__(Dictionary())
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
||||
)
|
||||
self.layers_a = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
|
||||
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
|
||||
nn.Linear(embed_dim, embed_dim), # output projection
|
||||
nn.Dropout(),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.layers_b = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 4*embed_dim), # FFN
|
||||
nn.ReLU(),
|
||||
nn.Linear(4*embed_dim, embed_dim), # FFN
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.layers_a = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
|
||||
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
|
||||
nn.Linear(embed_dim, embed_dim), # output projection
|
||||
nn.Dropout(),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.layers_b = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 4 * embed_dim), # FFN
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * embed_dim, embed_dim), # FFN
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.out_proj = nn.Linear(embed_dim, num_embed)
|
||||
|
||||
def forward(self, tokens, masked_tokens=None):
|
||||
@ -90,6 +91,6 @@ class DummyEncoder(FairseqDecoder):
|
||||
return F.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
@register_model_architecture('dummy_model', 'dummy_model')
|
||||
@register_model_architecture("dummy_model", "dummy_model")
|
||||
def base_architecture(args):
|
||||
pass
|
||||
|
@ -7,24 +7,22 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import register_task, LegacyFairseqTask
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task('dummy_mt')
|
||||
@register_task("dummy_mt")
|
||||
class DummyMTTask(LegacyFairseqTask):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument('--dict-size', default=49996, type=int)
|
||||
parser.add_argument('--dataset-size', default=100000, type=int)
|
||||
parser.add_argument('--src-len', default=30, type=int)
|
||||
parser.add_argument('--tgt-len', default=30, type=int)
|
||||
parser.add_argument("--dict-size", default=49996, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument("--src-len", default=30, type=int)
|
||||
parser.add_argument("--tgt-len", default=30, type=int)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
@ -41,8 +39,8 @@ class DummyMTTask(LegacyFairseqTask):
|
||||
"""Setup the task. """
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol('word{}'.format(i))
|
||||
logger.info('dictionary: {} types'.format(len(dictionary)))
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
|
||||
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
||||
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
||||
@ -62,17 +60,17 @@ class DummyMTTask(LegacyFairseqTask):
|
||||
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
'id': 1,
|
||||
'net_input': {
|
||||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
'src_lengths': torch.full(
|
||||
(bsz, ), self.args.src_len, dtype=torch.long
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.src_len, dtype=torch.long
|
||||
),
|
||||
'prev_output_tokens': tgt.clone(),
|
||||
"prev_output_tokens": tgt.clone(),
|
||||
},
|
||||
'target': tgt,
|
||||
'nsentences': bsz,
|
||||
'ntokens': bsz * self.args.tgt_len,
|
||||
"target": tgt,
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tgt_len,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=item_size,
|
||||
@ -88,7 +86,6 @@ class DummyMTTask(LegacyFairseqTask):
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
|
@ -6,9 +6,10 @@
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
import torch
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
|
||||
|
||||
def safe_readline(f):
|
||||
pos = f.tell()
|
||||
|
@ -67,12 +67,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
or is_better(val_loss, save_checkpoint.best)
|
||||
)
|
||||
if val_loss is not None and args.keep_best_checkpoints > 0:
|
||||
checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
|
||||
args.best_checkpoint_metric, val_loss)] = (
|
||||
not hasattr(save_checkpoint, "best")
|
||||
or is_better(val_loss, save_checkpoint.best)
|
||||
checkpoint_conds[
|
||||
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
|
||||
] = not hasattr(save_checkpoint, "best") or is_better(
|
||||
val_loss, save_checkpoint.best
|
||||
)
|
||||
checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints
|
||||
checkpoint_conds[
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not args.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
@ -112,10 +114,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
if args.keep_best_checkpoints > 0:
|
||||
# only keep the best N checkpoints according to validation metric
|
||||
checkpoints = checkpoint_paths(
|
||||
args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
|
||||
args.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
|
||||
args.best_checkpoint_metric
|
||||
),
|
||||
)
|
||||
if not args.maximize_best_checkpoint_metric:
|
||||
checkpoints = checkpoints[::-1]
|
||||
for old_chk in checkpoints[args.keep_best_checkpoints:]:
|
||||
for old_chk in checkpoints[args.keep_best_checkpoints :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
@ -133,16 +139,23 @@ def load_checkpoint(args, trainer, **passthrough_args):
|
||||
reset_meters = args.reset_meters
|
||||
reset_dataloader = args.reset_dataloader
|
||||
|
||||
if getattr(args, 'finetune_from_model', None) is not None \
|
||||
and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
|
||||
raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer"
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader")
|
||||
if getattr(args, "finetune_from_model", None) is not None and (
|
||||
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
||||
):
|
||||
raise ValueError(
|
||||
"--finetune-from-model can not be set together with either --reset-optimizer"
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||
)
|
||||
|
||||
suffix = getattr(args, "checkpoint_suffix", "")
|
||||
if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
|
||||
if (
|
||||
args.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(
|
||||
args.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
)
|
||||
first_launch = not PathManager.exists(checkpoint_path)
|
||||
if getattr(args, 'finetune_from_model', None) is not None and first_launch:
|
||||
if getattr(args, "finetune_from_model", None) is not None and first_launch:
|
||||
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
||||
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
||||
if PathManager.exists(args.finetune_from_model):
|
||||
@ -151,19 +164,26 @@ def load_checkpoint(args, trainer, **passthrough_args):
|
||||
reset_lr_scheduler = True
|
||||
reset_meters = True
|
||||
reset_dataloader = True
|
||||
logger.info(f'loading pretrained model from {checkpoint_path}: '
|
||||
'optimizer, lr scheduler, meters, dataloader will be reset')
|
||||
logger.info(
|
||||
f"loading pretrained model from {checkpoint_path}: "
|
||||
"optimizer, lr scheduler, meters, dataloader will be reset"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist')
|
||||
raise ValueError(
|
||||
f"--funetune-from-model {args.finetune_from_model} does not exist"
|
||||
)
|
||||
elif getattr(args, "model_parallel_size", 1) > 1:
|
||||
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
checkpoint_path = args.restore_file
|
||||
|
||||
if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None):
|
||||
if args.restore_file != "checkpoint_last.pt" and getattr(
|
||||
args, "finetune_from_model", None
|
||||
):
|
||||
raise ValueError(
|
||||
'--finetune-from-model and --restore-file (non-default value) '
|
||||
'can not be specified together: ' + str(args))
|
||||
"--finetune-from-model and --restore-file (non-default value) "
|
||||
"can not be specified together: " + str(args)
|
||||
)
|
||||
|
||||
extra_state = trainer.load_checkpoint(
|
||||
checkpoint_path,
|
||||
@ -213,7 +233,9 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
|
||||
return state
|
||||
|
||||
|
||||
def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
|
||||
def load_model_ensemble(
|
||||
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
|
||||
):
|
||||
"""Loads an ensemble of models.
|
||||
|
||||
Args:
|
||||
@ -222,18 +244,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s
|
||||
were used during model training
|
||||
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
||||
"""
|
||||
assert not (strict and num_shards > 1), \
|
||||
"Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble, args, _task = load_model_ensemble_and_task(
|
||||
filenames, arg_overrides, task, strict, suffix, num_shards,
|
||||
filenames,
|
||||
arg_overrides,
|
||||
task,
|
||||
strict,
|
||||
suffix,
|
||||
num_shards,
|
||||
)
|
||||
return ensemble, args
|
||||
|
||||
|
||||
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
|
||||
def load_model_ensemble_and_task(
|
||||
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
|
||||
):
|
||||
from fairseq import tasks
|
||||
assert not (strict and num_shards > 1), \
|
||||
"Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble = []
|
||||
for filename in filenames:
|
||||
orig_filename = filename
|
||||
@ -533,7 +565,9 @@ def verify_checkpoint_directory(save_dir: str) -> None:
|
||||
with open(temp_file_path, "w"):
|
||||
pass
|
||||
except OSError as e:
|
||||
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
|
||||
logger.warning(
|
||||
"Unable to access checkpoint save directory: {}".format(save_dir)
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
os.remove(temp_file_path)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user