Apply black+isort (#1357)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357

Reviewed By: alexeib

Differential Revision: D24377772

fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
This commit is contained in:
Myle Ott 2020-10-18 18:13:29 -07:00 committed by Facebook GitHub Bot
parent 5695cdfb2c
commit a48f235636
396 changed files with 15418 additions and 9810 deletions

View File

@ -20,10 +20,11 @@
import os
import sys
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath('..'))
source_suffix = ['.rst']
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
source_suffix = [".rst"]
# -- General configuration ------------------------------------------------
@ -35,34 +36,34 @@ source_suffix = ['.rst']
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.intersphinx',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'sphinxarg.ext',
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"sphinx.ext.viewcode",
"sphinx.ext.napoleon",
"sphinxarg.ext",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'fairseq'
copyright = '2019, Facebook AI Research (FAIR)'
author = 'Facebook AI Research (FAIR)'
project = "fairseq"
copyright = "2019, Facebook AI Research (FAIR)"
author = "Facebook AI Research (FAIR)"
github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.9.0'
version = "0.9.0"
# The full version, including alpha/beta/rc tags.
release = '0.9.0'
release = "0.9.0"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@ -74,11 +75,11 @@ language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
highlight_language = 'python'
pygments_style = "sphinx"
highlight_language = "python"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
@ -89,7 +90,7 @@ todo_include_todos = False
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
@ -100,11 +101,11 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
html_context = {
'css_files': [
'_static/theme_overrides.css', # override wide tables in RTD theme
"css_files": [
"_static/theme_overrides.css", # override wide tables in RTD theme
],
}
@ -113,7 +114,7 @@ html_context = {
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
#html_sidebars = {
# html_sidebars = {
# '**': [
# 'about.html',
# 'navigation.html',
@ -121,12 +122,12 @@ html_context = {
# 'searchbox.html',
# 'donate.html',
# ]
#}
# }
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/', None),
'torch': ('https://pytorch.org/docs/master/', None),
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
"python": ("https://docs.python.org/", None),
"torch": ("https://pytorch.org/docs/master/", None),
}

View File

@ -3,6 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
__version__ = '0.9.0'
__version__ = "0.9.0"
import examples.noisychannel # noqa

View File

@ -7,8 +7,8 @@
import argparse
import fileinput
import hashlib
from multiprocessing import Pool
import sys
from multiprocessing import Pool
def get_hashes_and_lines(raw_line):
@ -18,12 +18,12 @@ def get_hashes_and_lines(raw_line):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, default=10)
parser.add_argument('files', nargs='*', help='input files')
parser.add_argument("--workers", type=int, default=10)
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
seen = set()
with fileinput.input(args.files, mode='rb') as h:
with fileinput.input(args.files, mode="rb") as h:
pool = Pool(args.workers)
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
for i, (hash, raw_line) in enumerate(results):
@ -37,5 +37,5 @@ def main():
print(file=sys.stderr, flush=True)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -11,26 +11,38 @@ from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description=(
'Extract back-translations from the stdout of fairseq-generate. '
'If there are multiply hypotheses for a source, we only keep the first one. '
))
parser.add_argument('--output', required=True, help='output prefix')
parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)')
parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)')
parser.add_argument('--minlen', type=int, help='min length filter')
parser.add_argument('--maxlen', type=int, help='max length filter')
parser.add_argument('--ratio', type=float, help='ratio filter')
parser.add_argument('files', nargs='*', help='input files')
parser = argparse.ArgumentParser(
description=(
"Extract back-translations from the stdout of fairseq-generate. "
"If there are multiply hypotheses for a source, we only keep the first one. "
)
)
parser.add_argument("--output", required=True, help="output prefix")
parser.add_argument(
"--srclang", required=True, help="source language (extracted from H-* lines)"
)
parser.add_argument(
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
)
parser.add_argument("--minlen", type=int, help="min length filter")
parser.add_argument("--maxlen", type=int, help="max length filter")
parser.add_argument("--ratio", type=float, help="ratio filter")
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
def validate(src, tgt):
srclen = len(src.split(' ')) if src != '' else 0
tgtlen = len(tgt.split(' ')) if tgt != '' else 0
srclen = len(src.split(" ")) if src != "" else 0
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
if (
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen))
or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio))
or (
args.maxlen is not None
and (srclen > args.maxlen or tgtlen > args.maxlen)
)
or (
args.ratio is not None
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
)
):
return False
return True
@ -41,19 +53,20 @@ def main():
except IndexError:
return default
with open(args.output + '.' + args.srclang, 'w') as src_h, \
open(args.output + '.' + args.tgtlang, 'w') as tgt_h:
with open(args.output + "." + args.srclang, "w") as src_h, open(
args.output + "." + args.tgtlang, "w"
) as tgt_h:
for line in tqdm(fileinput.input(args.files)):
if line.startswith('S-'):
tgt = safe_index(line.rstrip().split('\t'), 1, '')
elif line.startswith('H-'):
if line.startswith("S-"):
tgt = safe_index(line.rstrip().split("\t"), 1, "")
elif line.startswith("H-"):
if tgt is not None:
src = safe_index(line.rstrip().split('\t'), 2, '')
src = safe_index(line.rstrip().split("\t"), 2, "")
if validate(src, tgt):
print(src, file=src_h)
print(tgt, file=tgt_h)
tgt = None
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -4,203 +4,251 @@
# LICENSE file in the root directory of this source tree.
import os.path as op
import argparse
import os
from multiprocessing import cpu_count
import os.path as op
from collections import namedtuple
from typing import Optional, List
from multiprocessing import cpu_count
from typing import List, Optional
import sentencepiece as sp
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
from fairseq.data.encoders.byte_utils import byte_encode
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
from fairseq.data.encoders.characters import Characters
from fairseq.data.encoders.byte_bpe import ByteBPE
from fairseq.data.encoders.byte_utils import byte_encode
from fairseq.data.encoders.bytes import Bytes
from fairseq.data.encoders.characters import Characters
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
SPLITS = ['train', 'valid', 'test']
SPLITS = ["train", "valid", "test"]
def _convert_xml(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
ss = s.strip()
if not ss.startswith('<seg'):
if not ss.startswith("<seg"):
continue
ss = ss.replace('</seg>', '').split('">')
ss = ss.replace("</seg>", "").split('">')
assert len(ss) == 2
f_o.write(ss[1].strip() + '\n')
f_o.write(ss[1].strip() + "\n")
def _convert_train(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
ss = s.strip()
if ss.startswith('<'):
if ss.startswith("<"):
continue
f_o.write(ss.strip() + '\n')
f_o.write(ss.strip() + "\n")
def _get_bytes(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(Bytes.encode(s.strip()) + '\n')
f_o.write(Bytes.encode(s.strip()) + "\n")
def _get_chars(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(Characters.encode(s.strip()) + '\n')
f_o.write(Characters.encode(s.strip()) + "\n")
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
Args = namedtuple('Args', ['moses_source_lang', 'moses_target_lang',
'moses_no_dash_splits', 'moses_no_escape'])
args = Args(moses_source_lang=src, moses_target_lang=tgt,
moses_no_dash_splits=False, moses_no_escape=False)
Args = namedtuple(
"Args",
[
"moses_source_lang",
"moses_target_lang",
"moses_no_dash_splits",
"moses_no_escape",
],
)
args = Args(
moses_source_lang=src,
moses_target_lang=tgt,
moses_no_dash_splits=False,
moses_no_escape=False,
)
pretokenizer = MosesTokenizer(args)
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(pretokenizer.encode(s.strip()) + '\n')
f_o.write(pretokenizer.encode(s.strip()) + "\n")
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
with open(out_path, 'w') as f_o:
with open(out_path, "w") as f_o:
for lang in [src, tgt]:
with open(f'{in_path_prefix}.{lang}') as f:
with open(f"{in_path_prefix}.{lang}") as f:
for s in f:
f_o.write(byte_encode(s.strip()) + '\n')
f_o.write(byte_encode(s.strip()) + "\n")
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
arguments = [
f'--input={in_path}', f'--model_prefix={model_prefix}',
f'--model_type=bpe', f'--vocab_size={vocab_size}',
'--character_coverage=1.0', '--normalization_rule_name=identity',
f'--num_threads={cpu_count()}'
f"--input={in_path}",
f"--model_prefix={model_prefix}",
f"--model_type=bpe",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
"--normalization_rule_name=identity",
f"--num_threads={cpu_count()}",
]
sp.SentencePieceTrainer.Train(' '.join(arguments))
sp.SentencePieceTrainer.Train(" ".join(arguments))
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
Args = namedtuple('Args', ['sentencepiece_model_path'])
Args = namedtuple("Args", ["sentencepiece_model_path"])
args = Args(sentencepiece_model_path=model_path)
tokenizer = ByteBPE(args)
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(tokenizer.encode(s.strip()) + '\n')
f_o.write(tokenizer.encode(s.strip()) + "\n")
def _apply_bpe(model_path: str, in_path: str, out_path: str):
Args = namedtuple('Args', ['sentencepiece_model'])
Args = namedtuple("Args", ["sentencepiece_model"])
args = Args(sentencepiece_model=model_path)
tokenizer = SentencepieceBPE(args)
with open(in_path) as f, open(out_path, 'w') as f_o:
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(tokenizer.encode(s.strip()) + '\n')
f_o.write(tokenizer.encode(s.strip()) + "\n")
def _concat_files(in_paths: List[str], out_path: str):
with open(out_path, 'w') as f_o:
with open(out_path, "w") as f_o:
for p in in_paths:
with open(p) as f:
for r in f:
f_o.write(r)
def preprocess_iwslt17(root: str, src: str, tgt: str, bpe_size: Optional[int],
need_chars: bool, bbpe_size: Optional[int],
need_bytes: bool):
def preprocess_iwslt17(
root: str,
src: str,
tgt: str,
bpe_size: Optional[int],
need_chars: bool,
bbpe_size: Optional[int],
need_bytes: bool,
):
# extract bitext
in_root = op.join(root, f'{src}-{tgt}')
in_root = op.join(root, f"{src}-{tgt}")
for lang in [src, tgt]:
_convert_train(
op.join(in_root, f'train.tags.{src}-{tgt}.{lang}'),
op.join(root, f'train.{lang}')
op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
op.join(root, f"train.{lang}"),
)
_convert_xml(
op.join(in_root, f'IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml'),
op.join(root, f'valid.{lang}')
op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
op.join(root, f"valid.{lang}"),
)
_convert_xml(
op.join(in_root, f'IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml'),
op.join(root, f'test.{lang}')
op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
op.join(root, f"test.{lang}"),
)
# pre-tokenize
for lang in [src, tgt]:
for split in SPLITS:
pretokenize(op.join(root, f'{split}.{lang}'),
op.join(root, f'{split}.moses.{lang}'), src, tgt)
pretokenize(
op.join(root, f"{split}.{lang}"),
op.join(root, f"{split}.moses.{lang}"),
src,
tgt,
)
# tokenize with BPE vocabulary
if bpe_size is not None:
# learn vocabulary
concated_train_path = op.join(root, 'train.all')
concated_train_path = op.join(root, "train.all")
_concat_files(
[op.join(root, 'train.moses.fr'), op.join(root, 'train.moses.en')],
concated_train_path
[op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
concated_train_path,
)
bpe_model_prefix = op.join(root, f'spm_bpe{bpe_size}')
bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
os.remove(concated_train_path)
# apply
for lang in [src, tgt]:
for split in SPLITS:
_apply_bpe(
bpe_model_prefix + '.model',
op.join(root, f'{split}.moses.{lang}'),
op.join(root, f'{split}.moses.bpe{bpe_size}.{lang}')
bpe_model_prefix + ".model",
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
)
# tokenize with bytes vocabulary
if need_bytes:
for lang in [src, tgt]:
for split in SPLITS:
_get_bytes(op.join(root, f'{split}.moses.{lang}'),
op.join(root, f'{split}.moses.bytes.{lang}'))
_get_bytes(
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bytes.{lang}"),
)
# tokenize with characters vocabulary
if need_chars:
for lang in [src, tgt]:
for split in SPLITS:
_get_chars(op.join(root, f'{split}.moses.{lang}'),
op.join(root, f'{split}.moses.chars.{lang}'))
_get_chars(
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.chars.{lang}"),
)
# tokenize with byte-level BPE vocabulary
if bbpe_size is not None:
# learn vocabulary
bchar_path = op.join(root, 'train.bchar')
_convert_to_bchar(op.join(root, 'train.moses'), src, tgt, bchar_path)
bbpe_model_prefix = op.join(root, f'spm_bbpe{bbpe_size}')
bchar_path = op.join(root, "train.bchar")
_convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
os.remove(bchar_path)
# apply
for lang in [src, tgt]:
for split in SPLITS:
_apply_bbpe(
bbpe_model_prefix + '.model',
op.join(root, f'{split}.moses.{lang}'),
op.join(root, f'{split}.moses.bbpe{bbpe_size}.{lang}')
bbpe_model_prefix + ".model",
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='data')
parser.add_argument('--bpe-vocab', default=None, type=int,
help='Generate tokenized bitext with BPE of size K.'
'Default to None (disabled).')
parser.add_argument('--bbpe-vocab', default=None, type=int,
help='Generate tokenized bitext with BBPE of size K.'
'Default to None (disabled).')
parser.add_argument('--byte-vocab', action='store_true',
help='Generate tokenized bitext with bytes vocabulary')
parser.add_argument('--char-vocab', action='store_true',
help='Generate tokenized bitext with chars vocabulary')
parser.add_argument("--root", type=str, default="data")
parser.add_argument(
"--bpe-vocab",
default=None,
type=int,
help="Generate tokenized bitext with BPE of size K."
"Default to None (disabled).",
)
parser.add_argument(
"--bbpe-vocab",
default=None,
type=int,
help="Generate tokenized bitext with BBPE of size K."
"Default to None (disabled).",
)
parser.add_argument(
"--byte-vocab",
action="store_true",
help="Generate tokenized bitext with bytes vocabulary",
)
parser.add_argument(
"--char-vocab",
action="store_true",
help="Generate tokenized bitext with chars vocabulary",
)
args = parser.parse_args()
preprocess_iwslt17(args.root, 'fr', 'en', args.bpe_vocab, args.char_vocab,
args.bbpe_vocab, args.byte_vocab)
preprocess_iwslt17(
args.root,
"fr",
"en",
args.bpe_vocab,
args.char_vocab,
args.bbpe_vocab,
args.byte_vocab,
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerModel, TransformerEncoder
from fairseq.models.transformer import TransformerEncoder, TransformerModel
@register_model("gru_transformer")
@ -24,9 +24,12 @@ class GRUTransformerModel(TransformerModel):
class GRUTransformerEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.emb_ctx = nn.GRU(input_size=embed_tokens.embedding_dim,
hidden_size=embed_tokens.embedding_dim // 2,
num_layers=1, bidirectional=True)
self.emb_ctx = nn.GRU(
input_size=embed_tokens.embedding_dim,
hidden_size=embed_tokens.embedding_dim // 2,
num_layers=1,
bidirectional=True,
)
def forward_embedding(self, src_tokens):
# embed tokens and positions

View File

@ -16,11 +16,12 @@ def main(args):
print(normalizer.normalize(line.rstrip()), flush=True)
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en')
parser.add_argument('--penn', '-p', action='store_true')
parser.add_argument("--lang", "-l", default="en")
parser.add_argument("--penn", "-p", action="store_true")
args = parser.parse_args()
main(args)

View File

@ -6,12 +6,14 @@
# LICENSE file in the root directory of this source tree.
import sys
import sacremoses
def main(args):
"""Tokenizes, preserving tabs"""
mt = sacremoses.MosesTokenizer(lang=args.lang)
def tok(s):
return mt.tokenize(s, return_str=True)
@ -20,12 +22,13 @@ def main(args):
print(*parts, sep="\t", flush=True)
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en')
parser.add_argument('--penn', '-p', action='store_true')
parser.add_argument('--fields', '-f', help="fields to tokenize")
parser.add_argument("--lang", "-l", default="en")
parser.add_argument("--penn", "-p", action="store_true")
parser.add_argument("--fields", "-f", help="fields to tokenize")
args = parser.parse_args()
main(args)

View File

@ -3,14 +3,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import numpy as np
import glob
import argparse
import glob
from subprocess import check_call
import faiss
import numpy as np
GB = 1024*1024*1024
GB = 1024 * 1024 * 1024
def call(cmd):
@ -18,14 +19,14 @@ def call(cmd):
check_call(cmd, shell=True)
def get_batches(directory, lang, prefix='all_avg_pool'):
def get_batches(directory, lang, prefix="all_avg_pool"):
print(f"Finding in {directory}/{prefix}.{lang}*")
files = glob.glob(f'{directory}/{prefix}.{lang}*')
files = glob.glob(f"{directory}/{prefix}.{lang}*")
emb_files = []
txt_files = []
for emb_fi in files:
emb_files.append(emb_fi)
txt_fi = emb_fi.replace(prefix, 'sentences')
txt_fi = emb_fi.replace(prefix, "sentences")
txt_files.append(txt_fi)
return emb_files, txt_files
@ -38,7 +39,7 @@ def load_batch(emb_file, dim):
return embeddings
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
sims = []
inds = []
xfrom = 0
@ -53,7 +54,7 @@ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
y_batch = load_batch(y_batch_f, dim)
neighbor_size = min(k, y_batch.shape[0])
yto = yfrom + y_batch.shape[0]
print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto))
print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
idx = faiss.IndexFlatIP(dim)
idx = faiss.index_cpu_to_all_gpus(idx)
idx.add(y_batch)
@ -86,8 +87,10 @@ def score(sim, fwd_mean, bwd_mean, margin):
return margin(sim, (fwd_mean + bwd_mean) / 2)
def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False):
print(' - scoring {:d} candidates'.format(sim_mat.shape[0]))
def score_candidates(
sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
):
print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
scores = np.zeros(candidate_inds.shape)
for i in range(scores.shape[0]):
for j in range(scores.shape[1]):
@ -106,42 +109,50 @@ def load_text(files):
return all_sentences
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Mine bitext')
parser.add_argument('--src-lang', help='Source language')
parser.add_argument('--tgt-lang', help='Target language')
parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt')
parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model')
parser.add_argument('--dim', type=int, default=1024,
help='Embedding dimension')
parser.add_argument('--mem', type=int, default=5,
help='Memory in GB')
parser.add_argument('--src-dir', help='Source directory')
parser.add_argument('--tgt-dir', help='Target directory')
parser.add_argument('--output', help='Output path')
parser.add_argument('--neighborhood', type=int, default=4,
help='Embedding dimension')
parser.add_argument('--threshold', type=float, default=1.06,
help='Threshold on mined bitext')
parser.add_argument('--valid-size', type=int, default=2000,
help='Number of sentences used for validation set')
parser.add_argument('--min-count', type=int, default=50000,
help='Min num sentences used for each language')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mine bitext")
parser.add_argument("--src-lang", help="Source language")
parser.add_argument("--tgt-lang", help="Target language")
parser.add_argument(
"--dict-path", help="Path to dictionary file", default="dict.txt"
)
parser.add_argument(
"--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
)
parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
parser.add_argument("--src-dir", help="Source directory")
parser.add_argument("--tgt-dir", help="Target directory")
parser.add_argument("--output", help="Output path")
parser.add_argument(
"--neighborhood", type=int, default=4, help="Embedding dimension"
)
parser.add_argument(
"--threshold", type=float, default=1.06, help="Threshold on mined bitext"
)
parser.add_argument(
"--valid-size",
type=int,
default=2000,
help="Number of sentences used for validation set",
)
parser.add_argument(
"--min-count",
type=int,
default=50000,
help="Min num sentences used for each language",
)
args = parser.parse_args()
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
margin = lambda a, b: a / b
y2x_sim, y2x_ind = knnGPU_sharded(
y_batches_f, x_batches_f,
args.dim,
args.neighborhood,
direction='y2x')
y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
)
x2y_sim, x2y_ind = knnGPU_sharded(
x_batches_f, y_batches_f,
args.dim,
args.neighborhood,
direction='x2y')
x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
)
x2y_mean = x2y_sim.mean(axis=1)
y2x_mean = y2x_sim.mean(axis=1)
@ -149,8 +160,13 @@ if __name__ == '__main__':
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1)
indices = np.stack(
(
np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
),
axis=1,
)
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
x_sentences = load_text(x_sents_f)
@ -162,20 +178,20 @@ if __name__ == '__main__':
directory = args.output
call(f"mkdir -p {directory}")
src_out = open(
f'{directory}/all.{args.src_lang}',
mode='w',
encoding='utf-8',
errors='surrogateescape')
f"{directory}/all.{args.src_lang}",
mode="w",
encoding="utf-8",
errors="surrogateescape",
)
tgt_out = open(
f'{directory}/all.{args.tgt_lang}',
mode='w',
encoding='utf-8',
errors='surrogateescape')
f"{directory}/all.{args.tgt_lang}",
mode="w",
encoding="utf-8",
errors="surrogateescape",
)
scores_out = open(
f'{directory}/all.scores',
mode='w',
encoding='utf-8',
errors='surrogateescape')
f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
)
count = 0
for i in np.argsort(-scores):
src_ind, trg_ind = indices[i]
@ -195,20 +211,23 @@ if __name__ == '__main__':
scores_out.close()
print(f"Found {count} pairs for threshold={threshold}")
with open(f'{directory}/all.{args.src_lang}') as all_s, \
open(f'{directory}/all.{args.tgt_lang}') as all_t, \
open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \
open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \
open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \
open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t:
count = 0
for s_line, t_line in zip(all_s, all_t):
s_line = s_line.split('\t')[1]
t_line = t_line.split('\t')[1]
if count >= args.valid_size:
train_s.write(s_line)
train_t.write(t_line)
else:
valid_s.write(s_line)
valid_t.write(t_line)
count += 1
with open(f"{directory}/all.{args.src_lang}") as all_s, open(
f"{directory}/all.{args.tgt_lang}"
) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
f"{directory}/valid.{args.tgt_lang}", "w"
) as valid_t, open(
f"{directory}/train.{args.src_lang}", "w"
) as train_s, open(
f"{directory}/train.{args.tgt_lang}", "w"
) as train_t:
count = 0
for s_line, t_line in zip(all_s, all_t):
s_line = s_line.split("\t")[1]
t_line = t_line.split("\t")[1]
if count >= args.valid_size:
train_s.write(s_line)
train_t.write(t_line)
else:
valid_s.write(s_line)
valid_t.write(t_line)
count += 1

View File

@ -7,27 +7,29 @@
Translate pre-processed data with a trained model.
"""
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
import numpy as np
def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False):
def get_avg_pool(
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
):
model = EnsembleModel(models)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32)
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
np.float32
)
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
if has_langtok:
encoder_mask = encoder_mask[1:, :, :]
@ -38,13 +40,15 @@ def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langto
def main(args):
assert args.path is not None, '--path required for generation!'
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
args.beam=1
args.beam = 1
utils.import_user_module(args)
if args.max_tokens is None:
@ -58,15 +62,15 @@ def main(args):
# Set dictionaries
try:
src_dict = getattr(task, 'source_dictionary', None)
src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
print("| loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
args.path.split(":"),
arg_overrides=eval(args.model_overrides),
task=task,
)
@ -105,9 +109,9 @@ def main(args):
shard_id = 0
all_avg_pool = None
encoder_has_langtok = (
hasattr(task.args, 'encoder_langtok')
hasattr(task.args, "encoder_langtok")
and task.args.encoder_langtok is not None
and hasattr(task.args, 'lang_tok_replacing_bos_eos')
and hasattr(task.args, "lang_tok_replacing_bos_eos")
and not task.args.lang_tok_replacing_bos_eos
)
with progress_bar.build_progress_bar(args, itr) as t:
@ -116,34 +120,42 @@ def main(args):
print("Skipping None")
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample['target'][:, :args.prefix_size]
prefix_tokens = sample["target"][:, : args.prefix_size]
with torch.no_grad():
avg_pool = get_avg_pool(
models, sample, prefix_tokens, src_dict,
args.remove_bpe,
has_langtok=encoder_has_langtok)
models,
sample,
prefix_tokens,
src_dict,
args.remove_bpe,
has_langtok=encoder_has_langtok,
)
if all_avg_pool is not None:
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
else:
all_avg_pool = avg_pool
if not isinstance(sample['id'], list):
sample_ids = sample['id'].tolist()
if not isinstance(sample["id"], list):
sample_ids = sample["id"].tolist()
else:
sample_ids = sample['id']
sample_ids = sample["id"]
for i, sample_id in enumerate(sample_ids):
# Remove padding
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
)
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
src_str = task.dataset(args.gen_subset).src.get_original_text(
sample_id
)
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
@ -152,37 +164,50 @@ def main(args):
if not args.quiet:
if src_dict is not None:
print('S-{}\t{}'.format(sample_id, src_str))
print("S-{}\t{}".format(sample_id, src_str))
source_sentences.append(f"{sample_id}\t{src_str}")
num_sentences += sample['nsentences']
num_sentences += sample["nsentences"]
if all_avg_pool.shape[0] >= 1000000:
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
'w') as avg_pool_file:
with open(
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
"w",
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
sentence_file.writelines(f'{line}\n' for line in source_sentences)
with open(
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
"w",
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
all_avg_pool = None
source_sentences = []
shard_id += 1
if all_avg_pool is not None:
with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
'w') as avg_pool_file:
with open(
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
sentence_file.writelines(f'{line}\n' for line in source_sentences)
with open(
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
return None
def cli_main():
parser = options.get_generation_parser()
parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N',
help='directory to save encoder outputs')
parser.add_argument(
"--encoder-save-dir",
default="",
type=str,
metavar="N",
help="directory to save encoder outputs",
)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == '__main__':
if __name__ == "__main__":
cli_main()

View File

@ -3,10 +3,11 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import argparse
import glob
import numpy as np
DIM = 1024
@ -14,9 +15,13 @@ DIM = 1024
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
target_ids = [tid for tid in target_embs]
source_mat = np.stack(source_embs.values(), axis=0)
normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True)
normalized_source_mat = source_mat / np.linalg.norm(
source_mat, axis=1, keepdims=True
)
target_mat = np.stack(target_embs.values(), axis=0)
normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True)
normalized_target_mat = target_mat / np.linalg.norm(
target_mat, axis=1, keepdims=True
)
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
if return_sim_mat:
return sim_mat
@ -36,14 +41,14 @@ def load_embeddings(directory, LANGS):
lang_dir = f"{directory}/{lang}"
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
for embed_file in embedding_files:
shard_id = embed_file.split('.')[-1]
shard_id = embed_file.split(".")[-1]
embeddings = np.fromfile(embed_file, dtype=np.float32)
num_rows = embeddings.shape[0] // DIM
embeddings = embeddings.reshape((num_rows, DIM))
with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file:
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
for idx, line in enumerate(sentence_file):
sentence_id, sentence = line.strip().split('\t')
sentence_id, sentence = line.strip().split("\t")
sentence_texts[lang][sentence_id] = sentence
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
@ -55,7 +60,7 @@ def compute_accuracy(directory, LANGS):
top_1_accuracy = {}
top1_str = " ".join(LANGS) + '\n'
top1_str = " ".join(LANGS) + "\n"
for source_lang in LANGS:
top_1_accuracy[source_lang] = {}
top1_str += f"{source_lang} "
@ -63,8 +68,8 @@ def compute_accuracy(directory, LANGS):
top1 = 0
top5 = 0
neighbors_map = compute_dist(
sentence_embeddings[source_lang],
sentence_embeddings[target_lang])
sentence_embeddings[source_lang], sentence_embeddings[target_lang]
)
for sentence_id, neighbors in neighbors_map.items():
if sentence_id == neighbors[0]:
top1 += 1
@ -75,17 +80,13 @@ def compute_accuracy(directory, LANGS):
top1_str += "\n"
print(top1_str)
print(top1_str, file=open(f"{directory}/accuracy", 'w'))
print(top1_str, file=open(f"{directory}/accuracy", "w"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Analyze encoder outputs')
parser.add_argument('directory',
help='Source language corpus'
)
parser.add_argument('--langs',
help='List of langs'
)
parser = argparse.ArgumentParser(description="Analyze encoder outputs")
parser.add_argument("directory", help="Source language corpus")
parser.add_argument("--langs", help="List of langs")
args = parser.parse_args()
langs = args.langs.split(',')
langs = args.langs.split(",")
compute_accuracy(args.directory, langs)

View File

@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
from .loss import latent_depth # noqa
from . import multilingual_translation_latent_depth # noqa
from . import multilingual_translation_latent_depth # noqa
from .loss import latent_depth # noqa
from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa

View File

@ -3,8 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import math
import torch
from torch.nn.modules.loss import _Loss
@ -19,17 +20,16 @@ class LatentLayersKLLoss(_Loss):
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (
torch.log(samples + eps) - math.log(0.5)
)).sum(-1)
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (samples * (
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
kl_loss = (
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")
@ -37,7 +37,9 @@ class LatentLayersKLLoss(_Loss):
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
(update_num - self.args.soft_update)
* self.args.sparsity_weight
/ self.args.anneal_updates,
)
kl_loss *= kl_weight * sample_size
return kl_loss
@ -58,15 +60,17 @@ class LatentLayersSparsityLoss(_Loss):
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
update_num > (self.args.soft_update + self.args.anneal_updates)):
if (
self.args.target_layers > 0 or self.args.share_weight > 0
) and update_num > (self.args.soft_update + self.args.anneal_updates):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight / self.args.anneal_updates
* self.args.share_weight
/ self.args.anneal_updates
)
else:
weight_anneal = 1
@ -75,12 +79,21 @@ class LatentLayersSparsityLoss(_Loss):
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
share_loss = sum(
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
)
batch_loss += (
weight_anneal * self.args.share_weight * sample_size * share_loss
)
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
batch_loss += (
weight_anneal
* self.args.share_weight
* sample_size
* global_sparsity_loss
)
return batch_loss

View File

@ -3,34 +3,31 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.models import register_model, register_model_architecture
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from .latent_transformer import (
LatentTransformerEncoder,
LatentTransformerDecoder,
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
base_architecture,
)
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
@register_model('latent_multilingual_transformer')
@register_model("latent_multilingual_transformer")
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
return LatentTransformerEncoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
@ -42,19 +39,21 @@ class LatentMultilingualTransformerModel(MultilingualTransformerModel):
return TransformerDecoder(args, lang_dict, embed_tokens)
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
@register_model_architecture(
"latent_multilingual_transformer", "latent_multilingual_transformer"
)
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 24)
args.share_encoders = getattr(args, 'share_encoders', True)
args.share_decoders = getattr(args, 'share_decoders', True)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.share_encoders = getattr(args, "share_encoders", True)
args.share_decoders = getattr(args, "share_decoders", True)
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
base_architecture(args)

View File

@ -7,26 +7,27 @@ from typing import Any, Dict, Optional
import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from ..modules.latent_layers import LayerSelect
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from torch import Tensor
from ..modules.latent_layers import LayerSelect
class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_encoder_layer(args, idx)
for idx in range(args.encoder_layers)
])
self.layers = nn.ModuleList(
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
@ -50,6 +51,7 @@ class LatentTransformerEncoderLayer(TransformerEncoderLayer):
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
@ -63,7 +65,10 @@ class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
@ -71,16 +76,20 @@ class LatentTransformerDecoder(TransformerDecoder):
)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
])
self.layers = nn.ModuleList(
[
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)
return LatentTransformerDecoderLayer(
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
)
def forward(
self,
@ -119,8 +128,15 @@ class LatentTransformerDecoderLayer(TransformerDecoderLayer):
(default: False).
"""
def __init__(
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
self,
args,
idx,
layer_select=None,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx

View File

@ -12,6 +12,7 @@ class LayerSelect(nn.Module):
either (soft) weighting or (hard) selection of residual connection.
https://arxiv.org/abs/2009.13102
"""
def __init__(self, num_layers, num_logits, args):
super(LayerSelect, self).__init__()
self.args = args
@ -27,14 +28,14 @@ class LayerSelect(nn.Module):
@staticmethod
def add_args(parser):
parser.add_argument(
'--soft-select',
action='store_true',
help='use soft samples in training an inference'
"--soft-select",
action="store_true",
help="use soft samples in training an inference",
)
parser.add_argument('--sampling-tau', type=float, help='sampling temperature')
parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
def sample(self, logit_idx):
""" To leverage the efficiency of distributed training, samples for all
"""To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other.
@ -43,7 +44,9 @@ class LayerSelect(nn.Module):
"""
assert logit_idx is not None
self.samples = self._gumbel_sigmoid(
self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :],
self.layer_logits[logit_idx, :].detach()
if self.detach_grad
else self.layer_logits[logit_idx, :],
dim=-1,
tau=self.tau,
hard=self.hard_select,
@ -54,10 +57,20 @@ class LayerSelect(nn.Module):
sample = self.samples[i]
return sample
def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5):
def _gumbel_sigmoid(
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
):
# ~Gumbel(0,1)
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
gumbels1 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
gumbels2 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
# Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()

View File

@ -5,10 +5,11 @@
from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
@register_task('multilingual_translation_latent_depth')
@register_task("multilingual_translation_latent_depth")
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth.
@ -39,7 +40,9 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs])
self.src_langs, self.tgt_langs = zip(
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
)
if self.training and self.encoder_latent_layer:
assert self.args.share_encoders
if self.training and self.decoder_latent_layer:
@ -47,46 +50,56 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update
model.models[lang_pair].encoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update
model.models[lang_pair].decoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if self.encoder_latent_layer:
none_samples = sum(
1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples
1 if x is None else 0
for x in model.models[lang_pair].encoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx,
update_num,
sample_size
sample_size,
)
if self.decoder_latent_layer:
none_samples = sum(
1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples
1 if x is None else 0
for x in model.models[lang_pair].decoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx,
update_num,
sample_size
sample_size,
)
if ignore_grad:
loss *= 0
@ -99,18 +112,31 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad)
sample, model, criterion, optimizer, update_num, ignore_grad
)
# compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0
if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size)
next(
iter(model.models.values())
).encoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size)
next(
iter(model.models.values())
).decoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if sparsity_loss > 0:
optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output
@ -123,10 +149,14 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models:
if self.encoder_latent_layer:
@ -137,15 +167,23 @@ class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx)
return super().inference_step(generator, models, sample, prefix_tokens, constraints)
return super().inference_step(
generator, models, sample, prefix_tokens, constraints
)
@property
def encoder_latent_layer(self):
return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer
return (
hasattr(self.args, "encoder_latent_layer")
and self.args.encoder_latent_layer
)
@property
def decoder_latent_layer(self):
return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer
return (
hasattr(self.args, "decoder_latent_layer")
and self.args.decoder_latent_layer
)
@property
def src_lang_idx_dict(self):

View File

@ -8,37 +8,40 @@ Linformer: Self-Attention with Linear Complexity
import logging
from fairseq.models import (
register_model,
register_model_architecture,
)
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.roberta import RobertaEncoder, RobertaModel
from fairseq.models.roberta import (
RobertaModel,
RobertaEncoder,
)
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
logger = logging.getLogger(__name__)
@register_model('linformer_roberta')
@register_model("linformer_roberta")
class LinformerModel(RobertaModel):
@staticmethod
def add_args(parser):
RobertaModel.add_args(parser)
# add args for Linformer
parser.add_argument('--compressed', type=int,
help='compressed ratio of sequence length')
parser.add_argument('--shared-kv-compressed', type=int,
help='share compressed matrix between k and v, in each layer')
parser.add_argument('--shared-layer-kv-compressed', type=int,
help='share compressed matrix between k and v and across all layers')
parser.add_argument('--freeze-compress', type=int,
help='freeze the parameters in compressed layer')
parser.add_argument(
"--compressed", type=int, help="compressed ratio of sequence length"
)
parser.add_argument(
"--shared-kv-compressed",
type=int,
help="share compressed matrix between k and v, in each layer",
)
parser.add_argument(
"--shared-layer-kv-compressed",
type=int,
help="share compressed matrix between k and v and across all layers",
)
parser.add_argument(
"--freeze-compress",
type=int,
help="freeze the parameters in compressed layer",
)
@classmethod
def build_model(cls, args, task):
@ -47,7 +50,7 @@ class LinformerModel(RobertaModel):
# make sure all arguments are present
base_architecture(args)
if not hasattr(args, 'max_positions'):
if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
encoder = LinformerEncoder(args, task.source_dictionary)
@ -85,47 +88,47 @@ class LinformerEncoder(RobertaEncoder):
)
@register_model_architecture('linformer_roberta', 'linformer_roberta')
@register_model_architecture("linformer_roberta", "linformer_roberta")
def base_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
args.compressed = getattr(args, 'compressed', 4)
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
args.freeze_compress = getattr(args, 'freeze_compress', 0)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.compressed = getattr(args, "compressed", 4)
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
args.freeze_compress = getattr(args, "freeze_compress", 0)
@register_model_architecture('linformer_roberta', 'linformer_roberta_base')
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
def linformer_roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture('linformer_roberta', 'linformer_roberta_large')
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
def linformer_roberta_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
args.compressed = getattr(args, 'compressed', 4)
args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.compressed = getattr(args, "compressed", 4)
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)

View File

@ -6,8 +6,8 @@
import math
import torch.nn as nn
from fairseq.modules import TransformerSentenceEncoder
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
@ -117,7 +117,9 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
qn_block_size,
):
if self.shared_layer_kv_compressed == 1:
compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed)
compress_layer = nn.Linear(
self.max_seq_len, self.max_seq_len // self.compressed
)
# intialize parameters for compressed layer
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
if self.freeze_compress == 1:
@ -139,8 +141,7 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
max_seq_len=self.max_seq_len,
shared_kv_compressed=self.shared_kv_compressed,
shared_compress_layer=(
None if self.shared_layer_kv_compressed == 0
else self.compress_layer
None if self.shared_layer_kv_compressed == 0 else self.compress_layer
),
freeze_compress=self.freeze_compress,
)
@ -156,7 +157,8 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder):
if self.shared_layer_kv_compressed:
for layer_idx in range(len(self.layers)):
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
layer_idx, k[len(prefix + 'compress_layer.'):],
layer_idx,
k[len(prefix + "compress_layer.") :],
)
items_to_add[new_k] = state_dict[k]

View File

@ -6,6 +6,7 @@
from typing import Callable
from fairseq.modules import TransformerSentenceEncoderLayer
from .multihead_linear_attention import MultiheadLinearAttention
@ -23,7 +24,7 @@ class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = 'relu',
activation_fn: str = "relu",
export: bool = False,
q_noise: float = 0.0,
qn_block_size: int = 8,

View File

@ -9,10 +9,10 @@ from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from torch import Tensor, nn
from torch.nn import Parameter
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.nn import Parameter
@with_incremental_state
@ -65,16 +65,24 @@ class MultiheadLinearAttention(nn.Module):
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
# used for compress sequence to subsequence
if shared_compress_layer is None:
self.compress_seq_len = max_seq_len // compressed
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
if shared_kv_compressed == 0:
self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
self.compress_v = nn.Linear(
max_seq_len, self.compress_seq_len, bias=False
)
self.layerwise_sharing = False
else:
self.compress_k = shared_compress_layer
@ -83,7 +91,9 @@ class MultiheadLinearAttention(nn.Module):
self.layerwise_sharing = True
self.shared_kv_compressed = shared_kv_compressed
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
@ -116,22 +126,28 @@ class MultiheadLinearAttention(nn.Module):
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
if not self.layerwise_sharing: # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2))
if (
not self.layerwise_sharing
): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(
self.compress_v.weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
if not self.layerwise_sharing: # otherwise, we already initialize the parameters
if (
not self.layerwise_sharing
): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight)
if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(self.compress_v.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
@ -189,14 +205,26 @@ class MultiheadLinearAttention(nn.Module):
q = self.q_proj(query)
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
k_input = (
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
k = self.k_proj(k_input)
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
if self.shared_kv_compressed == 0:
v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
v_input = (
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
v_input = (
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
v = self.v_proj(v_input)
elif self.encoder_decoder_attention:
# encoder-decoder attention
@ -302,7 +330,9 @@ class MultiheadLinearAttention(nn.Module):
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
attn_weights = MultiheadLinearAttention.apply_sparse_mask(
attn_weights, tgt_len, src_len, bsz
)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
@ -385,7 +415,9 @@ class MultiheadLinearAttention(nn.Module):
@torch.jit.export
def reorder_incremental_state(
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
@ -393,7 +425,9 @@ class MultiheadLinearAttention(nn.Module):
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0):
if self.encoder_decoder_attention and input_buffer_k.size(
0
) == new_order.size(0):
break
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
@ -428,8 +462,8 @@ class MultiheadLinearAttention(nn.Module):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
@ -438,9 +472,9 @@ class MultiheadLinearAttention(nn.Module):
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim:2 * dim
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")

View File

@ -8,14 +8,16 @@
import sys
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
factory=IndicNormalizerFactory()
normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing')
factory = IndicNormalizerFactory()
normalizer = factory.get_normalizer(
sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing"
)
for line in sys.stdin:
normalized_line=normalizer.normalize(line.strip())
tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1]))
normalized_line = normalizer.normalize(line.strip())
tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1]))
print(tokenized_line)

View File

@ -8,5 +8,6 @@ import sys
from pythainlp import word_tokenize
for line in sys.stdin:
print(" ".join(word_tokenize(line.strip())))

View File

@ -6,7 +6,9 @@
import fileinput
import sacrebleu
for line in fileinput.input():
print(sacrebleu.tokenize_zh(line))

View File

@ -6,19 +6,27 @@
import argparse
import fileinput
import sacremoses
def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('files', nargs='*', help='input files')
parser = argparse.ArgumentParser(description="")
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
detok = sacremoses.MosesDetokenizer()
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
print(detok.detokenize(line.strip().split(' ')).replace(' @', '').replace('@ ', '').replace(' =', '=').replace('= ', '=').replace(' ', ''))
print(
detok.detokenize(line.strip().split(" "))
.replace(" @", "")
.replace("@ ", "")
.replace(" =", "=")
.replace("= ", "=")
.replace(" ", "")
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -7,21 +7,22 @@ import math
from multiprocessing import Pool
import numpy as np
from fairseq import options
from fairseq.data import dictionary
from fairseq.scoring import bleu
from . import (
rerank_generate,
rerank_options,
rerank_score_bw,
rerank_score_lm,
rerank_options,
rerank_utils,
)
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize):
def score_target_hypo(
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
):
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
@ -61,11 +62,21 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
bitext2_score = None
bitext2_backwards = None
score = rerank_utils.get_score(a, b, c, target_len,
bitext1.rescore_score[i], bitext2_score, lm_score=lm_score,
lenpen=lenpen, src_len=bitext1.source_lengths[i],
tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards,
bitext2_backwards=bitext2_backwards, normalize=normalize)
score = rerank_utils.get_score(
a,
b,
c,
target_len,
bitext1.rescore_score[i],
bitext2_score,
lm_score=lm_score,
lenpen=lenpen,
src_len=bitext1.source_lengths[i],
tgt_len=bitext1.target_lengths[i],
bitext1_backwards=bitext1.backwards,
bitext2_backwards=bitext2_backwards,
normalize=normalize,
)
if score > best_score:
best_score = score
@ -88,8 +99,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
for key in range(len(gen_keys)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
"pred and rescore hypo mismatch: i: " + str(key) + ", "
+ str(hypo_lst[key]) + str(gen_keys[key])
"pred and rescore hypo mismatch: i: "
+ str(key)
+ ", "
+ str(hypo_lst[key])
+ str(gen_keys[key])
+ str(gen_output.no_bpe_hypo[key])
)
sys_tok = dict.encode_line(hypo_lst[key])
@ -97,7 +111,9 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
scorer.add(ref_tok, sys_tok)
else:
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
full_hypo = rerank_utils.get_full_from_prefix(
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
)
sys_tok = dict.encode_line(full_hypo)
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok)
@ -107,20 +123,31 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
# recover the orinal ids from n best list generation
for key in range(len(gen_output.no_bpe_target)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
"pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key])
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
"pred and rescore hypo mismatch:"
+ "i:"
+ str(key)
+ str(hypo_lst[key])
+ str(gen_output.no_bpe_hypo[key])
)
ordered_hypos[gen_keys[key]] = hypo_lst[key]
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
gen_keys[key]
]
else:
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
full_hypo = rerank_utils.get_full_from_prefix(
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
)
ordered_hypos[gen_keys[key]] = full_hypo
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
gen_keys[key]
]
# write the hypos in the original order from nbest list generation
if args.num_shards == (len(bitext1_lst)):
with open(target_outfile, 'w') as t:
with open(hypo_outfile, 'w') as h:
with open(target_outfile, "w") as t:
with open(hypo_outfile, "w") as h:
for key in range(len(ordered_hypos)):
t.write(ordered_targets[key])
h.write(ordered_hypos[key])
@ -135,17 +162,38 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
def match_target_hypo(args, target_outfile, hypo_outfile):
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
if len(args.weight1) == 1:
res = score_target_hypo(args, args.weight1[0], args.weight2[0],
args.weight3[0], args.lenpen[0], target_outfile,
hypo_outfile, True, args.normalize)
res = score_target_hypo(
args,
args.weight1[0],
args.weight2[0],
args.weight3[0],
args.lenpen[0],
target_outfile,
hypo_outfile,
True,
args.normalize,
)
rerank_scores = [res]
else:
print("launching pool")
with Pool(32) as p:
rerank_scores = p.starmap(score_target_hypo,
[(args, args.weight1[i], args.weight2[i], args.weight3[i],
args.lenpen[i], target_outfile, hypo_outfile,
False, args.normalize) for i in range(len(args.weight1))])
rerank_scores = p.starmap(
score_target_hypo,
[
(
args,
args.weight1[i],
args.weight2[i],
args.weight3[i],
args.lenpen[i],
target_outfile,
hypo_outfile,
False,
args.normalize,
)
for i in range(len(args.weight1))
],
)
if len(rerank_scores) > 1:
best_index = np.argmax(rerank_scores)
@ -155,11 +203,22 @@ def match_target_hypo(args, target_outfile, hypo_outfile):
print("best weight1", args.weight1[best_index])
print("best weight2", args.weight2[best_index])
print("best weight3", args.weight3[best_index])
return args.lenpen[best_index], args.weight1[best_index], \
args.weight2[best_index], args.weight3[best_index], best_score
return (
args.lenpen[best_index],
args.weight1[best_index],
args.weight2[best_index],
args.weight3[best_index],
best_score,
)
else:
return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0]
return (
args.lenpen[0],
args.weight1[0],
args.weight2[0],
args.weight3[0],
rerank_scores[0],
)
def load_score_files(args):
@ -175,55 +234,100 @@ def load_score_files(args):
for shard_id in shard_ids:
using_nbest = args.nbest_list is not None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, shard_id, args.num_shards, args.sampling,
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
rerank1_is_gen = (
args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1)
score1_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2)
score2_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
lm_score_file = rerank_utils.rescore_file_name(
pre_gen, args.prefix_len, args.lm_name, lm_file=True
)
# get gen output
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
nbest=using_nbest, prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac)
gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file,
bpe_symbol=args.remove_bpe,
nbest=using_nbest,
prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac,
)
if rerank1_is_gen:
bitext1 = gen_output
else:
bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1,
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
bitext1 = rerank_utils.BitextOutput(
score1_file,
args.backwards1,
args.right_to_left1,
args.remove_bpe,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
if args.score_model2 is not None or args.nbest_list is not None:
if rerank2_is_gen:
bitext2 = gen_output
else:
bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2,
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
bitext2 = rerank_utils.BitextOutput(
score2_file,
args.backwards2,
args.right_to_left2,
args.remove_bpe,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
assert bitext2.source_lengths == bitext1.source_lengths, \
"source lengths for rescoring models do not match"
assert bitext2.target_lengths == bitext1.target_lengths, \
"target lengths for rescoring models do not match"
assert (
bitext2.source_lengths == bitext1.source_lengths
), "source lengths for rescoring models do not match"
assert (
bitext2.target_lengths == bitext1.target_lengths
), "target lengths for rescoring models do not match"
else:
if args.diff_bpe:
assert args.score_model2 is None
@ -232,8 +336,13 @@ def load_score_files(args):
bitext2 = None
if args.language_model is not None:
lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len,
args.remove_bpe, args.target_prefix_frac)
lm_res1 = rerank_utils.LMOutput(
lm_score_file,
args.lm_dict,
args.prefix_len,
args.remove_bpe,
args.target_prefix_frac,
)
else:
lm_res1 = None
@ -259,28 +368,46 @@ def rerank(args):
shard_ids = [args.shard_id]
for shard_id in shard_ids:
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, shard_id, args.num_shards, args.sampling,
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
rerank_generate.gen_and_reprocess_nbest(args)
rerank_score_bw.score_bw(args)
rerank_score_lm.score_lm(args)
if args.write_hypos is None:
write_targets = pre_gen+"/matched_targets"
write_hypos = pre_gen+"/matched_hypos"
write_targets = pre_gen + "/matched_targets"
write_hypos = pre_gen + "/matched_hypos"
else:
write_targets = args.write_hypos+"_targets" + args.gen_subset
write_hypos = args.write_hypos+"_hypos" + args.gen_subset
write_targets = args.write_hypos + "_targets" + args.gen_subset
write_hypos = args.write_hypos + "_hypos" + args.gen_subset
if args.all_shards:
write_targets += "_all_shards"
write_hypos += "_all_shards"
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \
match_target_hypo(args, write_targets, write_hypos)
(
best_lenpen,
best_weight1,
best_weight2,
best_weight3,
best_score,
) = match_target_hypo(args, write_targets, write_hypos)
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
@ -291,5 +418,5 @@ def cli_main():
rerank(args)
if __name__ == '__main__':
if __name__ == "__main__":
cli_main()

View File

@ -8,9 +8,9 @@
Generate n-best translations using a trained model.
"""
from contextlib import redirect_stdout
import os
import subprocess
from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import generate, preprocess
@ -22,8 +22,12 @@ def gen_and_reprocess_nbest(args):
if args.score_dict_dir is None:
args.score_dict_dir = args.data
if args.prefix_len is not None:
assert args.right_to_left1 is False, "prefix length not compatible with right to left models"
assert args.right_to_left2 is False, "prefix length not compatible with right to left models"
assert (
args.right_to_left1 is False
), "prefix length not compatible with right to left models"
assert (
args.right_to_left2 is False
), "prefix length not compatible with right to left models"
if args.nbest_list is not None:
assert args.score_model2 is None
@ -35,27 +39,50 @@ def gen_and_reprocess_nbest(args):
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name
store_data = (
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
)
if not os.path.exists(store_data):
os.makedirs(store_data)
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, args.shard_id, args.num_shards,
args.sampling, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported"
assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported"
assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \
"target prefix frac and target prefix len incompatible"
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
assert not (
args.right_to_left1 and args.backwards1
), "backwards right to left not supported"
assert not (
args.right_to_left2 and args.backwards2
), "backwards right to left not supported"
assert not (
args.prefix_len is not None and args.target_prefix_frac is not None
), "target prefix frac and target prefix len incompatible"
# make directory to store generation results
if not os.path.exists(pre_gen):
os.makedirs(pre_gen)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
rerank1_is_gen = (
args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
if args.nbest_list is not None:
rerank2_is_gen = True
@ -70,17 +97,25 @@ def gen_and_reprocess_nbest(args):
if not os.path.exists(backwards_preprocessed_dir):
os.makedirs(backwards_preprocessed_dir)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1)
score1_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2)
score2_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
using_nbest = args.nbest_list is not None
@ -92,17 +127,29 @@ def gen_and_reprocess_nbest(args):
if not os.path.isfile(predictions_bpe_file):
print("STEP 1: generate predictions using the p(T|S) model with bpe")
print(args.data)
param1 = [args.data,
"--path", args.gen_model,
"--shard-id", str(args.shard_id),
"--num-shards", str(args.num_shards),
"--nbest", str(args.num_rescore),
"--batch-size", str(args.batch_size),
"--beam", str(args.num_rescore),
"--batch-size", str(args.num_rescore),
"--gen-subset", args.gen_subset,
"--source-lang", args.source_lang,
"--target-lang", args.target_lang]
param1 = [
args.data,
"--path",
args.gen_model,
"--shard-id",
str(args.shard_id),
"--num-shards",
str(args.num_shards),
"--nbest",
str(args.num_rescore),
"--batch-size",
str(args.batch_size),
"--beam",
str(args.num_rescore),
"--batch-size",
str(args.num_rescore),
"--gen-subset",
args.gen_subset,
"--source-lang",
args.source_lang,
"--target-lang",
args.target_lang,
]
if args.sampling:
param1 += ["--sampling"]
@ -110,124 +157,229 @@ def gen_and_reprocess_nbest(args):
input_args = options.parse_args_and_arch(gen_parser, param1)
print(input_args)
with open(predictions_bpe_file, 'w') as f:
with open(predictions_bpe_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
nbest=using_nbest, prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac)
gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file,
bpe_symbol=args.remove_bpe,
nbest=using_nbest,
prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac,
)
if args.diff_bpe:
rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang,
pre_gen+"/target_gen_bpe."+args.target_lang,
pre_gen+"/reference_gen_bpe."+args.target_lang)
rerank_utils.write_reprocessed(
gen_output.no_bpe_source,
gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
pre_gen + "/source_gen_bpe." + args.source_lang,
pre_gen + "/target_gen_bpe." + args.target_lang,
pre_gen + "/reference_gen_bpe." + args.target_lang,
)
bitext_bpe = args.rescore_bpe_code
bpe_src_param = ["-c", bitext_bpe,
"--input", pre_gen+"/source_gen_bpe."+args.source_lang,
"--output", pre_gen+"/rescore_data."+args.source_lang]
bpe_tgt_param = ["-c", bitext_bpe,
"--input", pre_gen+"/target_gen_bpe."+args.target_lang,
"--output", pre_gen+"/rescore_data."+args.target_lang]
bpe_src_param = [
"-c",
bitext_bpe,
"--input",
pre_gen + "/source_gen_bpe." + args.source_lang,
"--output",
pre_gen + "/rescore_data." + args.source_lang,
]
bpe_tgt_param = [
"-c",
bitext_bpe,
"--input",
pre_gen + "/target_gen_bpe." + args.target_lang,
"--output",
pre_gen + "/rescore_data." + args.target_lang,
]
subprocess.call(["python",
os.path.join(os.path.dirname(__file__),
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
shell=False)
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_src_param,
shell=False,
)
subprocess.call(["python",
os.path.join(os.path.dirname(__file__),
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param,
shell=False)
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_tgt_param,
shell=False,
)
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \
(args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen):
print("STEP 2: process the output of generate.py so we have clean text files with the translations")
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
args.score_model2 is not None
and not os.path.isfile(score2_file)
and not rerank2_is_gen
):
print(
"STEP 2: process the output of generate.py so we have clean text files with the translations"
)
rescore_file = "/rescore_data"
if args.prefix_len is not None:
prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len)
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
if args.target_prefix_frac is not None:
target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac)
target_prefix_frac_rescore_file = (
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
)
if args.source_prefix_frac is not None:
source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac)
source_prefix_frac_rescore_file = (
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
)
if not args.right_to_left1 or not args.right_to_left2:
if not args.diff_bpe:
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+rescore_file+"."+args.source_lang,
pre_gen+rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe)
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen + rescore_file + "." + args.source_lang,
pre_gen + rescore_file + "." + args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
)
if args.prefix_len is not None:
bw_rescore_file = prefix_len_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+prefix_len_rescore_file+"."+args.source_lang,
pre_gen+prefix_len_rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", prefix_len=args.prefix_len,
bpe_symbol=args.remove_bpe)
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen + prefix_len_rescore_file + "." + args.source_lang,
pre_gen + prefix_len_rescore_file + "." + args.target_lang,
pre_gen + "/reference_file",
prefix_len=args.prefix_len,
bpe_symbol=args.remove_bpe,
)
elif args.target_prefix_frac is not None:
bw_rescore_file = target_prefix_frac_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang,
pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
target_prefix_frac=args.target_prefix_frac)
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen
+ target_prefix_frac_rescore_file
+ "."
+ args.source_lang,
pre_gen
+ target_prefix_frac_rescore_file
+ "."
+ args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
target_prefix_frac=args.target_prefix_frac,
)
else:
bw_rescore_file = rescore_file
if args.source_prefix_frac is not None:
fw_rescore_file = source_prefix_frac_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang,
pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
source_prefix_frac=args.source_prefix_frac)
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen
+ source_prefix_frac_rescore_file
+ "."
+ args.source_lang,
pre_gen
+ source_prefix_frac_rescore_file
+ "."
+ args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
source_prefix_frac=args.source_prefix_frac,
)
else:
fw_rescore_file = rescore_file
if args.right_to_left1 or args.right_to_left2:
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+"/right_to_left_rescore_data."+args.source_lang,
pre_gen+"/right_to_left_rescore_data."+args.target_lang,
pre_gen+"/right_to_left_reference_file",
right_to_left=True, bpe_symbol=args.remove_bpe)
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen + "/right_to_left_rescore_data." + args.source_lang,
pre_gen + "/right_to_left_rescore_data." + args.target_lang,
pre_gen + "/right_to_left_reference_file",
right_to_left=True,
bpe_symbol=args.remove_bpe,
)
print("STEP 3: binarize the translations")
if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen:
if (
not args.right_to_left1
or args.score_model2 is not None
and not args.right_to_left2
or not rerank1_is_gen
):
if args.backwards1 or args.backwards2:
if args.backwards_score_dict_dir is not None:
bw_dict = args.backwards_score_dict_dir
else:
bw_dict = args.score_dict_dir
bw_preprocess_param = ["--source-lang", scorer1_src,
"--target-lang", scorer1_tgt,
"--trainpref", pre_gen+bw_rescore_file,
"--srcdict", bw_dict + "/dict." + scorer1_src + ".txt",
"--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt",
"--destdir", backwards_preprocessed_dir]
bw_preprocess_param = [
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
"--trainpref",
pre_gen + bw_rescore_file,
"--srcdict",
bw_dict + "/dict." + scorer1_src + ".txt",
"--tgtdict",
bw_dict + "/dict." + scorer1_tgt + ".txt",
"--destdir",
backwards_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(bw_preprocess_param)
preprocess.main(input_args)
preprocess_param = ["--source-lang", scorer1_src,
"--target-lang", scorer1_tgt,
"--trainpref", pre_gen+fw_rescore_file,
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
"--destdir", left_to_right_preprocessed_dir]
preprocess_param = [
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
"--trainpref",
pre_gen + fw_rescore_file,
"--srcdict",
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
"--tgtdict",
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
"--destdir",
left_to_right_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
if args.right_to_left1 or args.right_to_left2:
preprocess_param = ["--source-lang", scorer1_src,
"--target-lang", scorer1_tgt,
"--trainpref", pre_gen+"/right_to_left_rescore_data",
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
"--destdir", right_to_left_preprocessed_dir]
preprocess_param = [
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
"--trainpref",
pre_gen + "/right_to_left_rescore_data",
"--srcdict",
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
"--tgtdict",
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
"--destdir",
right_to_left_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
@ -241,5 +393,5 @@ def cli_main():
gen_and_reprocess_nbest(args)
if __name__ == '__main__':
if __name__ == "__main__":
cli_main()

View File

@ -6,14 +6,14 @@
from fairseq import options
def get_reranking_parser(default_task='translation'):
parser = options.get_parser('Generation and reranking', default_task)
def get_reranking_parser(default_task="translation"):
parser = options.get_parser("Generation and reranking", default_task)
add_reranking_args(parser)
return parser
def get_tuning_parser(default_task='translation'):
parser = options.get_parser('Reranking tuning', default_task)
def get_tuning_parser(default_task="translation"):
parser = options.get_parser("Reranking tuning", default_task)
add_reranking_args(parser)
add_tuning_args(parser)
return parser
@ -110,17 +110,40 @@ def add_reranking_args(parser):
def add_tuning_args(parser):
group = parser.add_argument_group("Tuning")
group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float,
help='lower bound of search space')
group.add_argument('--upper-bound', default=[3], nargs='+', type=float,
help='upper bound of search space')
group.add_argument('--tune-param', default=['lenpen'], nargs='+',
choices=['lenpen', 'weight1', 'weight2', 'weight3'],
help='the parameter(s) to tune')
group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'],
help='the subset to tune on ')
group.add_argument('--num-trials', default=1000, type=int,
help='number of trials to do for random search')
group.add_argument('--share-weights', action='store_true',
help='share weight2 and weight 3')
group.add_argument(
"--lower-bound",
default=[-0.7],
nargs="+",
type=float,
help="lower bound of search space",
)
group.add_argument(
"--upper-bound",
default=[3],
nargs="+",
type=float,
help="upper bound of search space",
)
group.add_argument(
"--tune-param",
default=["lenpen"],
nargs="+",
choices=["lenpen", "weight1", "weight2", "weight3"],
help="the parameter(s) to tune",
)
group.add_argument(
"--tune-subset",
default="valid",
choices=["valid", "test", "train"],
help="the subset to tune on ",
)
group.add_argument(
"--num-trials",
default=1000,
type=int,
help="number of trials to do for random search",
)
group.add_argument(
"--share-weights", action="store_true", help="share weight2 and weight 3"
)
return group

View File

@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import redirect_stdout
import os
from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import generate
@ -13,82 +13,124 @@ from . import rerank_options, rerank_utils
def score_bw(args):
if args.backwards1:
scorer1_src = args.target_lang
scorer1_tgt = args.source_lang
if args.backwards1:
scorer1_src = args.target_lang
scorer1_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
if args.score_model2 is not None:
if args.backwards2:
scorer2_src = args.target_lang
scorer2_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
scorer2_src = args.source_lang
scorer2_tgt = args.target_lang
if args.score_model2 is not None:
if args.backwards2:
scorer2_src = args.target_lang
scorer2_tgt = args.source_lang
else:
scorer2_src = args.source_lang
scorer2_tgt = args.target_lang
rerank1_is_gen = (
args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, args.shard_id, args.num_shards,
args.sampling, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
score1_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2)
if args.right_to_left1:
rerank_data1 = right_to_left_preprocessed_dir
elif args.backwards1:
rerank_data1 = backwards_preprocessed_dir
else:
rerank_data1 = left_to_right_preprocessed_dir
if args.right_to_left1:
rerank_data1 = right_to_left_preprocessed_dir
elif args.backwards1:
rerank_data1 = backwards_preprocessed_dir
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
if not rerank1_is_gen and not os.path.isfile(score1_file):
print("STEP 4: score the translations for model 1")
model_param1 = [
"--path",
args.score_model1,
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
]
gen_model1_param = [rerank_data1] + gen_param + model_param1
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
with open(score1_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
if (
args.score_model2 is not None
and not os.path.isfile(score2_file)
and not rerank2_is_gen
):
print("STEP 4: score the translations for model 2")
if args.right_to_left2:
rerank_data2 = right_to_left_preprocessed_dir
elif args.backwards2:
rerank_data2 = backwards_preprocessed_dir
else:
rerank_data1 = left_to_right_preprocessed_dir
rerank_data2 = left_to_right_preprocessed_dir
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
if not rerank1_is_gen and not os.path.isfile(score1_file):
print("STEP 4: score the translations for model 1")
model_param2 = [
"--path",
args.score_model2,
"--source-lang",
scorer2_src,
"--target-lang",
scorer2_tgt,
]
gen_model2_param = [rerank_data2] + gen_param + model_param2
model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt]
gen_model1_param = [rerank_data1] + gen_param + model_param1
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
with open(score1_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
print("STEP 4: score the translations for model 2")
if args.right_to_left2:
rerank_data2 = right_to_left_preprocessed_dir
elif args.backwards2:
rerank_data2 = backwards_preprocessed_dir
else:
rerank_data2 = left_to_right_preprocessed_dir
model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
gen_model2_param = [rerank_data2] + gen_param + model_param2
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
with open(score2_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
with open(score2_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
def cli_main():
@ -97,5 +139,5 @@ def cli_main():
score_bw(args)
if __name__ == '__main__':
if __name__ == "__main__":
cli_main()

View File

@ -12,22 +12,38 @@ from . import rerank_options, rerank_utils
def score_lm(args):
using_nbest = args.nbest_list is not None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, args.shard_id, args.num_shards,
args.sampling, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest)
gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
)
if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
lm_score_file = rerank_utils.rescore_file_name(
pre_gen, args.prefix_len, args.lm_name, lm_file=True
)
if args.language_model is not None and not os.path.isfile(lm_score_file):
print("STEP 4.5: language modeling for P(T)")
@ -38,10 +54,21 @@ def score_lm(args):
else:
bpe_status = "different"
rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen,
args.lm_dict, args.lm_name, args.language_model,
args.lm_bpe_code, 128, lm_score_file, args.target_lang,
args.source_lang, prefix_len=args.prefix_len)
rerank_utils.lm_scoring(
lm_preprocessed_dir,
bpe_status,
gen_output,
pre_gen,
args.lm_dict,
args.lm_name,
args.language_model,
args.lm_bpe_code,
128,
lm_score_file,
args.target_lang,
args.source_lang,
prefix_len=args.prefix_len,
)
def cli_main():
@ -50,5 +77,5 @@ def cli_main():
score_lm(args)
if __name__ == '__main__':
if __name__ == "__main__":
cli_main()

View File

@ -5,8 +5,8 @@
import argparse
import random
import numpy as np
import numpy as np
from fairseq import options
from . import rerank, rerank_options
@ -14,7 +14,7 @@ from . import rerank, rerank_options
def random_search(args):
param_values = []
tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3']
tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
for i, elem in enumerate(initial_params):
if type(elem) is not list:
@ -33,51 +33,60 @@ def random_search(args):
param_values += initial_params
random.seed(args.seed)
random_params = np.array([
[random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))]
for k in range(args.num_trials)
])
set_params = np.array([
[initial_params[i][0] for i in range(len(tuneable_parameters))]
for k in range(args.num_trials)
])
random_params = np.array(
[
[
random.uniform(args.lower_bound[i], args.upper_bound[i])
for i in range(len(args.tune_param))
]
for k in range(args.num_trials)
]
)
set_params = np.array(
[
[initial_params[i][0] for i in range(len(tuneable_parameters))]
for k in range(args.num_trials)
]
)
random_params = np.concatenate((random_params, set_params), 1)
rerank_args = vars(args).copy()
if args.nbest_list:
rerank_args['gen_subset'] = 'test'
rerank_args["gen_subset"] = "test"
else:
rerank_args['gen_subset'] = args.tune_subset
rerank_args["gen_subset"] = args.tune_subset
for k in range(len(tune_parameters)):
rerank_args[tune_parameters[k]] = list(random_params[:, k])
if args.share_weights:
k = tune_parameters.index('weight2')
rerank_args['weight3'] = list(random_params[:, k])
k = tune_parameters.index("weight2")
rerank_args["weight3"] = list(random_params[:, k])
rerank_args = argparse.Namespace(**rerank_args)
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args)
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
rerank_args
)
rerank_args = vars(args).copy()
rerank_args['lenpen'] = [best_lenpen]
rerank_args['weight1'] = [best_weight1]
rerank_args['weight2'] = [best_weight2]
rerank_args['weight3'] = [best_weight3]
rerank_args["lenpen"] = [best_lenpen]
rerank_args["weight1"] = [best_weight1]
rerank_args["weight2"] = [best_weight2]
rerank_args["weight3"] = [best_weight3]
# write the hypothesis from the valid set from the best trial
if args.gen_subset != "valid":
rerank_args['gen_subset'] = "valid"
rerank_args["gen_subset"] = "valid"
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
# test with the best hyperparameters on gen subset
rerank_args = vars(args).copy()
rerank_args['gen_subset'] = args.gen_subset
rerank_args['lenpen'] = [best_lenpen]
rerank_args['weight1'] = [best_weight1]
rerank_args['weight2'] = [best_weight2]
rerank_args['weight3'] = [best_weight3]
rerank_args["gen_subset"] = args.gen_subset
rerank_args["lenpen"] = [best_lenpen]
rerank_args["weight1"] = [best_weight1]
rerank_args["weight2"] = [best_weight2]
rerank_args["weight3"] = [best_weight3]
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
@ -89,5 +98,5 @@ def cli_main():
random_search(args)
if __name__ == '__main__':
if __name__ == "__main__":
cli_main()

View File

@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import redirect_stdout
import math
import os
import re
import subprocess
from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import eval_lm, preprocess
@ -20,7 +20,7 @@ def reprocess(fle):
# per source, so the values for hypothesis_dict are lists.
# parses output of generate.py
with open(fle, 'r') as f:
with open(fle, "r") as f:
txt = f.read()
"""reprocess generate.py output"""
@ -45,7 +45,9 @@ def reprocess(fle):
if line_type == "H":
h_txt = line[j:]
hypo = re.search(hp, h_txt)
assert hypo is not None, ("regular expression failed to find the hypothesis scoring")
assert (
hypo is not None
), "regular expression failed to find the hypothesis scoring"
_, i = hypo.span()
score = hypo.group()
if id_num in hypothesis_dict:
@ -56,9 +58,9 @@ def reprocess(fle):
score_dict[id_num] = [float(score)]
elif line_type == "S":
source_dict[id_num] = (line[j:])
source_dict[id_num] = line[j:]
elif line_type == "T":
target_dict[id_num] = (line[j:])
target_dict[id_num] = line[j:]
elif line_type == "P":
pos_scores = (line[j:]).split()
pos_scores = [float(x) for x in pos_scores]
@ -72,7 +74,7 @@ def reprocess(fle):
def reprocess_nbest(fle):
"""reprocess interactive.py output"""
with open(fle, 'r') as f:
with open(fle, "r") as f:
txt = f.read()
source_dict = {}
@ -82,7 +84,7 @@ def reprocess_nbest(fle):
pos_score_dict = {}
lines = txt.split("\n")
hp = re.compile(r'[-]?\d+[.]?\d+')
hp = re.compile(r"[-]?\d+[.]?\d+")
j = -1
for _i, line in enumerate(lines):
@ -119,59 +121,76 @@ def reprocess_nbest(fle):
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
def write_reprocessed(sources, hypos, targets, source_outfile,
hypo_outfile, target_outfile, right_to_left=False,
prefix_len=None, bpe_symbol=None,
target_prefix_frac=None, source_prefix_frac=None):
def write_reprocessed(
sources,
hypos,
targets,
source_outfile,
hypo_outfile,
target_outfile,
right_to_left=False,
prefix_len=None,
bpe_symbol=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
"""writes nbest hypothesis for rescoring"""
assert not (prefix_len is not None and target_prefix_frac is not None), \
"in writing reprocessed, only one type of prefix may be used"
assert not (prefix_len is not None and source_prefix_frac is not None), \
"in writing reprocessed, only one type of prefix may be used"
assert not (target_prefix_frac is not None and source_prefix_frac is not None), \
"in writing reprocessed, only one type of prefix may be used"
assert not (
prefix_len is not None and target_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
assert not (
prefix_len is not None and source_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
assert not (
target_prefix_frac is not None and source_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
with open(source_outfile, 'w') as source_file, \
open(hypo_outfile, 'w') as hypo_file, \
open(target_outfile, 'w') as target_file:
with open(source_outfile, "w") as source_file, open(
hypo_outfile, "w"
) as hypo_file, open(target_outfile, "w") as target_file:
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
if right_to_left:
for i in range(len(sources)):
for j in range(len(hypos[i])):
if prefix_len is None:
hypo_file.write(make_right_to_left(hypos[i][j])+"\n")
else:
raise NotImplementedError()
source_file.write(make_right_to_left(sources[i])+"\n")
target_file.write(make_right_to_left(targets[i])+"\n")
for j in range(len(hypos[i])):
if prefix_len is None:
hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
else:
raise NotImplementedError()
source_file.write(make_right_to_left(sources[i]) + "\n")
target_file.write(make_right_to_left(targets[i]) + "\n")
else:
for i in sorted(sources.keys()):
for j in range(len(hypos[i])):
if prefix_len is not None:
shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n"
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif target_prefix_frac is not None:
num_words, shortened, num_bpe_tokens = \
calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol)
shortened += "\n"
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif source_prefix_frac is not None:
num_words, shortened, num_bpe_tokensn = \
calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol)
shortened += "\n"
hypo_file.write(hypos[i][j])
source_file.write(shortened)
target_file.write(targets[i])
else:
hypo_file.write(hypos[i][j])
source_file.write(sources[i])
target_file.write(targets[i])
for j in range(len(hypos[i])):
if prefix_len is not None:
shortened = (
get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
+ "\n"
)
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif target_prefix_frac is not None:
num_words, shortened, num_bpe_tokens = calc_length_from_frac(
hypos[i][j], target_prefix_frac, bpe_symbol
)
shortened += "\n"
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif source_prefix_frac is not None:
num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
sources[i], source_prefix_frac, bpe_symbol
)
shortened += "\n"
hypo_file.write(hypos[i][j])
source_file.write(shortened)
target_file.write(targets[i])
else:
hypo_file.write(hypos[i][j])
source_file.write(sources[i])
target_file.write(targets[i])
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
@ -207,7 +226,9 @@ def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
if bpe_count == 0:
return sentence[:prefix_len]
else:
return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count)
return sentence[:prefix_len] + get_prefix_from_len(
sentence[prefix_len:], bpe_symbol, bpe_count
)
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
@ -225,9 +246,9 @@ def make_right_to_left(line):
def remove_bpe(line, bpe_symbol):
line = line.replace("\n", '')
line = (line + ' ').replace(bpe_symbol, '').rstrip()
return line+("\n")
line = line.replace("\n", "")
line = (line + " ").replace(bpe_symbol, "").rstrip()
return line + ("\n")
def remove_bpe_dict(pred_dict, bpe_symbol):
@ -242,7 +263,7 @@ def remove_bpe_dict(pred_dict, bpe_symbol):
def parse_bleu_scoring(line):
p = re.compile(r'(BLEU4 = )\d+[.]\d+')
p = re.compile(r"(BLEU4 = )\d+[.]\d+")
res = re.search(p, line)
assert res is not None, line
return float(res.group()[8:])
@ -259,9 +280,21 @@ def get_full_from_prefix(hypo_prefix, hypos):
raise Exception()
def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None,
lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False,
bitext2_backwards=False, normalize=False):
def get_score(
a,
b,
c,
target_len,
bitext_score1,
bitext_score2=None,
lm_score=None,
lenpen=None,
src_len=None,
tgt_len=None,
bitext1_backwards=False,
bitext2_backwards=False,
normalize=False,
):
if bitext1_backwards:
bitext1_norm = src_len
else:
@ -275,9 +308,13 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
bitext2_norm = 1
bitext_score2 = 0
if normalize:
score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len
score = (
a * bitext_score1 / bitext1_norm
+ b * bitext_score2 / bitext2_norm
+ c * lm_score / src_len
)
else:
score = a*bitext_score1 + b*bitext_score2+c*lm_score
score = a * bitext_score1 + b * bitext_score2 + c * lm_score
if lenpen is not None:
score /= (target_len) ** float(lenpen)
@ -286,8 +323,16 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
class BitextOutput(object):
def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
prefix_len=None, target_prefix_frac=None, source_prefix_frac=None):
def __init__(
self,
output_file,
backwards,
right_to_left,
bpe_symbol,
prefix_len=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
"""process output from rescoring"""
source, hypo, score, target, pos_score = reprocess(output_file)
if backwards:
@ -296,7 +341,9 @@ class BitextOutput(object):
self.hypo_fracs = target_prefix_frac
# remove length penalty so we can use raw scores
score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards)
score, num_bpe_tokens = get_score_from_pos(
pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
)
source_lengths = {}
target_lengths = {}
@ -341,7 +388,9 @@ class BitextOutput(object):
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
else:
assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence"
assert (
len(hypo[i]) == 1
), "expected only one hypothesis per source sentence"
source[i] = remove_bpe(source[i], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
@ -360,11 +409,26 @@ class BitextOutput(object):
class BitextOutputFromGen(object):
def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None):
def __init__(
self,
predictions_bpe_file,
bpe_symbol=None,
nbest=False,
prefix_len=None,
target_prefix_frac=None,
):
if nbest:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file)
(
pred_source,
pred_hypo,
pred_score,
pred_target,
pred_pos_score,
) = reprocess_nbest(predictions_bpe_file)
else:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file)
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
predictions_bpe_file
)
assert len(pred_source) == len(pred_hypo)
assert len(pred_source) == len(pred_score)
@ -372,8 +436,9 @@ class BitextOutputFromGen(object):
assert len(pred_source) == len(pred_pos_score)
# remove length penalty so we can use raw scores
pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo,
bpe_symbol, target_prefix_frac, False)
pred_score, num_bpe_tokens = get_score_from_pos(
pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
)
self.source = pred_source
self.target = pred_target
@ -414,7 +479,9 @@ class BitextOutputFromGen(object):
index += 1
def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards):
def get_score_from_pos(
pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
):
score_dict = {}
num_bpe_tokens_dict = {}
assert prefix_len is None or hypo_frac is None
@ -423,11 +490,15 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
num_bpe_tokens_dict[key] = []
for i in range(len(pos_score_dict[key])):
if prefix_len is not None and not backwards:
num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len)
num_bpe_tokens = get_num_bpe_tokens_from_len(
hypo_dict[key][i], bpe_symbol, prefix_len
)
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
num_bpe_tokens_dict[key].append(num_bpe_tokens)
elif hypo_frac is not None:
num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol)
num_words, shortened, hypo_prefix_len = calc_length_from_frac(
hypo_dict[key][i], hypo_frac, bpe_symbol
)
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
num_bpe_tokens_dict[key].append(hypo_prefix_len)
else:
@ -437,10 +508,26 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
class LMOutput(object):
def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \
parse_lm(lm_score_file, prefix_len=prefix_len,
bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac)
def __init__(
self,
lm_score_file,
lm_dict=None,
prefix_len=None,
bpe_symbol=None,
target_prefix_frac=None,
):
(
lm_sentences,
lm_sen_scores,
lm_sen_pos_scores,
lm_no_bpe_sentences,
lm_bpe_tokens,
) = parse_lm(
lm_score_file,
prefix_len=prefix_len,
bpe_symbol=bpe_symbol,
target_prefix_frac=target_prefix_frac,
)
self.sentences = lm_sentences
self.score = lm_sen_scores
@ -452,7 +539,7 @@ class LMOutput(object):
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
"""parse output of eval_lm"""
with open(input_file, 'r') as f:
with open(input_file, "r") as f:
text = f.readlines()
text = text[7:]
cleaned_text = text[:-2]
@ -467,20 +554,23 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
if tokens[0].isdigit():
line_id = int(tokens[0])
scores = [float(x[1:-1]) for x in tokens[2::2]]
sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n"
sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
if bpe_symbol is not None:
# exclude <eos> symbol to match output from generate.py
bpe_sen = " ".join(tokens[1::2][:-1])+"\n"
bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
no_bpe_sentences[line_id] = no_bpe_sen
if prefix_len is not None:
num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len)
num_bpe_tokens = get_num_bpe_tokens_from_len(
bpe_sen, bpe_symbol, prefix_len
)
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
num_bpe_tokens_dict[line_id] = num_bpe_tokens
elif target_prefix_frac is not None:
num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac,
bpe_symbol)
num_words, shortened, target_prefix_len = calc_length_from_frac(
bpe_sen, target_prefix_frac, bpe_symbol
)
sen_scores[line_id] = sum(scores[:target_prefix_len])
num_bpe_tokens_dict[line_id] = target_prefix_len
else:
@ -492,160 +582,269 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
def get_directories(data_dir_name, num_rescore, gen_subset,
fw_name, shard_id, num_shards,
sampling=False, prefix_len=None,
target_prefix_frac=None, source_prefix_frac=None):
nbest_file_id = "nbest_" + str(num_rescore) + \
"_subset_" + gen_subset + \
"_fw_name_" + fw_name + \
"_shard_" + str(shard_id) + \
"_of_" + str(num_shards)
def get_directories(
data_dir_name,
num_rescore,
gen_subset,
fw_name,
shard_id,
num_shards,
sampling=False,
prefix_len=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
nbest_file_id = (
"nbest_"
+ str(num_rescore)
+ "_subset_"
+ gen_subset
+ "_fw_name_"
+ fw_name
+ "_shard_"
+ str(shard_id)
+ "_of_"
+ str(num_shards)
)
if sampling:
nbest_file_id += "_sampling"
# the directory containing all information for this nbest list
pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id
pre_gen = (
os.path.join(os.path.dirname(__file__))
+ "/rerank_data/"
+ data_dir_name
+ "/"
+ nbest_file_id
)
# the directory to store the preprocessed nbest list, for left to right rescoring
left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed"
left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
if source_prefix_frac is not None:
left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
left_to_right_preprocessed_dir = (
left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
)
# the directory to store the preprocessed nbest list, for right to left rescoring
right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed"
right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
# the directory to store the preprocessed nbest list, for backwards rescoring
backwards_preprocessed_dir = pre_gen+"/backwards"
backwards_preprocessed_dir = pre_gen + "/backwards"
if target_prefix_frac is not None:
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac)
backwards_preprocessed_dir = (
backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
)
elif prefix_len is not None:
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len)
backwards_preprocessed_dir = (
backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
)
# the directory to store the preprocessed nbest list, for rescoring with P(T)
lm_preprocessed_dir = pre_gen+"/lm_preprocessed"
lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir
return (
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
)
def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen,
cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code,
batch_size, lm_score_file, target_lang, source_lang, prefix_len=None):
def lm_scoring(
preprocess_directory,
bpe_status,
gen_output,
pre_gen,
cur_lm_dict,
cur_lm_name,
cur_language_model,
cur_lm_bpe_code,
batch_size,
lm_score_file,
target_lang,
source_lang,
prefix_len=None,
):
if prefix_len is not None:
assert bpe_status == "different", "bpe status must be different to use prefix len"
assert (
bpe_status == "different"
), "bpe status must be different to use prefix len"
if bpe_status == "no bpe":
# run lm on output without bpe
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de",
pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe")
write_reprocessed(
gen_output.no_bpe_source,
gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
pre_gen + "/rescore_data_no_bpe.de",
pre_gen + "/rescore_data_no_bpe.en",
pre_gen + "/reference_file_no_bpe",
)
preprocess_lm_param = ["--only-source",
"--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang,
"--srcdict", cur_lm_dict,
"--destdir", preprocess_directory]
preprocess_lm_param = [
"--only-source",
"--trainpref",
pre_gen + "/rescore_data_no_bpe." + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_directory,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_directory,
"--path", cur_language_model,
"--output-word-probs",
"--batch-size", str(batch_size),
"--max-tokens", "1024",
"--sample-break-mode", "eos",
"--gen-subset", "train"]
eval_lm_param = [
preprocess_directory,
"--path",
cur_language_model,
"--output-word-probs",
"--batch-size",
str(batch_size),
"--max-tokens",
"1024",
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f:
with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "shared":
preprocess_lm_param = ["--only-source",
"--trainpref", pre_gen+"/rescore_data."+target_lang,
"--srcdict", cur_lm_dict,
"--destdir", preprocess_directory]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
preprocess_lm_param = [
"--only-source",
"--trainpref",
pre_gen + "/rescore_data." + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_directory,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_directory,
"--path", cur_language_model,
"--output-word-probs",
"--batch-size", str(batch_size),
"--sample-break-mode", "eos",
"--gen-subset", "train"]
eval_lm_param = [
preprocess_directory,
"--path",
cur_language_model,
"--output-word-probs",
"--batch-size",
str(batch_size),
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f:
with redirect_stdout(f):
eval_lm.main(input_args)
with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "different":
rescore_file = pre_gen+"/rescore_data_no_bpe"
rescore_bpe = pre_gen+"/rescore_data_new_bpe"
rescore_file = pre_gen + "/rescore_data_no_bpe"
rescore_bpe = pre_gen + "/rescore_data_new_bpe"
rescore_file += "."
rescore_bpe += "."
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
gen_output.no_bpe_target, rescore_file+source_lang,
rescore_file+target_lang, pre_gen+"/reference_file_no_bpe",
bpe_symbol=None)
write_reprocessed(
gen_output.no_bpe_source,
gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
rescore_file + source_lang,
rescore_file + target_lang,
pre_gen + "/reference_file_no_bpe",
bpe_symbol=None,
)
# apply LM bpe to nbest list
bpe_src_param = ["-c", cur_lm_bpe_code,
"--input", rescore_file+target_lang,
"--output", rescore_bpe+target_lang]
subprocess.call(["python",
os.path.join(os.path.dirname(__file__),
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
shell=False)
bpe_src_param = [
"-c",
cur_lm_bpe_code,
"--input",
rescore_file + target_lang,
"--output",
rescore_bpe + target_lang,
]
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_src_param,
shell=False,
)
# uncomment to use fastbpe instead of subword-nmt bpe
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
preprocess_dir = preprocess_directory
preprocess_lm_param = ["--only-source",
"--trainpref", rescore_bpe+target_lang,
"--srcdict", cur_lm_dict,
"--destdir", preprocess_dir]
preprocess_lm_param = [
"--only-source",
"--trainpref",
rescore_bpe + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_dir,
"--path", cur_language_model,
"--output-word-probs",
"--batch-size", str(batch_size),
"--max-tokens", "1024",
"--sample-break-mode", "eos",
"--gen-subset", "train"]
eval_lm_param = [
preprocess_dir,
"--path",
cur_language_model,
"--output-word-probs",
"--batch-size",
str(batch_size),
"--max-tokens",
"1024",
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f:
with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False,
target_prefix_frac=None, source_prefix_frac=None, backwards=None):
def rescore_file_name(
nbest_dir,
prefix_len,
scorer_name,
lm_file=False,
target_prefix_frac=None,
source_prefix_frac=None,
backwards=None,
):
if lm_file:
score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt"
score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
else:
score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt"
score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
if backwards:
if prefix_len is not None:
score_file += "prefix_len"+str(prefix_len)
score_file += "prefix_len" + str(prefix_len)
elif target_prefix_frac is not None:
score_file += "target_prefix_frac"+str(target_prefix_frac)
score_file += "target_prefix_frac" + str(target_prefix_frac)
else:
if source_prefix_frac is not None:
score_file += "source_prefix_frac"+str(source_prefix_frac)
score_file += "source_prefix_frac" + str(source_prefix_frac)
return score_file

View File

@ -13,57 +13,66 @@ logging.getLogger().setLevel(logging.INFO)
def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--en2fr', required=True,
help='path to en2fr model')
parser.add_argument('--fr2en', required=True,
help='path to fr2en mixture of experts model')
parser.add_argument('--user-dir',
help='path to fairseq examples/translation_moe/src directory')
parser.add_argument('--num-experts', type=int, default=10,
help='(keep at 10 unless using a different model)')
parser.add_argument('files', nargs='*', default=['-'],
help='input files to paraphrase; "-" for stdin')
parser = argparse.ArgumentParser(description="")
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
parser.add_argument(
"--fr2en", required=True, help="path to fr2en mixture of experts model"
)
parser.add_argument(
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
)
parser.add_argument(
"--num-experts",
type=int,
default=10,
help="(keep at 10 unless using a different model)",
)
parser.add_argument(
"files",
nargs="*",
default=["-"],
help='input files to paraphrase; "-" for stdin',
)
args = parser.parse_args()
if args.user_dir is None:
args.user_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
'translation_moe',
'src',
"translation_moe",
"src",
)
if os.path.exists(args.user_dir):
logging.info('found user_dir:' + args.user_dir)
logging.info("found user_dir:" + args.user_dir)
else:
raise RuntimeError(
'cannot find fairseq examples/translation_moe/src '
'(tried looking here: {})'.format(args.user_dir)
"cannot find fairseq examples/translation_moe/src "
"(tried looking here: {})".format(args.user_dir)
)
logging.info('loading en2fr model from:' + args.en2fr)
logging.info("loading en2fr model from:" + args.en2fr)
en2fr = TransformerModel.from_pretrained(
model_name_or_path=args.en2fr,
tokenizer='moses',
bpe='sentencepiece',
tokenizer="moses",
bpe="sentencepiece",
).eval()
logging.info('loading fr2en model from:' + args.fr2en)
logging.info("loading fr2en model from:" + args.fr2en)
fr2en = TransformerModel.from_pretrained(
model_name_or_path=args.fr2en,
tokenizer='moses',
bpe='sentencepiece',
tokenizer="moses",
bpe="sentencepiece",
user_dir=args.user_dir,
task='translation_moe',
task="translation_moe",
).eval()
def gen_paraphrases(en):
fr = en2fr.translate(en)
return [
fr2en.translate(fr, inference_step_args={'expert': i})
fr2en.translate(fr, inference_step_args={"expert": i})
for i in range(args.num_experts)
]
logging.info('Type the input sentence and press return:')
logging.info("Type the input sentence and press return:")
for line in fileinput.input(args.files):
line = line.strip()
if len(line) == 0:
@ -72,5 +81,5 @@ def main():
print(paraphrase)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import re
import argparse
import re
import sys
class OOVIndexError(IndexError):
@ -25,8 +25,8 @@ class OOVIndexError(IndexError):
def replace_oovs(source_in, target_in, target_out):
"""Replaces <unk-N> tokens in the target text with the corresponding word in
the source text.
"""
the source text.
"""
oov_re = re.compile("^<unk-([0-9]+)>$")

View File

@ -10,8 +10,8 @@ from itertools import zip_longest
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
"""Replaces out-of-vocabulary words in source and target text with <unk-N>,
where N in is the position of the word in the source sequence.
"""
where N in is the position of the word in the source sequence.
"""
def format_unk(pos):
return "<unk-{}>".format(pos)

View File

@ -8,19 +8,17 @@ from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from fairseq import utils, metrics
from fairseq import metrics, utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
TransformerModel,
TransformerDecoder,
TransformerEncoder,
base_architecture,
DEFAULT_MAX_SOURCE_POSITIONS,
DEFAULT_MAX_TARGET_POSITIONS,
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture,
)
from torch import Tensor

View File

@ -8,40 +8,44 @@ import os
import numpy as np
import torch
from fairseq.data import (
data_utils,
Dictionary,
encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
NumSamplesDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
data_utils,
encoders,
)
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.tasks import LegacyFairseqTask, register_task
@register_task('commonsense_qa')
@register_task("commonsense_qa")
class CommonsenseQATask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Commonsense QA."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR',
help='path to data directory; we load <split>.jsonl')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
parser.add_argument('--num-classes', type=int, default=5)
parser.add_argument(
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
)
parser.add_argument(
"--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
parser.add_argument("--num-classes", type=int, default=5)
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol('<mask>')
self.mask = vocab.add_symbol("<mask>")
self.bpe = encoders.build_bpe(args)
@ -53,20 +57,24 @@ class CommonsenseQATask(LegacyFairseqTask):
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking'
assert (
args.criterion == "sentence_ranking"
), "Must set --criterion=sentence_ranking"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
@ -77,16 +85,18 @@ class CommonsenseQATask(LegacyFairseqTask):
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=True, add_if_not_exist=False,
s,
append_eos=True,
add_if_not_exist=False,
).long()
if append_bos and self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens
if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
raise FileNotFoundError("Cannot find data: {}".format(data_path))
src_tokens = [[] for i in range(self.args.num_classes)]
src_lengths = [[] for i in range(self.args.num_classes)]
@ -95,20 +105,23 @@ class CommonsenseQATask(LegacyFairseqTask):
with open(data_path) as h:
for line in h:
example = json.loads(line.strip())
if 'answerKey' in example:
label = ord(example['answerKey']) - ord('A')
if "answerKey" in example:
label = ord(example["answerKey"]) - ord("A")
labels.append(label)
question = example['question']['stem']
assert len(example['question']['choices']) == self.args.num_classes
question = example["question"]["stem"]
assert len(example["question"]["choices"]) == self.args.num_classes
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
question = 'Q: ' + question
question = "Q: " + question
question_toks = binarize(question, append_bos=True)
for i, choice in enumerate(example['question']['choices']):
src = 'A: ' + choice['text']
for i, choice in enumerate(example["question"]["choices"]):
src = "A: " + choice["text"]
src_bin = torch.cat([question_toks, binarize(src)])
src_tokens[i].append(src_bin)
src_lengths[i].append(len(src_bin))
assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
assert all(
len(src_tokens[0]) == len(src_tokens[i])
for i in range(self.args.num_classes)
)
assert len(src_tokens[0]) == len(src_lengths[0])
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
@ -118,24 +131,26 @@ class CommonsenseQATask(LegacyFairseqTask):
src_lengths[i] = ListDataset(src_lengths[i])
dataset = {
'id': IdDataset(),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens[0], reduce=True),
"id": IdDataset(),
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens[0], reduce=True),
}
for i in range(self.args.num_classes):
dataset.update({
'net_input{}'.format(i + 1): {
'src_tokens': RightPadDataset(
src_tokens[i],
pad_idx=self.source_dictionary.pad(),
),
'src_lengths': src_lengths[i],
dataset.update(
{
"net_input{}".format(i + 1): {
"src_tokens": RightPadDataset(
src_tokens[i],
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths[i],
}
}
})
)
if len(labels) > 0:
dataset.update({'target': RawLabelDataset(labels)})
dataset.update({"target": RawLabelDataset(labels)})
dataset = NestedDictionaryDataset(
dataset,
@ -149,17 +164,18 @@ class CommonsenseQATask(LegacyFairseqTask):
sort_order=[np.random.permutation(len(dataset))],
)
print('| Loaded {} with {} samples'.format(split, len(dataset)))
print("| Loaded {} with {} samples".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
model.register_classification_head(
'sentence_classification_head',
"sentence_classification_head",
num_classes=1,
)

View File

@ -8,7 +8,6 @@
import argparse
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool
@ -26,23 +25,23 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
help='path to encoder.json',
help="path to encoder.json",
)
parser.add_argument(
"--vocab-bpe",
type=str,
help='path to vocab.bpe',
help="path to vocab.bpe",
)
parser.add_argument(
"--inputs",
nargs="+",
default=['-'],
default=["-"],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
default=['-'],
default=["-"],
help="path to save encoded outputs",
)
parser.add_argument(
@ -53,18 +52,21 @@ def main():
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \
"number of input and output paths should match"
assert len(args.inputs) == len(
args.outputs
), "number of input and output paths should match"
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin
if input != "-"
else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-" else sys.stdout
if output != "-"
else sys.stdout
for output in args.outputs
]
@ -87,7 +89,6 @@ def main():
class MultiprocessingEncoder(object):
def __init__(self, args):
self.args = args

View File

@ -25,7 +25,7 @@ def get_examples(data_dir, set_type):
examples = []
levels = ["middle", "high"]
set_type_c = set_type.split('-')
set_type_c = set_type.split("-")
if len(set_type_c) == 2:
levels = [set_type_c[1]]
set_type = set_type_c[0]
@ -33,13 +33,13 @@ def get_examples(data_dir, set_type):
cur_dir = os.path.join(data_dir, set_type, level)
for filename in os.listdir(cur_dir):
cur_path = os.path.join(cur_dir, filename)
with open(cur_path, 'r') as f:
with open(cur_path, "r") as f:
cur_data = json.load(f)
answers = cur_data["answers"]
options = cur_data["options"]
questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ")
context = re.sub(r'\s+', ' ', context)
context = re.sub(r"\s+", " ", context)
for i in range(len(answers)):
label = ord(answers[i]) - ord("A")
qa_list = []
@ -50,7 +50,7 @@ def get_examples(data_dir, set_type):
qa_cat = question.replace("_", option)
else:
qa_cat = " ".join([question, option])
qa_cat = re.sub(r'\s+', ' ', qa_cat)
qa_cat = re.sub(r"\s+", " ", qa_cat)
qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label))
@ -64,11 +64,11 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
help='input directory for downloaded RACE dataset',
help="input directory for downloaded RACE dataset",
)
parser.add_argument(
"--output-dir",
help='output directory for extracted data',
help="output directory for extracted data",
)
args = parser.parse_args()
@ -77,17 +77,20 @@ def main():
for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type)
qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
qa_file_paths = [
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
for i in range(4)
]
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
outf_context = open(outf_context_path, 'w')
outf_label = open(outf_label_path, 'w')
outf_context = open(outf_context_path, "w")
outf_label = open(outf_label_path, "w")
for example in examples:
outf_context.write(example.paragraph + '\n')
outf_context.write(example.paragraph + "\n")
for i in range(4):
qa_files[i].write(example.qa_list[i] + '\n')
outf_label.write(str(example.label) + '\n')
qa_files[i].write(example.qa_list[i] + "\n")
outf_label.write(str(example.label) + "\n")
for f in qa_files:
f.close()
@ -95,5 +98,5 @@ def main():
outf_context.close()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -7,19 +7,17 @@ import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data import encoders
@register_criterion('wsc')
@register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, 'w')
self.prediction_h = open(self.args.save_predictions, "w")
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args)
@ -32,12 +30,16 @@ class WSCCriterion(LegacyFairseqCriterion):
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
parser.add_argument('--wsc-cross-entropy', action='store_true',
help='use cross entropy formulation instead of margin loss')
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
parser.add_argument(
"--wsc-cross-entropy",
action="store_true",
help="use cross entropy formulation instead of margin loss",
)
parser.add_argument(
"--save-predictions", metavar="FILE", help="file to save predictions to"
)
def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone()
@ -60,27 +62,26 @@ class WSCCriterion(LegacyFairseqCriterion):
)
else:
return (
- query_lprobs
+ self.args.wsc_margin_alpha * (
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
).clamp(min=0)
-query_lprobs
+ self.args.wsc_margin_alpha
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
).sum()
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
loss, nloss = 0., 0
loss, nloss = 0.0, 0
ncorrect, nqueries = 0, 0
for i, label in enumerate(sample['labels']):
for i, label in enumerate(sample["labels"]):
query_lprobs = self.get_lprobs(
model,
sample['query_tokens'][i].unsqueeze(0),
sample['query_masks'][i].unsqueeze(0),
sample["query_tokens"][i].unsqueeze(0),
sample["query_masks"][i].unsqueeze(0),
)
cand_lprobs = self.get_lprobs(
model,
sample['candidate_tokens'][i],
sample['candidate_masks'][i],
sample["candidate_tokens"][i],
sample["candidate_masks"][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
@ -95,72 +96,72 @@ class WSCCriterion(LegacyFairseqCriterion):
nloss += 1
loss += self.get_loss(query_lprobs, cand_lprobs)
id = sample['id'][i].item()
id = sample["id"][i].item()
if self.prediction_h is not None:
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
'ncorrect': ncorrect,
'nqueries': nqueries,
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ncorrect": ncorrect,
"nqueries": nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
"loss": loss_sum / sample_size / math.log(2),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
if nqueries > 0:
agg_output['accuracy'] = ncorrect / float(nqueries)
agg_output["accuracy"] = ncorrect / float(nqueries)
return agg_output
@register_criterion('winogrande')
@register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
query_lprobs = self.get_lprobs(
model,
sample['query_tokens'],
sample['query_masks'],
sample["query_tokens"],
sample["query_masks"],
)
cand_lprobs = self.get_lprobs(
model,
sample['candidate_tokens'],
sample['candidate_masks'],
sample["candidate_tokens"],
sample["candidate_masks"],
)
pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs)
sample_size = sample['query_tokens'].size(0)
sample_size = sample["query_tokens"].size(0)
ncorrect = pred.sum().item()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
'ncorrect': ncorrect,
'nqueries': sample_size,
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ncorrect": ncorrect,
"nqueries": sample_size,
}
return loss, sample_size, logging_output

View File

@ -10,47 +10,51 @@ import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import (
data_utils,
Dictionary,
encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
SortDataset,
data_utils,
encoders,
)
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.tasks import LegacyFairseqTask, register_task
from . import wsc_utils
@register_task('wsc')
@register_task("wsc")
class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR',
help='path to data directory; we load <split>.jsonl')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
parser.add_argument(
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
)
parser.add_argument(
"--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol('<mask>')
self.mask = vocab.add_symbol("<mask>")
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces
if args.bpe == 'gpt2':
if args.bpe == "gpt2":
self.leading_space = True
self.trailing_space = False
else:
@ -65,16 +69,16 @@ class WSCTask(LegacyFairseqTask):
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'wsc', 'Must set --criterion=wsc'
assert args.criterion == "wsc", "Must set --criterion=wsc"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
@ -84,7 +88,9 @@ class WSCTask(LegacyFairseqTask):
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=append_eos, add_if_not_exist=False,
s,
append_eos=append_eos,
add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
@ -98,19 +104,21 @@ class WSCTask(LegacyFairseqTask):
mask = torch.zeros_like(toks, dtype=torch.bool)
mask_start = len(self.binarize(prefix))
mask_size = len(self.binarize(leading_space + txt))
mask[mask_start:mask_start + mask_size] = 1
mask[mask_start : mask_start + mask_size] = 1
return toks, mask
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
@ -121,13 +129,15 @@ class WSCTask(LegacyFairseqTask):
labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
prefix = sentence[:pronoun_span.start].text
suffix = sentence[pronoun_span.end:].text_with_ws
prefix = sentence[: pronoun_span.start].text
suffix = sentence[pronoun_span.end :].text_with_ws
# spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE
leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
leading_space = (
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
)
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
# get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks(
@ -152,7 +162,11 @@ class WSCTask(LegacyFairseqTask):
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = self.binarize_with_mask(
cand_span.text, prefix, suffix, leading_space, trailing_space,
cand_span.text,
prefix,
suffix,
leading_space,
trailing_space,
)
cand_toks.append(toks)
cand_masks.append(mask)
@ -176,17 +190,17 @@ class WSCTask(LegacyFairseqTask):
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
labels = ListDataset(labels, [1]*len(labels))
labels = ListDataset(labels, [1] * len(labels))
dataset = {
'id': IdDataset(),
'query_tokens': query_tokens,
'query_masks': query_masks,
'candidate_tokens': candidate_tokens,
'candidate_masks': candidate_masks,
'labels': labels,
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True),
"id": IdDataset(),
"query_tokens": query_tokens,
"query_masks": query_masks,
"candidate_tokens": candidate_tokens,
"candidate_masks": candidate_masks,
"labels": labels,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
@ -210,9 +224,9 @@ class WSCTask(LegacyFairseqTask):
def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h:
h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
dataset = self.load_dataset(
'disambiguate_pronoun',
"disambiguate_pronoun",
data_path=h.name,
return_only=True,
)
@ -239,19 +253,19 @@ class WSCTask(LegacyFairseqTask):
return scores
cand_lprobs = get_lprobs(
sample['candidate_tokens'][0],
sample['candidate_masks'][0],
sample["candidate_tokens"][0],
sample["candidate_masks"][0],
)
if sample['query_tokens'][0] is not None:
if sample["query_tokens"][0] is not None:
query_lprobs = get_lprobs(
sample['query_tokens'][0].unsqueeze(0),
sample['query_masks'][0].unsqueeze(0),
sample["query_tokens"][0].unsqueeze(0),
sample["query_masks"][0].unsqueeze(0),
)
return (query_lprobs >= cand_lprobs).all().item() == 1
else:
best_idx = cand_lprobs.argmax().item()
full_cand = sample['candidate_tokens'][0][best_idx]
mask = sample['candidate_masks'][0][best_idx]
full_cand = sample["candidate_tokens"][0][best_idx]
mask = sample["candidate_masks"][0][best_idx]
toks = full_cand[mask.bool()]
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@ -264,7 +278,7 @@ class WSCTask(LegacyFairseqTask):
return self.vocab
@register_task('winogrande')
@register_task("winogrande")
class WinograndeTask(WSCTask):
"""
Task for WinoGrande dataset. Efficient implementation for Winograd schema
@ -273,24 +287,26 @@ class WinograndeTask(WSCTask):
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'winogrande', 'Must set --criterion=winogrande'
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
@ -299,19 +315,23 @@ class WinograndeTask(WSCTask):
candidate_masks = []
candidate_lengths = []
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test'))
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
for sample in itr:
sentence, pronoun_span, query, cand_text = sample
prefix = sentence[:pronoun_span[0]].rstrip()
suffix = sentence[pronoun_span[1]:]
prefix = sentence[: pronoun_span[0]].rstrip()
suffix = sentence[pronoun_span[1] :]
leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else ''
trailing_space = ''
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
trailing_space = ""
if query is not None:
query_toks, query_mask = self.binarize_with_mask(
query, prefix, suffix, leading_space, trailing_space,
query,
prefix,
suffix,
leading_space,
trailing_space,
)
query_len = len(query_toks)
else:
@ -322,7 +342,11 @@ class WinograndeTask(WSCTask):
query_lengths.append(query_len)
cand_toks, cand_mask = self.binarize_with_mask(
cand_text, prefix, suffix, leading_space, trailing_space,
cand_text,
prefix,
suffix,
leading_space,
trailing_space,
)
candidate_tokens.append(cand_toks)
@ -342,17 +366,19 @@ class WinograndeTask(WSCTask):
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad())
candidate_tokens = get_pad_dataset_fn(
candidate_tokens, candidate_lengths, self.vocab.pad()
)
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
dataset = {
'id': IdDataset(),
'query_tokens': query_tokens,
'query_masks': query_masks,
'candidate_tokens': candidate_tokens,
'candidate_masks': candidate_masks,
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True),
"id": IdDataset(),
"query_tokens": query_tokens,
"query_masks": query_masks,
"candidate_tokens": candidate_tokens,
"candidate_masks": candidate_masks,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(

View File

@ -3,48 +3,48 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import json
from functools import lru_cache
def convert_sentence_to_json(sentence):
if '_' in sentence:
prefix, rest = sentence.split('_', 1)
query, rest = rest.split('_', 1)
query_index = len(prefix.rstrip().split(' '))
if "_" in sentence:
prefix, rest = sentence.split("_", 1)
query, rest = rest.split("_", 1)
query_index = len(prefix.rstrip().split(" "))
else:
query, query_index = None, None
prefix, rest = sentence.split('[', 1)
pronoun, rest = rest.split(']', 1)
pronoun_index = len(prefix.rstrip().split(' '))
prefix, rest = sentence.split("[", 1)
pronoun, rest = rest.split("]", 1)
pronoun_index = len(prefix.rstrip().split(" "))
sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
return {
'idx': 0,
'text': sentence,
'target': {
'span1_index': query_index,
'span1_text': query,
'span2_index': pronoun_index,
'span2_text': pronoun,
"idx": 0,
"text": sentence,
"target": {
"span1_index": query_index,
"span1_text": query,
"span2_index": pronoun_index,
"span2_text": pronoun,
},
}
def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
np_start, cur_np = 0, 'NONE'
np_start, cur_np = 0, "NONE"
for i, token in enumerate(sentence):
np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
if np_type != cur_np:
if cur_np != 'NONE':
if cur_np != "NONE":
noun_chunks.add((np_start, i))
if np_type != 'NONE':
if np_type != "NONE":
np_start = i
cur_np = np_type
if cur_np != 'NONE':
if cur_np != "NONE":
noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
@ -61,14 +61,14 @@ def find_token(sentence, start_pos):
def find_span(sentence, search_text, start=0):
search_text = search_text.lower()
for tok in sentence[start:]:
remainder = sentence[tok.i:].text.lower()
remainder = sentence[tok.i :].text.lower()
if remainder.startswith(search_text):
len_to_consume = len(search_text)
start_idx = tok.idx
for next_tok in sentence[tok.i:]:
for next_tok in sentence[tok.i :]:
end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume:
span = sentence[tok.i:next_tok.i + 1]
span = sentence[tok.i : next_tok.i + 1]
return span
return None
@ -76,13 +76,15 @@ def find_span(sentence, search_text, start=0):
@lru_cache(maxsize=1)
def get_detokenizer():
from sacremoses import MosesDetokenizer
detok = MosesDetokenizer(lang='en')
detok = MosesDetokenizer(lang="en")
return detok
@lru_cache(maxsize=1)
def get_spacy_nlp():
import en_core_web_lg
nlp = en_core_web_lg.load()
return nlp
@ -95,45 +97,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
for line in fin:
sample = json.loads(line.strip())
if positive_only and 'label' in sample and not sample['label']:
if positive_only and "label" in sample and not sample["label"]:
# only consider examples where the query is correct
continue
target = sample['target']
target = sample["target"]
# clean up the query
query = target['span1_text']
query = target["span1_text"]
if query is not None:
if '\n' in query:
if "\n" in query:
continue
if query.endswith('.') or query.endswith(','):
if query.endswith(".") or query.endswith(","):
query = query[:-1]
# split tokens
tokens = sample['text'].split(' ')
tokens = sample["text"].split(" ")
def strip_pronoun(x):
return x.rstrip('.,"')
# find the pronoun
pronoun_idx = target['span2_index']
pronoun = strip_pronoun(target['span2_text'])
pronoun_idx = target["span2_index"]
pronoun = strip_pronoun(target["span2_text"])
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1
else:
raise Exception('Misaligned pronoun!')
raise Exception("Misaligned pronoun!")
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun
before = tokens[:pronoun_idx]
after = tokens[pronoun_idx + 1:]
after = tokens[pronoun_idx + 1 :]
# the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun
leading_space = ' ' if pronoun_idx > 0 else ''
trailing_space = ' ' if len(after) > 0 else ''
leading_space = " " if pronoun_idx > 0 else ""
trailing_space = " " if len(after) > 0 else ""
# detokenize
before = detok.detokenize(before, return_str=True)
@ -142,14 +144,14 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
# hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part
if pronoun.endswith('.') or pronoun.endswith(','):
if pronoun.endswith(".") or pronoun.endswith(","):
after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove
# the trailing space
if after.startswith('.') or after.startswith(','):
trailing_space = ''
if after.startswith(".") or after.startswith(","):
trailing_space = ""
# parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
@ -164,13 +166,13 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
# convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_"
query_span = find_span(sentence, query)
query_with_ws = '_{}_{}'.format(
query_with_ws = "_{}_{}".format(
query_span.text,
(' ' if query_span.text_with_ws.endswith(' ') else '')
(" " if query_span.text_with_ws.endswith(" ") else ""),
)
pronoun_with_ws = '[{}]{}'.format(
pronoun_with_ws = "[{}]{}".format(
pronoun_span.text,
(' ' if pronoun_span.text_with_ws.endswith(' ') else '')
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
)
if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws)
@ -179,41 +181,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws)
sentence = (
sentence[:first[0].start].text_with_ws
sentence[: first[0].start].text_with_ws
+ first[1]
+ sentence[first[0].end:second[0].start].text_with_ws
+ sentence[first[0].end : second[0].start].text_with_ws
+ second[1]
+ sentence[second[0].end:].text
+ sentence[second[0].end :].text
)
yield sentence, sample.get('label', None)
yield sentence, sample.get("label", None)
else:
yield sentence, pronoun_span, query, sample.get('label', None)
yield sentence, pronoun_span, query, sample.get("label", None)
def winogrande_jsonl_iterator(input_fname, eval=False):
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
sentence, option1, option2 = sample['sentence'], sample['option1'],\
sample['option2']
sentence, option1, option2 = (
sample["sentence"],
sample["option1"],
sample["option2"],
)
pronoun_span = (sentence.index('_'), sentence.index('_') + 1)
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
if eval:
query, cand = option1, option2
else:
query = option1 if sample['answer'] == '1' else option2
cand = option2 if sample['answer'] == '1' else option1
query = option1 if sample["answer"] == "1" else option2
cand = option2 if sample["answer"] == "1" else option1
yield sentence, pronoun_span, query, cand
def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
def filter_noun_chunks(
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
):
if exclude_pronouns:
chunks = [
np for np in chunks if (
np.lemma_ != '-PRON-'
and not all(tok.pos_ == 'PRON' for tok in np)
)
np
for np in chunks
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
]
if exclude_query is not None:
@ -224,9 +230,8 @@ def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact
found = False
for excl in excl_txt:
if (
(not exact_match and (lower_chunk in excl or excl in lower_chunk))
or lower_chunk == excl
):
not exact_match and (lower_chunk in excl or excl in lower_chunk)
) or lower_chunk == excl:
found = True
break
if not found:

View File

@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import criterions, models, eval # noqa
from . import criterions, eval, models # noqa

View File

@ -6,6 +6,7 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]

View File

@ -3,21 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from examples.simultaneous_translation.utils.latency import LatencyTraining
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion
)
from examples.simultaneous_translation.utils.latency import (
LatencyTraining
LabelSmoothedCrossEntropyCriterion,
)
@register_criterion('latency_augmented_label_smoothed_cross_entropy')
@register_criterion("latency_augmented_label_smoothed_cross_entropy")
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing
@ -40,7 +36,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
def add_args(parser):
super(
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
LatencyAugmentedLabelSmoothedCrossEntropyCriterion
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
).add_args(parser)
"""Add criterion-specific arguments to the parser."""
# fmt: off
@ -69,7 +65,8 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
# Get latency loss
latency_loss = self.latency_train.loss(
attn_list, source_padding_mask, target_padding_mask)
attn_list, source_padding_mask, target_padding_mask
)
loss += latency_loss

View File

@ -5,16 +5,20 @@
import importlib
import os
from fairseq import registry
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type')
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
"--agent-type"
)
DEFAULT_EOS = '</s>'
DEFAULT_EOS = "</s>"
GET = 0
SEND = 1
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('agents.' + module)
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("agents." + module)

View File

@ -3,14 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import GET, SEND, DEFAULT_EOS
import time
from multiprocessing.pool import ThreadPool as Pool
from functools import partial
from multiprocessing.pool import ThreadPool as Pool
from . import DEFAULT_EOS, GET, SEND
class Agent(object):
"an agent needs to follow this pattern"
def __init__(self, *args, **kwargs):
pass
@ -40,26 +42,26 @@ class Agent(object):
with Pool(10) as p:
p.map(
partial(self._decode_one, session),
[sent_id for sent_id in range(low, high + 1)]
[sent_id for sent_id in range(low, high + 1)],
)
else:
for sent_id in range(low, high + 1):
self._decode_one(session, sent_id)
print(f'Finished {low} to {high} in {time.time() - t0}s')
print(f"Finished {low} to {high} in {time.time() - t0}s")
def _decode_one(self, session, sent_id):
action = {}
self.reset()
states = self.init_states()
while action.get('value', None) != DEFAULT_EOS:
while action.get("value", None) != DEFAULT_EOS:
# take an action
action = self.policy(states)
if action['key'] == GET:
if action["key"] == GET:
new_states = session.get_src(sent_id, action["value"])
states = self.update_states(states, new_states)
elif action['key'] == SEND:
session.send_hypo(sent_id, action['value'])
elif action["key"] == SEND:
session.send_hypo(sent_id, action["value"])
print(" ".join(states["tokens"]["tgt"]))

View File

@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . agent import Agent
from . import DEFAULT_EOS, GET, SEND
from fairseq import checkpoint_utils, utils, tasks
import os
import json
import os
from fairseq import checkpoint_utils, tasks, utils
from . import DEFAULT_EOS, GET, SEND
from .agent import Agent
class SimulTransAgent(Agent):
@ -51,13 +53,15 @@ class SimulTransAgent(Agent):
raise NotImplementedError
def load_model(self, args):
args.user_dir = os.path.join(os.path.dirname(__file__), '..', '..')
args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
utils.import_user_module(args)
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename, json.loads(args.model_overrides))
state = checkpoint_utils.load_checkpoint_to_cpu(
filename, json.loads(args.model_overrides)
)
saved_args = state["args"]
saved_args.data = args.data_bin
@ -79,7 +83,7 @@ class SimulTransAgent(Agent):
"steps": {"src": 0, "tgt": 0},
"finished": False,
"finish_read": False,
"model_states": {}
"model_states": {},
}
def update_states(self, states, new_state):
@ -115,38 +119,38 @@ class SimulTransAgent(Agent):
def write_action(self, states):
token, index = self.model.predict_from_states(states)
if index == self.dict["tgt"].eos() or len(states["tokens"]["tgt"]) > self.max_len:
if (
index == self.dict["tgt"].eos()
or len(states["tokens"]["tgt"]) > self.max_len
):
# Finish this sentence is predict EOS
states["finished"] = True
end_idx_last_full_word = self._target_length(states)
else:
states["tokens"]["tgt"] += [token]
end_idx_last_full_word = (
self.word_splitter["tgt"]
.end_idx_last_full_word(states["tokens"]["tgt"])
end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
states["tokens"]["tgt"]
)
self._append_indices(states, [index], "tgt")
if end_idx_last_full_word > states["steps"]["tgt"]:
# Only sent detokenized full words to the server
word = self.word_splitter["tgt"].merge(
states["tokens"]["tgt"][
states["steps"]["tgt"]: end_idx_last_full_word
]
states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
)
states["steps"]["tgt"] = end_idx_last_full_word
states["segments"]["tgt"] += [word]
return {'key': SEND, 'value': word}
return {"key": SEND, "value": word}
else:
return None
def read_action(self, states):
return {'key': GET, 'value': None}
return {"key": GET, "value": None}
def finish_action(self):
return {'key': SEND, 'value': DEFAULT_EOS}
return {"key": SEND, "value": DEFAULT_EOS}
def reset(self):
pass
@ -160,4 +164,4 @@ class SimulTransAgent(Agent):
states["indices"][key] += new_indices
def _target_length(self, states):
return len(states["tokens"]['tgt'])
return len(states["tokens"]["tgt"])

View File

@ -3,10 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . simul_trans_agent import SimulTransAgent
from . import DEFAULT_EOS, GET
from . import register_agent
from . word_splitter import SPLITTER_DICT
from . import DEFAULT_EOS, GET, register_agent
from .simul_trans_agent import SimulTransAgent
from .word_splitter import SPLITTER_DICT
@register_agent("simul_trans_text")
@ -15,11 +14,11 @@ class SimulTransTextAgent(SimulTransAgent):
self.word_splitter = {}
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
getattr(args, f"src_splitter_path")
)
getattr(args, f"src_splitter_path")
)
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
getattr(args, f"tgt_splitter_path")
)
getattr(args, f"tgt_splitter_path")
)
def load_dictionary(self, task):
self.dict = {}
@ -37,12 +36,16 @@ class SimulTransTextAgent(SimulTransAgent):
tokens = self.word_splitter["src"].split(new_word)
# Get indices from dictionary
# You can change to you own dictionary
indices = self.dict["src"].encode_line(
tokens,
line_tokenizer=lambda x: x,
add_if_not_exist=False,
append_eos=False
).tolist()
indices = (
self.dict["src"]
.encode_line(
tokens,
line_tokenizer=lambda x: x,
add_if_not_exist=False,
append_eos=False,
)
.tolist()
)
else:
tokens = [new_word]
indices = [self.dict["src"].eos()]
@ -61,11 +64,11 @@ class SimulTransTextAgent(SimulTransAgent):
# At leat one word is read
if len(states["tokens"]["src"]) == 0:
return {'key': GET, 'value': None}
return {"key": GET, "value": None}
# Only request new word if there is no buffered tokens
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
return {'key': GET, 'value': None}
return {"key": GET, "value": None}
return None

View File

@ -40,6 +40,7 @@ class BPEWordSplitter(object):
def __init__(self, model_path):
super().__init__()
from subword_nmt.apply_bpe import BPE
with open(model_path) as f:
self.model = BPE(f)
@ -48,7 +49,7 @@ class BPEWordSplitter(object):
def end_idx_last_full_word(self, tokens):
# Begin of word indices
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@']
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
if len(bow_indices) < 2:
return 0
@ -63,6 +64,7 @@ class SentencePieceModelWordSplitter(object):
def __init__(self, model_path):
super().__init__()
import sentencepiece as spm
self.model = spm.SentencePieceProcessor()
self.model.Load(model_path)
@ -71,7 +73,7 @@ class SentencePieceModelWordSplitter(object):
def end_idx_last_full_word(self, tokens):
# Begin of word indices
bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581']
bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
if len(bow_indices) < 2:
return 0

View File

@ -3,19 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import requests
from typing import Optional
import requests
from scorers import build_scorer
class SimulSTEvaluationService(object):
DEFAULT_HOSTNAME = 'localhost'
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
self.hostname = hostname
self.port = port
self.base_url = f'http://{self.hostname}:{self.port}'
self.base_url = f"http://{self.hostname}:{self.port}"
def __enter__(self):
self.new_session()
@ -25,56 +26,53 @@ class SimulSTEvaluationService(object):
def new_session(self):
# start eval session
url = f'{self.base_url}'
url = f"{self.base_url}"
try:
_ = requests.post(url)
except Exception as e:
print(f'Failed to start an evaluation session: {e}')
print(f"Failed to start an evaluation session: {e}")
print('Evaluation session started.')
print("Evaluation session started.")
return self
def get_scores(self):
# end eval session
url = f'{self.base_url}/result'
url = f"{self.base_url}/result"
try:
r = requests.get(url)
print('Scores: {}'.format(r.json()))
print('Evaluation session finished.')
print("Scores: {}".format(r.json()))
print("Evaluation session finished.")
except Exception as e:
print(f'Failed to end an evaluation session: {e}')
print(f"Failed to end an evaluation session: {e}")
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
url = f'{self.base_url}/src'
url = f"{self.base_url}/src"
params = {"sent_id": sent_id}
if extra_params is not None:
for key in extra_params.keys():
params[key] = extra_params[key]
try:
r = requests.get(
url,
params=params
)
r = requests.get(url, params=params)
except Exception as e:
print(f'Failed to request a source segment: {e}')
print(f"Failed to request a source segment: {e}")
return r.json()
def send_hypo(self, sent_id: int, hypo: str) -> None:
url = f'{self.base_url}/hypo'
url = f"{self.base_url}/hypo"
params = {"sent_id": sent_id}
try:
requests.put(url, params=params, data=hypo.encode("utf-8"))
except Exception as e:
print(f'Failed to send a translated segment: {e}')
print(f"Failed to send a translated segment: {e}")
def corpus_info(self):
url = f'{self.base_url}'
url = f"{self.base_url}"
try:
r = requests.get(url)
except Exception as e:
print(f'Failed to request corpus information: {e}')
print(f"Failed to request corpus information: {e}")
return r.json()

View File

@ -3,20 +3,21 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from examples.simultaneous_translation.utils.latency import LatencyInference
import argparse
import torch
import json
import torch
from examples.simultaneous_translation.utils.latency import LatencyInference
LATENCY_METRICS = [
'differentiable_average_lagging',
'average_lagging',
'average_proportion',
"differentiable_average_lagging",
"average_lagging",
"average_proportion",
]
class LatencyScorer():
class LatencyScorer:
def __init__(self, start_from_zero=True):
self.recorder = []
self.scores = {}
@ -26,10 +27,7 @@ class LatencyScorer():
def update_reorder(self, list_of_dict):
self.recorder = []
for info in list_of_dict:
delays = [
int(x) - int(not self.start_from_zero)
for x in info["delays"]
]
delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
@ -59,7 +57,7 @@ if __name__ == "__main__":
scorer = LatencyInference()
recorder = []
with open(args.input, 'r') as f:
with open(args.input, "r") as f:
for line in f:
info = json.loads(line)
@ -74,7 +72,7 @@ if __name__ == "__main__":
average_results = {}
for metric in LATENCY_METRICS:
average_results[metric] = sum(
[x[metric][0, 0].item() for x in recorder]
) / len(recorder)
average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
recorder
)
print(f"{metric}: {average_results[metric]}")

View File

@ -5,37 +5,48 @@
import argparse
from agents import build_agent
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES
from agents import build_agent
DEFAULT_HOSTNAME = 'localhost'
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
help='server hostname')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='server port number')
parser.add_argument('--agent-type', default='simul_trans_text',
help='Agent type')
parser.add_argument('--scorer-type', default='text',
help='Scorer type')
parser.add_argument('--start-idx', type=int, default=0,
help='Start index of the sentence to evaluate')
parser.add_argument('--end-idx', type=int, default=float('inf'),
help='End index of the sentence to evaluate')
parser.add_argument('--scores', action="store_true",
help='Request scores from server')
parser.add_argument('--reset-server', action="store_true",
help='Reset the server')
parser.add_argument('--num-threads', type=int, default=10,
help='Number of threads used by agent')
parser.add_argument('--local', action="store_true", default=False,
help='Local evaluation')
parser.add_argument(
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
)
parser.add_argument(
"--port", type=int, default=DEFAULT_PORT, help="server port number"
)
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
parser.add_argument("--scorer-type", default="text", help="Scorer type")
parser.add_argument(
"--start-idx",
type=int,
default=0,
help="Start index of the sentence to evaluate",
)
parser.add_argument(
"--end-idx",
type=int,
default=float("inf"),
help="End index of the sentence to evaluate",
)
parser.add_argument(
"--scores", action="store_true", help="Request scores from server"
)
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
parser.add_argument(
"--num-threads", type=int, default=10, help="Number of threads used by agent"
)
parser.add_argument(
"--local", action="store_true", default=False, help="Local evaluation"
)
args, _ = parser.parse_known_args()

View File

@ -5,15 +5,15 @@
import importlib
import os
from fairseq import registry
(
build_scorer,
register_scorer,
SCORER_REGISTRIES,
_
) = registry.setup_registry('--scorer-type')
(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
"--scorer-type"
)
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('scorers.' + module)
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("scorers." + module)

View File

@ -3,16 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from vizseq.scorers.bleu import BLEUScorer
from vizseq.scorers.ter import TERScorer
from vizseq.scorers.meteor import METEORScorer
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
from collections import defaultdict
import json
import os
from collections import defaultdict
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
from vizseq.scorers.bleu import BLEUScorer
from vizseq.scorers.meteor import METEORScorer
from vizseq.scorers.ter import TERScorer
DEFAULT_EOS = '</s>'
DEFAULT_EOS = "</s>"
class SimulScorer(object):
@ -23,7 +24,7 @@ class SimulScorer(object):
self.output_files = {
"text": os.path.join(args.output, "text"),
"delay": os.path.join(args.output, "delay"),
"scores": os.path.join(args.output, "scores")
"scores": os.path.join(args.output, "scores"),
}
else:
self.output_files = None
@ -52,14 +53,7 @@ class SimulScorer(object):
def recv_hyp(self, sent_id, list_of_tokens):
for token in list_of_tokens:
self.translations[
sent_id
].append(
(
token,
self.steps[sent_id]
)
)
self.translations[sent_id].append((token, self.steps[sent_id]))
def reset(self):
self.steps = defaultdict(int)
@ -76,8 +70,9 @@ class SimulScorer(object):
delays += [[t[1] for t in self.translations[i]]]
bleu_score = BLEUScorer(
sent_level=False, corpus_level=True,
extra_args={'bleu_tokenizer': self.tokenizer}
sent_level=False,
corpus_level=True,
extra_args={"bleu_tokenizer": self.tokenizer},
).score(translations, [self.data["tgt"]])
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
@ -92,16 +87,16 @@ class SimulScorer(object):
{"src_len": src_len, "delays": delay}
for src_len, delay in zip(self.src_lengths(), delays)
],
start_from_zero=False
start_from_zero=False,
)
scores = {
'BLEU': bleu_score[0],
'TER': ter_score[0],
'METEOR': meteor_score[0],
'DAL': latency_score['differentiable_average_lagging'],
'AL': latency_score['average_lagging'],
'AP': latency_score['average_proportion'],
"BLEU": bleu_score[0],
"TER": ter_score[0],
"METEOR": meteor_score[0],
"DAL": latency_score["differentiable_average_lagging"],
"AL": latency_score["average_lagging"],
"AP": latency_score["average_proportion"],
}
if self.output_files is not None:
@ -109,9 +104,9 @@ class SimulScorer(object):
os.makedirs(self.output_dir, exist_ok=True)
self.write_results_to_file(translations, delays, scores)
except BaseException as be:
print(f'Failed to write results to {self.output_dir}.')
print(f"Failed to write results to {self.output_dir}.")
print(be)
print('Skip writing predictions')
print("Skip writing predictions")
return scores
@ -125,12 +120,8 @@ class SimulScorer(object):
with open(self.output_files["delay"], "w") as f:
for i, delay in enumerate(delays):
f.write(
json.dumps(
{
"src_len": self.src_lengths()[i],
"delays": delay
}
) + "\n"
json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
+ "\n"
)
with open(self.output_files["scores"], "w") as f:
@ -163,7 +154,7 @@ class SimulScorer(object):
list_to_return.append(
{
"path": item["input"]["path"].strip(),
"length": item["input"]["length_ms"]
"length": item["input"]["length_ms"],
}
)
return list_to_return

View File

@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . scorer import SimulScorer
from . import register_scorer
from .scorer import SimulScorer
@register_scorer("text")
@ -13,7 +13,7 @@ class SimulTextScorer(SimulScorer):
super().__init__(args)
self.data = {
"src": self._load_text_file(args.src_file, split=True),
"tgt": self._load_text_file(args.tgt_file, split=False)
"tgt": self._load_text_file(args.tgt_file, split=False),
}
def send_src(self, sent_id, *args):
@ -21,7 +21,7 @@ class SimulTextScorer(SimulScorer):
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
"segment": self.eos
"segment": self.eos,
}
# Consider EOS
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
@ -29,7 +29,7 @@ class SimulTextScorer(SimulScorer):
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
"segment": self.data["src"][sent_id][self.steps[sent_id]]
"segment": self.data["src"][sent_id][self.steps[sent_id]],
}
self.steps[sent_id] += 1

View File

@ -3,12 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys
import json
from tornado import web, ioloop
from scorers import build_scorer
import sys
DEFAULT_HOSTNAME = 'localhost'
from scorers import build_scorer
from tornado import ioloop, web
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
@ -34,10 +36,10 @@ class ResultHandler(ScorerHandler):
class SourceHandler(ScorerHandler):
def get(self):
sent_id = int(self.get_argument('sent_id'))
sent_id = int(self.get_argument("sent_id"))
segment_size = None
if "segment_size" in self.request.arguments:
string = self.get_argument('segment_size')
string = self.get_argument("segment_size")
if len(string) > 0:
segment_size = int(string)
@ -48,8 +50,8 @@ class SourceHandler(ScorerHandler):
class HypothesisHandler(ScorerHandler):
def put(self):
sent_id = int(self.get_argument('sent_id'))
list_of_tokens = self.request.body.decode('utf-8').strip().split()
sent_id = int(self.get_argument("sent_id"))
list_of_tokens = self.request.body.decode("utf-8").strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens)
@ -67,18 +69,21 @@ def add_args():
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
app = web.Application([
(r'/result', ResultHandler, dict(scorer=scorer)),
(r'/src', SourceHandler, dict(scorer=scorer)),
(r'/hypo', HypothesisHandler, dict(scorer=scorer)),
(r'/', EvalSessionHandler, dict(scorer=scorer)),
], debug=debug)
app = web.Application(
[
(r"/result", ResultHandler, dict(scorer=scorer)),
(r"/src", SourceHandler, dict(scorer=scorer)),
(r"/hypo", HypothesisHandler, dict(scorer=scorer)),
(r"/", EvalSessionHandler, dict(scorer=scorer)),
],
debug=debug,
)
app.listen(port, max_buffer_size=1024 ** 3)
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
ioloop.IOLoop.current().start()
if __name__ == '__main__':
if __name__ == "__main__":
args = add_args()
scorer = build_scorer(args)
start_server(scorer, args.hostname, args.port, args.debug)

View File

@ -6,7 +6,10 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
model_name = file[:file.find('.py')]
importlib.import_module('examples.simultaneous_translation.models.' + model_name)
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.models." + model_name
)

View File

@ -6,42 +6,34 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
register_model,
register_model_architecture,
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer,
)
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
TransformerModel,
TransformerEncoder,
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture,
transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big,
)
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_unidirectional')
@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@register_model('transformer_monotonic')
@register_model("transformer_monotonic")
class TransformerMonotonicModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@ -62,26 +54,17 @@ class TransformerMonotonicModel(TransformerModel):
)
tgt_indices = tensor(
[
[self.decoder.dictionary.eos()]
+ states["indices"]["tgt"]
]
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
)
else:
src_indices = states["indices"]["src"][: 1 +
states["steps"]["src"]]
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
tgt_indices = states["indices"]["tgt"]
return src_indices, None, tgt_indices
def predict_from_states(self, states):
decoder_states = self.decoder.output_layer(
states["decoder_features"]
)
lprobs = self.get_normalized_probs(
[decoder_states[:, -1:]],
log_probs=True
)
decoder_states = self.decoder.output_layer(states["decoder_features"])
lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
index = lprobs.argmax(dim=-1)
@ -90,25 +73,24 @@ class TransformerMonotonicModel(TransformerModel):
return token, index[0, 0].item()
def decision_from_states(self, states):
'''
"""
This funcion take states dictionary as input, and gives the agent
a decision of whether read a token from server. Moreover, the decoder
states are also calculated here so we can directly generate a target
token without recompute every thing
'''
"""
self.eval()
if len(states["tokens"]["src"]) == 0:
return 0
src_indices, src_lengths, tgt_indices = self._indices_from_states(
states)
src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
# Update encoder states if needed
if (
"encoder_states" not in states or
states["encoder_states"][0].size(1) <= states["steps"]["src"]
"encoder_states" not in states
or states["encoder_states"][0].size(1) <= states["steps"]["src"]
):
encoder_out_dict = self.encoder(src_indices, src_lengths)
states["encoder_states"] = encoder_out_dict
@ -136,16 +118,14 @@ class TransformerMonotonicModel(TransformerModel):
class TransformerMonotonicEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerMonotonicEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.layers.extend(
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
)
class TransformerMonotonicDecoder(TransformerDecoder):
@ -166,19 +146,24 @@ class TransformerMonotonicDecoder(TransformerDecoder):
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
self.layers.extend(
[
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
def pre_attention(
self, prev_output_tokens, encoder_out_dict,
incremental_state=None
self, prev_output_tokens, encoder_out_dict, incremental_state=None
):
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
@ -216,8 +201,7 @@ class TransformerMonotonicDecoder(TransformerDecoder):
return x
def extract_features(
self, prev_output_tokens, encoder_out,
incremental_state=None, **unused
self, prev_output_tokens, encoder_out, incremental_state=None, **unused
):
"""
Similar to *forward* but only return features.
@ -228,14 +212,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
- a dictionary with any model-specific outputs
"""
# incremental_state = None
(
x,
encoder_outs,
encoder_padding_mask
) = self.pre_attention(
prev_output_tokens,
encoder_out,
incremental_state
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
prev_output_tokens, encoder_out, incremental_state
)
attn = None
inner_states = [x]
@ -250,7 +228,8 @@ class TransformerMonotonicDecoder(TransformerDecoder):
encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None else None,
if incremental_state is None
else None,
)
inner_states.append(x)
@ -261,38 +240,30 @@ class TransformerMonotonicDecoder(TransformerDecoder):
step_list.append(curr_steps)
if incremental_state.get("online", False):
p_choose = attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
new_steps = (
curr_steps
+ (p_choose < 0.5).t().type_as(curr_steps)
p_choose = (
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
)
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
if (new_steps >= incremental_state["steps"]["src"]).any():
# We need to prune the last self_attn saved_state
# if model decide not to read
# otherwise there will be duplicated saved_state
for j in range(i + 1):
self.layers[j].prune_incremental_state(
incremental_state)
self.layers[j].prune_incremental_state(incremental_state)
return x, {"action": 0}
if (
incremental_state is not None
and not incremental_state.get("online", False)
):
if incremental_state is not None and not incremental_state.get("online", False):
# Here is for fast evaluation
fastest_step = torch.max(
torch.cat(step_list, dim=1),
dim=1,
keepdim=True
)[0] + 1
fastest_step = (
torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = torch.cat(
[incremental_state["fastest_step"], fastest_step],
dim=1
[incremental_state["fastest_step"], fastest_step], dim=1
)
else:
incremental_state["fastest_step"] = fastest_step
@ -310,25 +281,19 @@ class TransformerMonotonicDecoder(TransformerDecoder):
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = (
incremental_state["fastest_step"]
.index_select(0, new_order)
)
incremental_state["fastest_step"] = incremental_state[
"fastest_step"
].index_select(0, new_order)
@register_model_architecture(
'transformer_monotonic',
'transformer_monotonic'
)
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_rchitecture(args):
base_architecture(args)
args.encoder_unidirectional = getattr(
args, 'encoder_unidirectional', False)
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
@register_model_architecture(
'transformer_monotonic',
'transformer_monotonic_iwslt_de_en'
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args)
@ -337,24 +302,21 @@ def transformer_monotonic_iwslt_de_en(args):
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
'transformer_monotonic',
'transformer_monotonic_vaswani_wmt_en_de_big'
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture(
'transformer_monotonic',
'transformer_monotonic_vaswani_wmt_en_fr_big'
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture(
'transformer_unidirectional',
'transformer_unidirectional_iwslt_de_en'
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args)

View File

@ -7,14 +7,18 @@ import importlib
import os
from fairseq import registry
(
build_monotonic_attention,
register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY,
_
) = registry.setup_registry('--simul-type')
_,
) = registry.setup_registry("--simul-type")
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
model_name = file[:file.find('.py')]
importlib.import_module('examples.simultaneous_translation.modules.' + model_name)
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.modules." + model_name
)

View File

@ -4,22 +4,19 @@
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
from fairseq import utils
from fairseq.modules import MultiheadAttention
import torch.nn.functional as F
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
lengths_to_mask
lengths_to_mask,
)
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules import MultiheadAttention
from fairseq.utils import convert_padding_direction
from . import register_monotonic_attention
@ -28,6 +25,7 @@ class MonotonicAttention(nn.Module):
"""
Abstract class of monotonic attentions
"""
def __init__(self, args):
self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation
@ -38,7 +36,8 @@ class MonotonicAttention(nn.Module):
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True else 0
if args.energy_bias is True
else 0
)
@staticmethod
@ -90,7 +89,7 @@ class MonotonicAttention(nn.Module):
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
float('-inf'),
float("-inf"),
)
return attn_energy
@ -131,10 +130,7 @@ class MonotonicAttention(nn.Module):
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(
previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
dim=1
)
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
@ -170,8 +166,7 @@ class MonotonicAttention(nn.Module):
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"step",
p_choose.new_zeros([bsz, self.num_heads]).long()
"step", p_choose.new_zeros([bsz, self.num_heads]).long()
)
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
@ -181,8 +176,7 @@ class MonotonicAttention(nn.Module):
p_choose = p_choose.view(bsz, num_heads, src_len)
if key_padding_mask is not None:
src_lengths = src_len - \
key_padding_mask.sum(dim=1, keepdim=True).long()
src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
@ -197,10 +191,7 @@ class MonotonicAttention(nn.Module):
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
max_steps = (
src_lengths - 1 if self.mass_preservation
else src_lengths
)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
@ -210,11 +201,11 @@ class MonotonicAttention(nn.Module):
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose
.gather(
p_choose.gather(
2,
(step_offset + new_monotonic_step).unsqueeze(2)
.clamp(0, src_len - 1)
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
@ -239,21 +230,17 @@ class MonotonicAttention(nn.Module):
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = (
p_choose
.new_zeros([bsz * self.num_heads, src_len])
.scatter(
1,
(step_offset + new_monotonic_step).view(bsz *
self.num_heads, 1).clamp(0, src_len - 1),
1
)
alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1)
.clamp(0, src_len - 1),
1,
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1),
0
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
)
alpha = alpha.unsqueeze(1)
@ -266,8 +253,14 @@ class MonotonicAttention(nn.Module):
raise NotImplementedError
def forward(
self, query, key, value,
key_padding_mask=None, incremental_state=None, *args, **kwargs,
self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
*args,
**kwargs,
):
tgt_len, bsz, embed_dim = query.size()
@ -280,25 +273,24 @@ class MonotonicAttention(nn.Module):
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(p_choose, key_padding_mask, incremental_state)
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state
)
else:
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(alpha, query, key, value, key_padding_mask, incremental_state)
beta = self.expected_attention(
alpha, query, key, value, key_padding_mask, incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = (
attn
.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
@ -318,26 +310,32 @@ class MonotonicAttention(nn.Module):
self._set_monotonic_buffer(incremental_state, input_buffer)
def _get_monotonic_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'monotonic',
) or {}
return (
utils.get_incremental_state(
self,
incremental_state,
"monotonic",
)
or {}
)
def _set_monotonic_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'monotonic',
"monotonic",
buffer,
)
def get_pointer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'monotonic',
) or {}
return (
utils.get_incremental_state(
self,
incremental_state,
"monotonic",
)
or {}
)
def get_fastest_pointer(self, incremental_state):
return self.get_pointer(incremental_state)["step"].max(0)[0]
@ -354,23 +352,22 @@ class MonotonicAttention(nn.Module):
utils.set_incremental_state(
self,
incremental_state,
'monotonic',
"monotonic",
{"step": buffer},
)
@register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
def __init__(self, args):
MultiheadAttention.__init__(
self,
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, 'encoder_embed_dim', None),
vdim=getattr(args, 'encoder_embed_dim', None),
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True
encoder_decoder_attention=True,
)
MonotonicAttention.__init__(self, args)
@ -395,21 +392,33 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
bsz = query.size(1)
q = self.q_in_proj[name](query)
q *= self.scaling
q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
q = (
q.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
q = None
if key is not None:
bsz = key.size(1)
k = self.k_in_proj[name](key)
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
k = None
if value is not None:
bsz = value.size(1)
v = self.v_in_proj[name](value)
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
v = None
@ -441,8 +450,7 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
if self.training:
# add noise here to encourage discretness
noise = (
torch
.normal(self.noise_mean, self.noise_var, attn_energy.size())
torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
.type_as(attn_energy)
.to(attn_energy.device)
)
@ -454,9 +462,9 @@ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
return p_choose.view(-1, tgt_len, src_len)
def expected_attention(self, alpha, *args):
'''
"""
For MMA-H, beta = alpha
'''
"""
return alpha
def v_proj_output(self, value):
@ -479,13 +487,19 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
def expected_attention(self, alpha, query, key, value, key_padding_mask, incremental_state):
def expected_attention(
self, alpha, query, key, value, key_padding_mask, incremental_state
):
# monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads)
@ -507,9 +521,10 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_step += step_offset
mask = lengths_to_mask(
monotonic_step.view(-1), soft_energy.size(2), 1).unsqueeze(1)
monotonic_step.view(-1), soft_energy.size(2), 1
).unsqueeze(1)
soft_energy = soft_energy.masked_fill(~ mask.bool(), float('-inf'))
soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
@ -524,14 +539,20 @@ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHar
if key_padding_mask is not None:
if key_padding_mask.any():
exp_soft_energy_cumsum = (
exp_soft_energy_cumsum.view(-1, self.num_heads, tgt_len, src_len)
.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps)
exp_soft_energy_cumsum.view(
-1, self.num_heads, tgt_len, src_len
)
.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
)
.view(-1, tgt_len, src_len)
)
inner_items = alpha / exp_soft_energy_cumsum
beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2])
beta = exp_soft_energy * torch.cumsum(
inner_items.flip(dims=[2]), dim=2
).flip(dims=[2])
beta = self.dropout_module(beta)
@ -547,7 +568,9 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert self.waitk_lagging > 0, f"Lagging has to been larger than 0, get {self.waitk_lagging}."
assert (
self.waitk_lagging > 0
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
@staticmethod
def add_args(parser):
@ -556,10 +579,13 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
MonotonicMultiheadAttentionWaitk,
).add_args(parser)
parser.add_argument('--waitk-lagging', type=int, required=True,
help='Wait k lagging')
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait k lagging"
)
def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None):
def p_choose(
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
):
"""
query: bsz, tgt_len
key: bsz, src_len
@ -574,16 +600,22 @@ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookba
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
# Left pad source
# add -1 to the end
p_choose = p_choose.masked_fill(key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1)
p_choose = convert_padding_direction(p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True)
p_choose = p_choose.masked_fill(
key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
)
p_choose = convert_padding_direction(
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
)
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
# remove -1
p_choose[p_choose.eq(-1)] = 0
# Extend to each head
p_choose = (
p_choose.contiguous().unsqueeze(1)
.expand(-1, self.num_heads, -1, -1).contiguous()
p_choose.contiguous()
.unsqueeze(1)
.expand(-1, self.num_heads, -1, -1)
.contiguous()
.view(-1, tgt_len, src_len)
)

View File

@ -3,37 +3,32 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.modules import (
LayerNorm,
TransformerEncoderLayer,
TransformerDecoderLayer
)
from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
from . import build_monotonic_attention
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float('-inf'))
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(
args,
no_encoder_attn=True,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn
add_zero_attn=add_zero_attn,
)
self.encoder_attn = build_monotonic_attention(args)
self.encoder_attn_layer_norm = LayerNorm(
self.embed_dim,
export=getattr(args, 'char_inputs', False)
self.embed_dim, export=getattr(args, "char_inputs", False)
)
def prune_incremental_state(self, incremental_state):
@ -46,12 +41,8 @@ class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
input_buffer = {}
break
module._set_input_buffer(incremental_state, input_buffer)
prune(self.self_attn)
def get_steps(self, incremental_state):
return (
self.encoder_attn
._get_monotonic_buffer(
incremental_state
).get("step", 0)
)
return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)

View File

@ -9,6 +9,6 @@ import os
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('examples.simultaneous_translation.utils.' + module)
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("examples.simultaneous_translation.utils." + module)

View File

@ -16,7 +16,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod(
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
dim=dim,
eps=eps,
)
if dim == 0:
@ -132,12 +134,14 @@ def moving_sum(x, start_idx: int, end_idx: int):
# batch_size, 1, src_len
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
moving_sum = torch.nn.functional.conv1d(
x,
moving_sum_weight,
padding=start_idx + end_idx - 1
).squeeze(1).t()
moving_sum = moving_sum[end_idx: -start_idx]
moving_sum = (
torch.nn.functional.conv1d(
x, moving_sum_weight, padding=start_idx + end_idx - 1
)
.squeeze(1)
.t()
)
moving_sum = moving_sum[end_idx:-start_idx]
assert src_len == moving_sum.size(0)
assert batch_size == moving_sum.size(1)

View File

@ -18,7 +18,7 @@ class LatencyMetric(object):
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True
start_from_zero: bool = True,
):
assert len(delays.size()) == 2
assert len(src_lens.size()) == 2
@ -59,11 +59,7 @@ class LatencyMetric(object):
start_from_zero: bool = True,
):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
delays,
src_lens,
target_padding_mask,
batch_first,
start_from_zero
delays, src_lens, target_padding_mask, batch_first, start_from_zero
)
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@ -89,10 +85,13 @@ class AverageProportion(LatencyMetric):
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None:
AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True)
AP = torch.sum(
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
)
else:
AP = torch.sum(delays, dim=0, keepdim=True)
@ -116,14 +115,24 @@ class AverageLagging(LatencyMetric):
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens
lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :]
lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask.t(), (1, 0)
).t()[:-1, :]
gamma = tgt_lens / src_lens
lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
lagging = (
delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau
@ -149,6 +158,7 @@ class DifferentiableAverageLagging(LatencyMetric):
2. max(delays_i, delays'_{i-1} + 1 / gamma)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size()
@ -163,13 +173,18 @@ class DifferentiableAverageLagging(LatencyMetric):
new_delays[i] = torch.cat(
[
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
delays[i].unsqueeze(0)
delays[i].unsqueeze(0),
],
dim=0
dim=0,
).max(dim=0)[0]
DAL = (
new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
new_delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0)
@ -186,7 +201,7 @@ class LatencyMetricVariance(LatencyMetric):
src_lens,
target_padding_mask=None,
batch_first: bool = True,
start_from_zero: bool = True
start_from_zero: bool = True,
):
assert batch_first
assert len(delays.size()) == 3
@ -256,25 +271,21 @@ class LatencyInference(object):
src_lens = src_lens
delays = (
monotonic_step
.view(monotonic_step.size(0), -1, monotonic_step.size(-1))
.max(dim=1)[0]
)
delays = monotonic_step.view(
monotonic_step.size(0), -1, monotonic_step.size(-1)
).max(dim=1)[0]
delays = (
delays.masked_fill(delays >= src_lens, 0)
+ (src_lens - 1)
.expand_as(delays)
.masked_fill(delays < src_lens, 0)
)
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
delays
).masked_fill(delays < src_lens, 0)
return_dict = {}
for key, func in self.metric_calculator.items():
return_dict[key] = func(
delays.float(), src_lens.float(),
delays.float(),
src_lens.float(),
target_padding_mask=None,
batch_first=True,
start_from_zero=True
start_from_zero=True,
).t()
return return_dict
@ -282,8 +293,13 @@ class LatencyInference(object):
class LatencyTraining(object):
def __init__(
self, avg_weight, var_weight, avg_type, var_type,
stay_on_last_token, average_method,
self,
avg_weight,
var_weight,
avg_type,
var_type,
stay_on_last_token,
average_method,
):
self.avg_weight = avg_weight
self.var_weight = var_weight
@ -319,17 +335,12 @@ class LatencyTraining(object):
attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token:
residual_attention = \
1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
attention = torch.cat(
[attention[:, :, :-1], residual_attention],
dim=2
)
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = (
torch
.arange(1, 1 + src_len)
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(attention)
@ -355,15 +366,12 @@ class LatencyTraining(object):
src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len
expected_delays = (steps * attention).sum(dim=2).view(
bsz, num_heads_x_layers, tgt_len
expected_delays = (
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
)
if target_padding_mask is not None:
expected_delays.masked_fill_(
target_padding_mask.unsqueeze(1),
0
)
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
return expected_delays, src_lens
@ -371,8 +379,7 @@ class LatencyTraining(object):
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = (
target_padding_mask
.unsqueeze(1)
target_padding_mask.unsqueeze(1)
.expand_as(expected_delays)
.contiguous()
.view(-1, tgt_len)
@ -396,8 +403,11 @@ class LatencyTraining(object):
if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type](
expected_delays, src_lens, target_padding_mask,
batch_first=True, start_from_zero=False
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.avg_type} is not supported.")
@ -408,12 +418,17 @@ class LatencyTraining(object):
return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask):
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1]
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
:, :1
]
if self.var_weight > 0.0:
if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type](
expected_delays, src_lens, target_padding_mask,
batch_first=True, start_from_zero=False
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.var_type} is not supported.")

View File

@ -1 +1 @@
from . import tasks, criterions, models # noqa
from . import criterions, models, tasks # noqa

View File

@ -6,9 +6,9 @@
# LICENSE file in the root directory of this source tree.
import torch
from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from examples.speech_recognition.data.replabels import pack_replabels
@register_criterion("asg_loss")

View File

@ -5,6 +5,7 @@
from .asr_dataset import AsrDataset
__all__ = [
'AsrDataset',
"AsrDataset",
]

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from fairseq.data import FairseqDataset
@ -30,16 +31,22 @@ class AsrDataset(FairseqDataset):
"""
def __init__(
self, aud_paths, aud_durations_ms, tgt,
tgt_dict, ids, speakers,
num_mel_bins=80, frame_length=25.0, frame_shift=10.0
self,
aud_paths,
aud_durations_ms,
tgt,
tgt_dict,
ids,
speakers,
num_mel_bins=80,
frame_length=25.0,
frame_shift=10.0,
):
assert frame_length > 0
assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [
int(1 + (d - frame_length) / frame_shift)
for d in aud_durations_ms
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
]
assert len(aud_paths) > 0
@ -57,13 +64,17 @@ class AsrDataset(FairseqDataset):
self.frame_shift = frame_shift
self.s2s_collater = Seq2SeqCollater(
0, 1, pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
0,
1,
pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(),
move_eos_to_beginning=True,
)
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
@ -74,7 +85,7 @@ class AsrDataset(FairseqDataset):
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift
frame_shift=self.frame_shift,
)
output_cmvn = data_utils.apply_mv_norm(output)

View File

@ -12,18 +12,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object):
"""
Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and
targets.
This collator is also used for aligned training task.
Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and
targets.
This collator is also used for aligned training task.
"""
def __init__(

View File

@ -6,52 +6,74 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import namedtuple
import concurrent.futures
from itertools import chain
import argparse
import os
import concurrent.futures
import json
import sentencepiece as spm
import multiprocessing
import os
from collections import namedtuple
from itertools import chain
import sentencepiece as spm
from fairseq.data import Dictionary
MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS)
input["length_ms"] = int(
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
)
input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable
output["token"] = token
output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids]))
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True,
help="input directories with audio files")
parser.add_argument("--labels", required=True,
help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType('r', encoding='UTF-8'))
parser.add_argument("--spm-model", required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType('r', encoding='UTF-8'))
parser.add_argument("--dictionary", required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType('r', encoding='UTF-8'))
parser.add_argument(
"--audio-dirs",
nargs="+",
default=["-"],
required=True,
help="input directories with audio files",
)
parser.add_argument(
"--labels",
required=True,
help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--spm-model",
required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--dictionary",
required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
parser.add_argument("--output", required=True, type=argparse.FileType('w'),
help="path to save json output")
parser.add_argument(
"--output",
required=True,
type=argparse.FileType("w"),
help="path to save json output",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
@ -64,15 +86,17 @@ def main():
(utt_id, label) = line.split(" ", 1)
labels[utt_id] = label
if len(labels) == 0:
raise Exception('No labels found in ', args.labels_path)
raise Exception("No labels found in ", args.labels_path)
Sample = namedtuple('Sample', 'aud_path utt_id')
Sample = namedtuple("Sample", "aud_path utt_id")
samples = []
for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs):
for path, _, files in chain.from_iterable(
os.walk(path) for path in args.audio_dirs
):
for f in files:
if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2:
raise Exception('Expect <utt_id.extension> file name. Got: ', f)
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
utt_id = os.path.splitext(f)[0]
if utt_id not in labels:
continue
@ -81,12 +105,17 @@ def main():
utts = {}
num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples}
future_to_sample = {
executor.submit(
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
): s
for s in samples
}
for future in concurrent.futures.as_completed(future_to_sample):
try:
data = future.result()
except Exception as exc:
print('generated an exception: ', exc)
print("generated an exception: ", exc)
else:
utts.update(data)
json.dump({"utts": utts}, args.output, indent=4)

View File

@ -8,17 +8,17 @@
Run inference for pre-processed data with a trained model.
"""
import editdistance
import logging
import math
import os
import sys
import editdistance
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig()
@ -52,10 +52,12 @@ output units",
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
"--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder"
"--w2l-decoder",
choices=["viterbi", "kenlm", "fairseqlm"],
help="use a w2l decoder",
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm")
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100)
@ -87,10 +89,10 @@ def check_args(args):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
@ -110,7 +112,7 @@ def get_dataset_itr(args, task, models):
def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
@ -122,16 +124,25 @@ def process_predictions(
if res_files is not None:
print(
"{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
)
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.remove_bpe)
if res_files is not None:
print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
)
# only score top hypothesis
if not args.quiet:
logger.debug("HYPO:" + hyp_words)
@ -146,7 +157,7 @@ def process_predictions(
def prepare_result_files(args):
def get_res_file(file_prefix):
if args.num_shards > 1:
file_prefix = f'{args.shard_id}_{file_prefix}'
file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
@ -166,15 +177,17 @@ def prepare_result_files(args):
}
def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None):
def load_models_and_criterions(
filenames, data_path, arg_overrides=None, task=None, model_state=None
):
models = []
criterions = []
if arg_overrides is None:
arg_overrides = {}
arg_overrides['wer_args'] = None
arg_overrides['data'] = data_path
arg_overrides["wer_args"] = None
arg_overrides["data"] = data_path
if filenames is None:
assert model_state is not None
@ -205,8 +218,7 @@ def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=No
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
"""Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
@ -229,7 +241,7 @@ class ExistingEmissionsDecoder(object):
emissions = np.stack(self.emissions[ids])
except:
print([x.shape for x in self.emissions[ids]])
raise Exception('invalid sizes')
raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions)
@ -300,7 +312,9 @@ def main(args, task=None, model_state=None):
return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment')
print(
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args)
@ -361,7 +375,11 @@ def main(args, task=None, model_state=None):
encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None
padding = (
encoder_out["encoder_padding_mask"][i].cpu().numpy()
if encoder_out["encoder_padding_mask"] is not None
else None
)
features[id.item()] = (feat[i], padding)
continue
hypos = task.inference_step(generator, models, sample, prefix_tokens)
@ -372,20 +390,31 @@ def main(args, task=None, model_state=None):
speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id
toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :]
target_tokens = (
utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
toks = (
sample["target"][i, :]
if "target_label" not in sample
else sample["target_label"][i, :]
)
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions
errs, length = process_predictions(
args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id
args,
hypos[i],
None,
tgt_dict,
target_tokens,
res_files,
speaker,
id,
)
errs_t += errs
lengths_t += length
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
wer = None
if args.dump_emissions:
@ -413,7 +442,7 @@ def main(args, task=None, model_state=None):
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer
@ -424,6 +453,7 @@ def make_parser():
parser = add_asr_eval_argument(parser)
return parser
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)

View File

@ -1,7 +1,8 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
model_name = file[:file.find('.py')]
importlib.import_module('examples.speech_recognition.models.' + model_name)
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.models." + model_name)

View File

@ -9,18 +9,22 @@ from collections.abc import Iterable
import torch
import torch.nn as nn
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import LinearizedConvolution
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
from fairseq.modules import (
LinearizedConvolution,
TransformerDecoderLayer,
TransformerEncoderLayer,
VGGBlock,
)
@register_model("asr_vggtransformer")
@ -29,6 +33,7 @@ class VGGTransformerModel(FairseqEncoderDecoderModel):
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@ -602,18 +607,22 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
))
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.layers.append(TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
))
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
)
)
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
@ -713,6 +722,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x = x.transpose(0, 1)
return x
@register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):

View File

@ -10,7 +10,6 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,

View File

@ -1,7 +1,8 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
task_name = file[:file.find('.py')]
importlib.import_module('examples.speech_recognition.tasks.' + task_name)
if file.endswith(".py") and not file.startswith("_"):
task_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.tasks." + task_name)

View File

@ -9,10 +9,10 @@ import re
import sys
import torch
from fairseq.data import Dictionary
from fairseq.tasks import register_task, LegacyFairseqTask
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict):
@ -78,10 +78,20 @@ class SpeechRecognitionTask(LegacyFairseqTask):
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N',
help='max number of frames in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument(
"--max-source-positions",
default=sys.maxsize,
type=int,
metavar="N",
help="max number of frames in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)

View File

@ -9,16 +9,18 @@
Wav2letter decoders.
"""
from collections import namedtuple, deque
import gc
import itertools as it
import numpy as np
import torch
import os.path as osp
import warnings
from collections import deque, namedtuple
import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from examples.speech_recognition.data.replabels import unpack_replabels
try:
from wav2letter.common import create_word_dict, load_words

View File

@ -4,66 +4,76 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing import cpu_count
import csv
import os
import os.path as op
from glob import glob
import zipfile
import csv
from functools import reduce
from typing import Dict, Any, List
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
from glob import glob
from multiprocessing import cpu_count
from typing import Any, Dict, List
import sentencepiece as sp
from tqdm import tqdm
import numpy as np
import sentencepiece as sp
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
from tqdm import tqdm
UNK_TOKEN, UNK_TOKEN_ID = '<unk>', 3
BOS_TOKEN, BOS_TOKEN_ID = '<s>', 0
EOS_TOKEN, EOS_TOKEN_ID = '</s>', 2
PAD_TOKEN, PAD_TOKEN_ID = '<pad>', 1
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
def gen_vocab(
input_path: str, output_path_prefix: str, model_type='bpe',
vocab_size=1000,
input_path: str,
output_path_prefix: str,
model_type="bpe",
vocab_size=1000,
):
# Train SentencePiece Model
arguments = [
f'--input={input_path}',
f'--model_prefix={output_path_prefix}',
f'--model_type={model_type}',
f'--vocab_size={vocab_size}',
'--character_coverage=1.0',
f'--num_threads={cpu_count()}',
f'--unk_id={UNK_TOKEN_ID}',
f'--bos_id={BOS_TOKEN_ID}',
f'--eos_id={EOS_TOKEN_ID}',
f'--pad_id={PAD_TOKEN_ID}'
f"--input={input_path}",
f"--model_prefix={output_path_prefix}",
f"--model_type={model_type}",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
f"--num_threads={cpu_count()}",
f"--unk_id={UNK_TOKEN_ID}",
f"--bos_id={BOS_TOKEN_ID}",
f"--eos_id={EOS_TOKEN_ID}",
f"--pad_id={PAD_TOKEN_ID}",
]
sp.SentencePieceTrainer.Train(' '.join(arguments))
sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix + '.model')
spm.Load(output_path_prefix + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \
vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \
vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \
vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
assert (
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
)
vocab = {
i: s for i, s in vocab.items()
i: s
for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
with open(output_path_prefix + '.txt', 'w') as f_out:
with open(output_path_prefix + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f'{s} 1\n')
f_out.write(f"{s} 1\n")
def extract_fbank_features(waveform, sample_rate, output_path=None,
n_mel_bins=80, apply_utterance_cmvn=True,
overwrite=False):
def extract_fbank_features(
waveform,
sample_rate,
output_path=None,
n_mel_bins=80,
apply_utterance_cmvn=True,
overwrite=False,
):
if output_path is not None and op.exists(output_path) and not overwrite:
return
@ -74,8 +84,10 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
raise ImportError('Please install pyKaldi or torchaudio to enable '
'online filterbank feature extraction')
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
if apply_utterance_cmvn:
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
@ -89,8 +101,8 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
def create_zip(data_root, zip_path):
cwd = os.path.abspath(os.curdir)
os.chdir(data_root)
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f:
for filename in tqdm(glob('*.npy')):
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
for filename in tqdm(glob("*.npy")):
f.write(filename)
os.chdir(cwd)
@ -101,69 +113,80 @@ def is_npy_data(data: bytes) -> bool:
def get_zip_manifest(zip_root, zip_filename):
zip_path = op.join(zip_root, zip_filename)
with zipfile.ZipFile(zip_path, mode='r') as f:
with zipfile.ZipFile(zip_path, mode="r") as f:
info = f.infolist()
manifest = {}
for i in tqdm(info):
utt_id = op.splitext(i.filename)[0]
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}'
with open(zip_path, 'rb') as f:
manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
with open(zip_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data)
return manifest
def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml',
specaugment_policy='lb'):
assert specaugment_policy in {'lb', 'ld'}
def gen_config_yaml(
data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
):
assert specaugment_policy in {"lb", "ld"}
data_root = op.abspath(data_root)
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
writer.set_audio_root(op.abspath(data_root))
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1)
writer.set_input_feat_per_channel(80)
if specaugment_policy == 'lb':
if specaugment_policy == "lb":
writer.set_specaugment_lb_policy()
else:
writer.set_specaugment_ld_policy()
writer.set_bpe_tokenizer(
{'bpe': 'sentencepiece',
'sentencepiece_model': op.join(data_root, spm_filename)}
{
"bpe": "sentencepiece",
"sentencepiece_model": op.join(data_root, spm_filename),
}
)
writer.set_feature_transforms('_train', ['specaugment'])
writer.set_feature_transforms("_train", ["specaugment"])
writer.flush()
def save_df_to_tsv(dataframe, path):
dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8",
escapechar='\\', quoting=csv.QUOTE_NONE)
dataframe.to_csv(
path,
sep="\t",
header=True,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
def filter_manifest_df(df, is_train_split=False, extra_filters=None,
min_n_frames=5, max_n_frames=3000):
def filter_manifest_df(
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
):
filters = {
'no speech': df['audio'] == '',
f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames,
'empty sentence': df['tgt_text'] == '',
"no speech": df["audio"] == "",
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
"empty sentence": df["tgt_text"] == "",
}
if is_train_split:
filters[f'long speech (>{max_n_frames} frames)'] = \
df['n_frames'] > max_n_frames
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
'| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) +
f', total {invalid.sum()} filtered, {valid.sum()} remained.'
"| "
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
)
return df[valid]
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = 'dict.txt'
DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
@ -171,48 +194,69 @@ class S2TDataConfigWriter(object):
try:
import yaml
except ImportError:
print('Please install PyYAML to load YAML files for S2T data config')
print("Please install PyYAML to load YAML files for S2T data config")
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
with open(self.yaml_path, 'w') as f:
with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=''):
self.config['audio_root'] = audio_root
def set_audio_root(self, audio_root=""):
self.config["audio_root"] = audio_root
def set_vocab_filename(self, vocab_filename='dict.txt'):
self.config['vocab_filename'] = vocab_filename
def set_vocab_filename(self, vocab_filename="dict.txt"):
self.config["vocab_filename"] = vocab_filename
def set_specaugment(self, time_wrap_w: int, freq_mask_n: int,
freq_mask_f: int, time_mask_n: int, time_mask_t: int,
time_mask_p: float):
self.config['specaugment'] = {
'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n,
'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n,
'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p,
def set_specaugment(
self,
time_wrap_w: int,
freq_mask_n: int,
freq_mask_f: int,
time_mask_n: int,
time_mask_t: int,
time_mask_p: float,
):
self.config["specaugment"] = {
"time_wrap_W": time_wrap_w,
"freq_mask_N": freq_mask_n,
"freq_mask_F": freq_mask_f,
"time_mask_N": time_mask_n,
"time_mask_T": time_mask_t,
"time_mask_p": time_mask_p,
}
def set_specaugment_lb_policy(self):
self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27,
time_mask_n=1, time_mask_t=100, time_mask_p=1.0)
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=1,
freq_mask_f=27,
time_mask_n=1,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_ld_policy(self):
self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27,
time_mask_n=2, time_mask_t=100, time_mask_p=1.0)
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=100,
time_mask_p=1.0,
)
def set_input_channels(self, input_channels=1):
self.config['input_channels'] = input_channels
self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel=80):
self.config['input_feat_per_channel'] = input_feat_per_channel
self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config['bpe_tokenizer'] = bpe_tokenizer
self.config["bpe_tokenizer"] = bpe_tokenizer
def set_feature_transforms(self, split, transforms: List[str]):
if 'transforms' not in self.config:
self.config['transforms'] = {}
self.config['transforms'][split] = transforms
if "transforms" not in self.config:
self.config["transforms"] = {}
self.config["transforms"][split] = transforms

View File

@ -5,30 +5,35 @@
# LICENSE file in the root directory of this source tree.
import argparse
import csv
import logging
from tempfile import NamedTemporaryFile
import os
import os.path as op
import shutil
from typing import Tuple, Optional
import csv
from tempfile import NamedTemporaryFile
from typing import Optional, Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
from torch import Tensor
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml, filter_manifest_df
)
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class CoVoST(Dataset):
@ -44,40 +49,82 @@ class CoVoST(Dataset):
found at root path. (default: ``False``).
"""
CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
CV_URL_TEMPLATE = (
"https://voice-prod-bundler-ee1969a6ce8178826482b88"
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
)
COVOST_URL_TEMPLATE = (
"https://dl.fbaipublicfiles.com/covost/"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
)
VERSIONS = {2}
SPLITS = ['train', 'dev', 'test']
SPLITS = ["train", "dev", "test"]
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
XX_EN_LANGUAGES = {
1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn',
'zh-CN'],
2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn',
'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy']
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
2: [
"fr",
"de",
"es",
"ca",
"it",
"ru",
"zh-CN",
"pt",
"fa",
"et",
"mn",
"nl",
"tr",
"ar",
"sv-SE",
"lv",
"sl",
"ta",
"ja",
"id",
"cy",
],
}
EN_XX_LANGUAGES = {
1: [],
2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et',
'id',
'ar', 'ta', 'lv', 'ja']
2: [
"de",
"tr",
"fa",
"sv-SE",
"mn",
"zh-CN",
"cy",
"ca",
"sl",
"et",
"id",
"ar",
"ta",
"lv",
"ja",
],
}
def __init__(
self, root: str, split: str, source_language: str,
target_language: Optional[str] = None, version: int = 2,
download: bool = False
self,
root: str,
split: str,
source_language: str,
target_language: Optional[str] = None,
version: int = 2,
download: bool = False,
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
self.no_translation = (target_language is None)
self.no_translation = target_language is None
if not self.no_translation:
assert 'en' in {source_language, target_language}
if source_language == 'en':
assert "en" in {source_language, target_language}
if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
@ -85,51 +132,60 @@ class CoVoST(Dataset):
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
target_language = 'de' if source_language == 'en' else 'en'
target_language = "de" if source_language == "en" else "en"
self.root = os.path.join(root, 'raw')
self.root = os.path.join(root, "raw")
os.makedirs(self.root, exist_ok=True)
cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version],
lang=source_language)
cv_url = self.CV_URL_TEMPLATE.format(
ver=self.CV_VERSION_ID[version], lang=source_language
)
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
if download:
if not os.path.isfile(cv_archive):
download_url(cv_url, self.root, hash_value=None)
extract_archive(cv_archive)
covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language,
tgt_lang=target_language)
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
)
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
if download:
if not os.path.isfile(covost_archive):
download_url(covost_url, self.root, hash_value=None)
extract_archive(covost_archive)
cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv'))
cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
covost_tsv = self.load_from_tsv(
os.path.join(self.root,
os.path.basename(covost_url).replace('.tar.gz', ''))
os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
)
df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']],
right=covost_tsv[['path', 'translation', 'split']],
how='inner', on='path')
if split == 'train':
df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')]
df = pd.merge(
left=cv_tsv[["path", "sentence", "client_id"]],
right=covost_tsv[["path", "translation", "split"]],
how="inner",
on="path",
)
if split == "train":
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
df = df[df['split'] == split]
self.data = df.to_dict(orient='index').items()
df = df[df["split"] == split]
self.data = df.to_dict(orient="index").items()
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
@classmethod
def load_from_tsv(cls, path: str):
return pd.read_csv(
path, sep='\t', header=0, encoding='utf-8', escapechar='\\',
quoting=csv.QUOTE_NONE, na_filter=False
path,
sep="\t",
header=0,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
na_filter=False,
)
def __getitem__(
self, n: int
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
@ -141,12 +197,12 @@ class CoVoST(Dataset):
sample_id)``
"""
data = self.data[n]
path = os.path.join(self.root, 'clips', data['path'])
path = os.path.join(self.root, "clips", data["path"])
waveform, sample_rate = torchaudio.load(path)
sentence = data['sentence']
translation = None if self.no_translation else data['translation']
speaker_id = data['client_id']
_id = data['path'].replace('.mp3', '')
sentence = data["sentence"]
translation = None if self.no_translation else data["translation"]
speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int:
@ -157,76 +213,82 @@ def process(args):
root = op.join(args.data_root, args.src_lang)
os.makedirs(root, exist_ok=True)
# Extract features
feature_root = op.join(root, 'fbank80')
feature_root = op.join(root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in CoVoST.SPLITS:
print(f'Fetching split {split}...')
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang,
download=True)
print('Extracting log mel filter bank features...')
print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(waveform, sample_rate,
op.join(feature_root, f'{utt_id}.npy'))
extract_fbank_features(
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP
zip_filename = 'fbank80.zip'
zip_filename = "fbank80.zip"
zip_path = op.join(root, zip_filename)
print('ZIPing features...')
print("ZIPing features...")
create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...')
zip_manifest = get_zip_manifest(args.data_root,
f'{args.src_lang}/{zip_filename}')
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
# Generate TSV manifest
print('Generating manifest...')
print("Generating manifest...")
train_text = []
task = f'asr_{args.src_lang}'
task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
task = f'st_{args.src_lang}_{args.tgt_lang}'
task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest['id'].append(utt_id)
manifest['audio'].append(zip_manifest[utt_id])
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
manifest['tgt_text'].append(
src_utt if args.tgt_lang is None else tgt_utt
)
manifest['speaker'].append(speaker_id)
is_train_split = split.startswith('train')
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train")
if is_train_split:
train_text.extend(manifest['tgt_text'])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv'))
save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
# Generate vocab
vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size)
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}'
with NamedTemporaryFile(mode='w') as f:
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + '\n')
gen_vocab(f.name, op.join(root, spm_filename_prefix),
args.vocab_type, args.vocab_size)
f.write(t + "\n")
gen_vocab(
f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
)
# Generate config YAML
gen_config_yaml(root, spm_filename_prefix + '.model',
yaml_filename=f'config_{task}.yaml',
specaugment_policy='lb')
gen_config_yaml(
root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-root', '-d', required=True, type=str)
parser.add_argument('--vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--vocab-size', default=1000, type=int)
parser.add_argument('--src-lang', '-s', required=True, type=str)
parser.add_argument('--tgt-lang', '-t', type=str)
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args()
process(args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -6,91 +6,114 @@
import argparse
import logging
from tempfile import NamedTemporaryFile
import os
import shutil
import os.path as op
import shutil
from tempfile import NamedTemporaryFile
from tqdm import tqdm
from torchaudio.datasets import LIBRISPEECH
import pandas as pd
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml
create_zip,
extract_fbank_features,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torchaudio.datasets import LIBRISPEECH
from tqdm import tqdm
log = logging.getLogger(__name__)
SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean',
'dev-other', 'test-clean', 'test-other']
SPLITS = [
"train-clean-100",
"train-clean-360",
"train-other-500",
"dev-clean",
"dev-other",
"test-clean",
"test-other",
]
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
def process(args):
os.makedirs(args.output_root, exist_ok=True)
# Extract features
feature_root = op.join(args.output_root, 'fbank80')
feature_root = op.join(args.output_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in SPLITS:
print(f'Fetching split {split}...')
print(f"Fetching split {split}...")
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
print('Extracting log mel filter bank features...')
print("Extracting log mel filter bank features...")
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
extract_fbank_features(wav, sample_rate,
op.join(feature_root, f'{sample_id}.npy'))
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
extract_fbank_features(
wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
)
# Pack features into ZIP
zip_filename = 'fbank80.zip'
zip_filename = "fbank80.zip"
zip_path = op.join(args.output_root, zip_filename)
print('ZIPing features...')
print("ZIPing features...")
create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...')
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
# Generate TSV manifest
print('Generating manifest...')
print("Generating manifest...")
train_text = []
for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(args.output_root, url=split)
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
manifest['id'].append(sample_id)
manifest['audio'].append(zip_manifest[sample_id])
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
manifest["id"].append(sample_id)
manifest["audio"].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
manifest['tgt_text'].append(utt)
manifest['speaker'].append(spk_id)
save_df_to_tsv(pd.DataFrame.from_dict(manifest),
op.join(args.output_root, f'{split}.tsv'))
if split.startswith('train'):
train_text.extend(manifest['tgt_text'])
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(utt)
manifest["speaker"].append(spk_id)
save_df_to_tsv(
pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
)
if split.startswith("train"):
train_text.extend(manifest["tgt_text"])
# Generate vocab
vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size)
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}'
with NamedTemporaryFile(mode='w') as f:
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + '\n')
gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix),
args.vocab_type, args.vocab_size)
f.write(t + "\n")
gen_vocab(
f.name,
op.join(args.output_root, spm_filename_prefix),
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
gen_config_yaml(args.output_root, spm_filename_prefix + '.model',
specaugment_policy='ld')
gen_config_yaml(
args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output-root', '-o', required=True, type=str)
parser.add_argument('--vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--vocab-size', default=10000, type=int)
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
args = parser.parse_args()
process(args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -6,29 +6,34 @@
import argparse
import logging
from tempfile import NamedTemporaryFile
import os
import os.path as op
import shutil
from typing import Tuple
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
from torch import Tensor
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml, filter_manifest_df
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
TASKS = ['asr', 'st']
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
TASKS = ["asr", "st"]
class MUSTC(Dataset):
@ -37,49 +42,55 @@ class MUSTC(Dataset):
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru']
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = op.join(root, f'en-{lang}', 'data', split)
wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt')
_root = op.join(root, f"en-{lang}", "data", split)
wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
# Load audio segments
try:
import yaml
except ImportError:
print('Please install PyYAML to load YAML files for '
'the MuST-C dataset')
with open(op.join(txt_root, f'{split}.yaml')) as f:
print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
with open(op.join(txt_root, f"{split}.yaml")) as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
for _lang in ['en', lang]:
with open(op.join(txt_root, f'{split}.{_lang}')) as f:
for _lang in ["en", lang]:
with open(op.join(txt_root, f"{split}.{_lang}")) as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']):
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = op.join(wav_root, wav_filename)
sample_rate = torchaudio.info(wav_path)[0].rate
seg_group = sorted(_seg_group, key=lambda x: x['offset'])
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment['offset']) * sample_rate)
n_frames = int(float(segment['duration']) * sample_rate)
_id = f'{op.splitext(wav_filename)[0]}_{i}'
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{op.splitext(wav_filename)[0]}_{i}"
self.data.append(
(wav_path, offset, n_frames, sample_rate, segment['en'],
segment[lang], segment['speaker_id'], _id)
(
wav_path,
offset,
n_frames,
sample_rate,
segment["en"],
segment[lang],
segment["speaker_id"],
_id,
)
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \
self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset,
num_frames=n_frames)
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int:
@ -88,85 +99,102 @@ class MUSTC(Dataset):
def process(args):
for lang in MUSTC.LANGUAGES:
cur_root = op.join(args.data_root, f'en-{lang}')
cur_root = op.join(args.data_root, f"en-{lang}")
if not op.isdir(cur_root):
print(f'{cur_root} does not exist. Skipped.')
print(f"{cur_root} does not exist. Skipped.")
continue
# Extract features
feature_root = op.join(cur_root, 'fbank80')
feature_root = op.join(cur_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in MUSTC.SPLITS:
print(f'Fetching split {split}...')
print(f"Fetching split {split}...")
dataset = MUSTC(args.data_root, lang, split)
print('Extracting log mel filter bank features...')
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(waveform, sample_rate,
op.join(feature_root, f'{utt_id}.npy'))
extract_fbank_features(
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP
zip_filename = 'fbank80.zip'
zip_filename = "fbank80.zip"
zip_path = op.join(cur_root, zip_filename)
print('ZIPing features...')
print("ZIPing features...")
create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...')
zip_manifest = get_zip_manifest(args.data_root,
f'en-{lang}/{zip_filename}')
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
# Generate TSV manifest
print('Generating manifest...')
print("Generating manifest...")
train_text = {task: [] for task in TASKS}
for split in MUSTC.SPLITS:
is_train_split = split.startswith('train')
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
text = {task: [] for task in TASKS}
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest['id'].append(utt_id)
manifest['audio'].append(zip_manifest[utt_id])
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
text['asr'].append(src_utt)
text['st'].append(tgt_utt)
manifest['speaker'].append(speaker_id)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
text["asr"].append(src_utt)
text["st"].append(tgt_utt)
manifest["speaker"].append(speaker_id)
if is_train_split:
for task in TASKS:
train_text[task].extend(text[task])
for task in TASKS:
manifest['tgt_text'] = text[task]
manifest["tgt_text"] = text[task]
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv'))
save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
# Generate vocab
for task in TASKS:
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
if task == 'st':
if task == "st":
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
vocab_size_str = '' if vocab_type == 'char' else str(vocab_size)
spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}'
with NamedTemporaryFile(mode='w') as f:
vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text[task]:
f.write(t + '\n')
gen_vocab(f.name, op.join(cur_root, spm_filename_prefix),
vocab_type, vocab_size)
f.write(t + "\n")
gen_vocab(
f.name,
op.join(cur_root, spm_filename_prefix),
vocab_type,
vocab_size,
)
# Generate config YAML
gen_config_yaml(cur_root, spm_filename_prefix + '.model',
yaml_filename=f'config_{task}.yaml',
specaugment_policy='lb')
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-root', '-d', required=True, type=str)
parser.add_argument('--asr-vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--st-vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--asr-vocab-size', default=5000, type=int)
parser.add_argument('--st-vocab-size', default=8000, type=int)
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--asr-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument(
"--st-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--asr-vocab-size", default=5000, type=int)
parser.add_argument("--st-vocab-size", default=8000, type=int)
args = parser.parse_args()
process(args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -12,9 +12,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
"""
import argparse
from itertools import chain
import sys
import random
import sys
from itertools import chain
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
@ -22,17 +22,21 @@ from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
def main():
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
help='path to system output')
parser.add_argument('--ref', default='', metavar='FILE',
help='path to references')
parser.add_argument('--output', default='', metavar='FILE',
help='print outputs into a pretty format')
parser.add_argument(
"--sys", nargs="*", default="", metavar="FILE", help="path to system output"
)
parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
parser.add_argument(
"--output",
default="",
metavar="FILE",
help="print outputs into a pretty format",
)
args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print('pairwise BLEU: %.2f' % pairwise(hypos))
print("pairwise BLEU: %.2f" % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
@ -58,18 +62,18 @@ def load_sys(paths):
# S: source
# T: target
# D: detokenized system output
if line.startswith(('S-', 'T-', 'D-')):
i = int(line[line.find('-')+1:line.find('\t')])
if line.startswith('S-'):
src[i] = line.split('\t')[1]
if line.startswith('T-'):
tgt[i] = line.split('\t')[1]
if line.startswith('D-'):
if line.startswith(("S-", "T-", "D-")):
i = int(line[line.find("-") + 1 : line.find("\t")])
if line.startswith("S-"):
src[i] = line.split("\t")[1]
if line.startswith("T-"):
tgt[i] = line.split("\t")[1]
if line.startswith("D-"):
if i not in hypos:
hypos[i] = []
log_probs[i] = []
hypos[i].append(line.split('\t')[2])
log_probs[i].append(float(line.split('\t')[1]))
hypos[i].append(line.split("\t")[2])
log_probs[i].append(float(line.split("\t")[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
@ -79,34 +83,34 @@ def load_ref(path):
src, tgt, refs = [], [], []
i = 0
while i < len(lines):
if lines[i].startswith('S-'):
src.append(lines[i].split('\t')[1].rstrip())
if lines[i].startswith("S-"):
src.append(lines[i].split("\t")[1].rstrip())
i += 1
elif lines[i].startswith('T-'):
tgt.append(lines[i].split('\t')[1].rstrip())
elif lines[i].startswith("T-"):
tgt.append(lines[i].split("\t")[1].rstrip())
i += 1
else:
a = []
while i < len(lines) and lines[i].startswith('R'):
a.append(lines[i].split('\t')[1].rstrip())
while i < len(lines) and lines[i].startswith("R"):
a.append(lines[i].split("\t")[1].rstrip())
i += 1
refs.append(a)
return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path):
with open(path, 'w') as f:
with open(path, "w") as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
f.write(s + '\n')
f.write(t + '\n')
f.write('\n')
f.write(s + "\n")
f.write(t + "\n")
f.write("\n")
for h, lp in zip(hs, lps):
f.write('\t%f\t%s\n' % (lp, h.strip()))
f.write('------------------------------------------------------\n')
f.write("\t%f\t%s\n" % (lp, h.strip()))
f.write("------------------------------------------------------\n")
def corpus_bleu(sys_stream, ref_streams):
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
return bleu.score
@ -116,9 +120,11 @@ def sentence_bleu(hypothesis, reference):
bleu.counts[i] += 1
bleu.totals[i] += 1
bleu = compute_bleu(
bleu.counts, bleu.totals,
bleu.sys_len, bleu.ref_len,
smooth_method='exp',
bleu.counts,
bleu.totals,
bleu.sys_len,
bleu.ref_len,
smooth_method="exp",
)
return bleu.score
@ -150,7 +156,7 @@ def multi_ref(refs, hypos):
best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best))
ref_cnt += len(a)
print('#refs covered: %.2f' % (ref_cnt / len(refs)))
print("#refs covered: %.2f" % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
@ -160,33 +166,32 @@ def multi_ref(refs, hypos):
k = len(hypos)
m = len(refs)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
duplicated_refs = [
[ref for ref in refs_i for _ in range(k)]
for refs_i in refs
]
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
loo_bleus = []
for held_out_ref in range(m):
remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
remaining_refs = (
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
)
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
def intra_ref(refs):
print('ref pairwise BLEU: %.2f' % pairwise(refs))
print("ref pairwise BLEU: %.2f" % pairwise(refs))
refs = list(zip(*refs))
m = len(refs)
concat_h = []
concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs):
rest = refs[:i] + refs[i+1:]
rest = refs[:i] + refs[i + 1 :]
concat_h.append(h)
for j in range(m - 1):
concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -21,6 +21,6 @@ class LogSumExpMoE(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
posterior, = ctx.saved_tensors
(posterior,) = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None

View File

@ -26,15 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
def forward(self, encoder_out):
if not (
hasattr(encoder_out, 'encoder_out')
and hasattr(encoder_out, 'encoder_padding_mask')
hasattr(encoder_out, "encoder_out")
and hasattr(encoder_out, "encoder_padding_mask")
and encoder_out.encoder_out.size(2) == self.embed_dim
):
raise ValueError('Unexpected format for encoder_out')
raise ValueError("Unexpected format for encoder_out")
# mean pooling over time
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0

View File

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import metrics, utils
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask
@ -13,7 +12,7 @@ from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
@register_task('translation_moe')
@register_task("translation_moe")
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
@ -58,19 +57,19 @@ class TranslationMoETask(TranslationTask):
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
if args.method == 'sMoElp':
if args.method == "sMoElp":
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
elif args.method == 'sMoEup':
elif args.method == "sMoEup":
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
elif args.method == 'hMoElp':
elif args.method == "hMoElp":
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
elif args.method == 'hMoEup':
elif args.method == "hMoEup":
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
@ -78,50 +77,56 @@ class TranslationMoETask(TranslationTask):
# add indicator tokens for each expert
for i in range(args.num_experts):
# add to both dictionaries in case we're sharing embeddings
src_dict.add_symbol('<expert_{}>'.format(i))
tgt_dict.add_symbol('<expert_{}>'.format(i))
src_dict.add_symbol("<expert_{}>".format(i))
tgt_dict.add_symbol("<expert_{}>".format(i))
super().__init__(args, src_dict, tgt_dict)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not self.uniform_prior and not hasattr(model, 'gating_network'):
if not self.uniform_prior and not hasattr(model, "gating_network"):
if self.args.mean_pool_gating_network:
if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
if getattr(args, "mean_pool_gating_network_encoder_dim", None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, 'encoder_embed_dim', None):
elif getattr(args, "encoder_embed_dim", None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
else:
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
raise ValueError(
"Must specify --mean-pool-gating-network-encoder-dim"
)
if getattr(args, 'mean_pool_gating_network_dropout', None):
if getattr(args, "mean_pool_gating_network_dropout", None):
dropout = args.mean_pool_gating_network_dropout
elif getattr(args, 'dropout', None):
elif getattr(args, "dropout", None):
dropout = args.dropout
else:
raise ValueError('Must specify --mean-pool-gating-network-dropout')
raise ValueError("Must specify --mean-pool-gating-network-dropout")
model.gating_network = MeanPoolGatingNetwork(
encoder_dim, args.num_experts, dropout,
encoder_dim,
args.num_experts,
dropout,
)
else:
raise ValueError(
'translation_moe task with learned prior requires the model to '
'have a gating network; try using --mean-pool-gating-network'
"translation_moe task with learned prior requires the model to "
"have a gating network; try using --mean-pool-gating-network"
)
return model
def expert_index(self, i):
return i + self.tgt_dict.index('<expert_0>')
return i + self.tgt_dict.index("<expert_0>")
def _get_loss(self, sample, model, criterion):
assert hasattr(criterion, 'compute_loss'), \
'translation_moe task requires the criterion to implement the compute_loss() method'
assert hasattr(
criterion, "compute_loss"
), "translation_moe task requires the criterion to implement the compute_loss() method"
k = self.args.num_experts
bsz = sample['target'].size(0)
bsz = sample["target"].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder(
@ -134,20 +139,22 @@ class TranslationMoETask(TranslationTask):
def get_lprob_yz(winners=None):
encoder_out = model.encoder(
src_tokens=sample['net_input']['src_tokens'],
src_lengths=sample['net_input']['src_lengths'],
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
)
if winners is None:
lprob_y = []
for i in range(k):
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
prev_output_tokens_k = sample["net_input"][
"prev_output_tokens"
].clone()
assert not prev_output_tokens_k.requires_grad
prev_output_tokens_k[:, 0] = self.expert_index(i)
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
else:
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
prev_output_tokens_k[:, 0] = self.expert_index(winners)
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
@ -177,17 +184,21 @@ class TranslationMoETask(TranslationTask):
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = loss.sum()
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
'loss': utils.item(loss.data),
'ntokens': sample['ntokens'],
'nsentences': bsz,
'sample_size': sample_size,
'posterior': prob_z_xy.float().sum(dim=0).cpu(),
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": bsz,
"sample_size": sample_size,
"posterior": prob_z_xy.float().sum(dim=0).cpu(),
}
return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
if ignore_grad:
@ -201,7 +212,15 @@ class TranslationMoETask(TranslationTask):
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None):
def inference_step(
self,
generator,
models,
sample,
prefix_tokens=None,
expert=None,
constraints=None,
):
expert = expert or self.args.gen_expert
with torch.no_grad():
return generator.generate(
@ -215,6 +234,6 @@ class TranslationMoETask(TranslationTask):
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
metrics.log_scalar(
'posterior',
sum(log['posterior'] for log in logging_outputs if 'posterior' in log)
"posterior",
sum(log["posterior"] for log in logging_outputs if "posterior" in log),
)

View File

@ -4,37 +4,38 @@
# LICENSE file in the root directory of this source tree.
import argparse
import numpy as np
import sys
import numpy as np
aggregate_funcs = {
'std': np.std,
'var': np.var,
'median': np.median,
'mean': np.mean,
'min': np.min,
'max': np.max,
"std": np.std,
"var": np.var,
"median": np.median,
"mean": np.mean,
"min": np.min,
"max": np.max,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_file', required=True, type=str)
parser.add_argument('-n', '--repeat_times', required=True, type=int)
parser.add_argument('-o', '--output_file', required=False)
parser.add_argument('-f', '--func', required=False, default='mean')
parser.add_argument("-i", "--input_file", required=True, type=str)
parser.add_argument("-n", "--repeat_times", required=True, type=int)
parser.add_argument("-o", "--output_file", required=False)
parser.add_argument("-f", "--func", required=False, default="mean")
args = parser.parse_args()
stream = open(args.output_file, 'w') if args.output_file else sys.stdout
stream = open(args.output_file, "w") if args.output_file else sys.stdout
segment_scores = []
for line in open(args.input_file):
segment_scores.append(float(line.strip()))
if len(segment_scores) == args.repeat_times:
stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores)))
stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
segment_scores = []
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -4,14 +4,13 @@
# LICENSE file in the root directory of this source tree.
import argparse
import os
import sys
import subprocess
import tempfile
import math
from itertools import combinations
import os
import subprocess
import sys
import tempfile
from collections import defaultdict
from itertools import combinations
def read_translations(path, n_repeats):
@ -19,7 +18,7 @@ def read_translations(path, n_repeats):
segment_translations = []
translations = defaultdict(list)
for line in open(path):
segment_translations.append(' '.join(line.split()))
segment_translations.append(" ".join(line.split()))
if len(segment_translations) == n_repeats:
translations[segment_counter] = segment_translations
segment_translations = []
@ -30,42 +29,55 @@ def read_translations(path, n_repeats):
def generate_input(translations, n_repeats):
_, ref_path = tempfile.mkstemp()
_, mt_path = tempfile.mkstemp()
ref_fh = open(ref_path, 'w')
mt_fh = open(mt_path, 'w')
ref_fh = open(ref_path, "w")
mt_fh = open(mt_path, "w")
for segid in sorted(translations.keys()):
assert len(translations[segid]) == n_repeats
indexes = combinations(range(n_repeats), 2)
for idx1, idx2 in indexes:
mt_fh.write(translations[segid][idx1].strip() + '\n')
ref_fh.write(translations[segid][idx2].strip() + '\n')
sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path))
mt_fh.write(translations[segid][idx1].strip() + "\n")
ref_fh.write(translations[segid][idx2].strip() + "\n")
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
return ref_path, mt_path
def run_meteor(ref_path, mt_path, metric_path, lang='en'):
def run_meteor(ref_path, mt_path, metric_path, lang="en"):
_, out_path = tempfile.mkstemp()
subprocess.call([
'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path,
'-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R
'-norm',
'-l', lang], stdout=open(out_path, 'w'))
subprocess.call(
[
"java",
"-Xmx2G",
"-jar",
metric_path,
mt_path,
ref_path,
"-p",
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
"-norm",
"-l",
lang,
],
stdout=open(out_path, "w"),
)
os.remove(ref_path)
os.remove(mt_path)
sys.stderr.write('\nSaved Meteor output to %s' % out_path)
sys.stderr.write("\nSaved Meteor output to %s" % out_path)
return out_path
def read_output(meteor_output_path, n_repeats):
n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2))
n_combinations = math.factorial(n_repeats) / (
math.factorial(2) * math.factorial(n_repeats - 2)
)
raw_scores = []
average_scores = []
for line in open(meteor_output_path):
if not line.startswith('Segment '):
if not line.startswith("Segment "):
continue
score = float(line.strip().split('\t')[1])
score = float(line.strip().split("\t")[1])
raw_scores.append(score)
if len(raw_scores) == n_combinations:
average_scores.append(sum(raw_scores)/n_combinations)
average_scores.append(sum(raw_scores) / n_combinations)
raw_scores = []
os.remove(meteor_output_path)
return average_scores
@ -73,25 +85,25 @@ def read_output(meteor_output_path, n_repeats):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input')
parser.add_argument('-n', '--repeat_times', type=int)
parser.add_argument('-m', '--meteor')
parser.add_argument('-o', '--output')
parser.add_argument("-i", "--input")
parser.add_argument("-n", "--repeat_times", type=int)
parser.add_argument("-m", "--meteor")
parser.add_argument("-o", "--output")
args = parser.parse_args()
translations = read_translations(args.infile, args.repetitions)
sys.stderr.write('\nGenerating input for Meteor...')
sys.stderr.write("\nGenerating input for Meteor...")
ref_path, mt_path = generate_input(translations, args.repetitions)
sys.stderr.write('\nRunning Meteor...')
sys.stderr.write("\nRunning Meteor...")
out_path = run_meteor(ref_path, mt_path, args.meteor)
sys.stderr.write('\nReading output...')
sys.stderr.write("\nReading output...")
scores = read_output(out_path, args.repetitions)
sys.stderr.write('\nWriting results...')
with open(args.output, 'w') as o:
sys.stderr.write("\nWriting results...")
with open(args.output, "w") as o:
for scr in scores:
o.write('{}\n'.format(scr))
o.write("{}\n".format(scr))
o.close()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -8,21 +8,21 @@ import sys
def _normalize_spaces(line):
return ' '.join(line.split())
return " ".join(line.split())
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_file', required=True, type=str)
parser.add_argument('-n', '--repeat_times', required=True, type=int)
parser.add_argument('-o', '--output_file', required=False, type=str)
parser.add_argument("-i", "--input_file", required=True, type=str)
parser.add_argument("-n", "--repeat_times", required=True, type=int)
parser.add_argument("-o", "--output_file", required=False, type=str)
args = parser.parse_args()
stream = open(args.output_file, 'w') if args.output_file else sys.stdout
stream = open(args.output_file, "w") if args.output_file else sys.stdout
for line in open(args.input_file):
for _ in range(args.repeat_times):
stream.write(_normalize_spaces(line) + '\n')
stream.write(_normalize_spaces(line) + "\n")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -8,30 +8,31 @@
Helper script to pre-compute embeddings for a wav2letter++ dataset
"""
import argparse
import glob
import os
import os.path as osp
import pprint
import glob, os, argparse
import soundfile as sf
import torch
import tqdm
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
from torch import nn
from torch.utils.data import DataLoader
try:
import tqdm
except:
print("Install tqdm to use --log-format=tqdm")
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
import tqdm
import soundfile as sf
from torch.utils.data import DataLoader
import os.path as osp
class FilesDataset:
def __init__(self, files, labels):
self.files = files
if labels and osp.exists(labels):
with open(labels, 'r') as lbl_f:
with open(labels, "r") as lbl_f:
self.labels = [line.rstrip() for line in lbl_f]
else:
self.labels = labels
@ -50,7 +51,7 @@ class FilesDataset:
if self.labels:
if isinstance(self.labels, str):
lbl_file = osp.splitext(fname)[0] + "." + self.labels
with open(lbl_file, 'r') as lblf:
with open(lbl_file, "r") as lblf:
lbls = lblf.readline()
assert lbls is not None
else:
@ -116,24 +117,24 @@ class DatasetWriter:
assert len(files) > 0
if self.args.shard is not None:
files = files[self.args.shard::self.args.num_shards]
files = files[self.args.shard :: self.args.num_shards]
lbls = []
with open(self.data_file(split), 'w') as srcf:
with open(self.data_file(split), "w") as srcf:
for line, lbl in self.iterate(files):
print(line, file=srcf)
if self.args.labels:
lbls.append(lbl + '\n')
lbls.append(lbl + "\n")
if self.args.labels:
assert all(a is not None for a in lbls)
with open(self.lbl_file(split), 'w') as lblf:
with open(self.lbl_file(split), "w") as lblf:
lblf.writelines(lbls)
def iterate(self, files):
data = self.load_data(files)
for samples in tqdm.tqdm(data, total=len(files)//32):
for samples in tqdm.tqdm(data, total=len(files) // 32):
for wav, lbl in samples:
x = wav.unsqueeze(0).float().cuda()
@ -162,7 +163,6 @@ class DatasetWriter:
idx = torch.cat(result, dim=0)
yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
def lbl_file(self, name):
shard_part = "" if self.args.shard is None else f".{self.args.shard}"
return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
@ -230,7 +230,9 @@ class DatasetWriter:
self.process_splits()
if hasattr(self.model.feature_extractor, "vars") and (self.args.shard is None or self.args.shard == 0):
if hasattr(self.model.feature_extractor, "vars") and (
self.args.shard is None or self.args.shard == 0
):
vars = (
self.model.feature_extractor.vars.view(
self.model.feature_extractor.banks,
@ -248,4 +250,4 @@ if __name__ == "__main__":
write_data = DatasetWriter()
write_data()
print("Done.")
print("Done.")

View File

@ -14,13 +14,12 @@ import os
from shutil import copy
import h5py
import soundfile as sf
import numpy as np
import soundfile as sf
import torch
from torch import nn
import tqdm
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
from torch import nn
def read_audio(fname):
@ -33,7 +32,6 @@ def read_audio(fname):
class PretrainedWav2VecModel(nn.Module):
def __init__(self, fname):
super().__init__()
@ -55,32 +53,33 @@ class PretrainedWav2VecModel(nn.Module):
class EmbeddingWriterConfig(argparse.ArgumentParser):
def __init__(self):
super().__init__("Pre-compute embeddings for wav2letter++ datasets")
kwargs = {"action": "store", "type": str, "required": True}
self.add_argument("--input", "-i",
help="Input Directory", **kwargs)
self.add_argument("--output", "-o",
help="Output Directory", **kwargs)
self.add_argument("--model",
help="Path to model checkpoint", **kwargs)
self.add_argument("--split",
help="Dataset Splits", nargs='+', **kwargs)
self.add_argument("--ext", default="wav", required=False,
help="Audio file extension")
self.add_argument("--input", "-i", help="Input Directory", **kwargs)
self.add_argument("--output", "-o", help="Output Directory", **kwargs)
self.add_argument("--model", help="Path to model checkpoint", **kwargs)
self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs)
self.add_argument(
"--ext", default="wav", required=False, help="Audio file extension"
)
self.add_argument("--no-copy-labels", action="store_true",
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.")
self.add_argument("--use-feat", action="store_true",
help="Use the feature vector ('z') instead of context vector ('c') for features")
self.add_argument("--gpu",
help="GPU to use", default=0, type=int)
self.add_argument(
"--no-copy-labels",
action="store_true",
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.",
)
self.add_argument(
"--use-feat",
action="store_true",
help="Use the feature vector ('z') instead of context vector ('c') for features",
)
self.add_argument("--gpu", help="GPU to use", default=0, type=int)
class Prediction():
class Prediction:
""" Lightweight wrapper around a fairspeech embedding model """
def __init__(self, fname, gpu=0):
@ -95,7 +94,7 @@ class Prediction():
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
class H5Writer():
class H5Writer:
""" Write features as hdf5 file in wav2letter++ compatible format """
def __init__(self, fname):
@ -112,7 +111,7 @@ class H5Writer():
class EmbeddingDatasetWriter(object):
""" Given a model and a wav2letter++ dataset, pre-compute and store embeddings
"""Given a model and a wav2letter++ dataset, pre-compute and store embeddings
Args:
input_root, str :
@ -123,13 +122,17 @@ class EmbeddingDatasetWriter(object):
Dataset split
"""
def __init__(self, input_root, output_root, split,
model_fname,
extension="wav",
gpu=0,
verbose=False,
use_feat=False,
):
def __init__(
self,
input_root,
output_root,
split,
model_fname,
extension="wav",
gpu=0,
verbose=False,
use_feat=False,
):
assert os.path.exists(model_fname)
@ -143,8 +146,9 @@ class EmbeddingDatasetWriter(object):
self.extension = extension
self.use_feat = use_feat
assert os.path.exists(self.input_path), \
"Input path '{}' does not exist".format(self.input_path)
assert os.path.exists(self.input_path), "Input path '{}' does not exist".format(
self.input_path
)
def _progress(self, iterable, **kwargs):
if self.verbose:
@ -176,7 +180,11 @@ class EmbeddingDatasetWriter(object):
def copy_labels(self):
self.require_output_path()
labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))))
labels = list(
filter(
lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))
)
)
for fname in tqdm.tqdm(labels):
copy(fname, self.output_path)
@ -191,10 +199,16 @@ class EmbeddingDatasetWriter(object):
paths = self.input_fnames
fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \
map(os.path.basename, paths))
fnames_context = map(
lambda x: os.path.join(
self.output_path, x.replace("." + self.extension, ".h5context")
),
map(os.path.basename, paths),
)
for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)):
for name, target_fname in self._progress(
zip(paths, fnames_context), total=len(self)
):
wav, sr = read_audio(name)
z, c = self.model(wav)
feat = z if self.use_feat else c
@ -204,7 +218,8 @@ class EmbeddingDatasetWriter(object):
def __repr__(self):
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
n_files=len(self), **self.__dict__)
n_files=len(self), **self.__dict__
)
if __name__ == "__main__":

View File

@ -10,32 +10,50 @@ Data pre-processing: build vocabularies and binarize training data.
import argparse
import glob
import os
import soundfile
import random
import soundfile
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
help='percentage of data to use as validation set (between 0 and 1)')
parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
help='if set, path must contain this substring for a file to be included in the manifest')
parser.add_argument(
"root", metavar="DIR", help="root directory containing flac files to index"
)
parser.add_argument(
"--valid-percent",
default=0.01,
type=float,
metavar="D",
help="percentage of data to use as validation set (between 0 and 1)",
)
parser.add_argument(
"--dest", default=".", type=str, metavar="DIR", help="output directory"
)
parser.add_argument(
"--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
)
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
parser.add_argument(
"--path-must-contain",
default=None,
type=str,
metavar="FRAG",
help="if set, path must contain this substring for a file to be included in the manifest",
)
return parser
def main(args):
assert args.valid_percent >= 0 and args.valid_percent <= 1.
assert args.valid_percent >= 0 and args.valid_percent <= 1.0
dir_path = os.path.realpath(args.root)
search_path = os.path.join(dir_path, '**/*.' + args.ext)
search_path = os.path.join(dir_path, "**/*." + args.ext)
rand = random.Random(args.seed)
with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open(
os.path.join(args.dest, "valid.tsv"), "w"
) as valid_f:
print(dir_path, file=train_f)
print(dir_path, file=valid_f)
@ -47,10 +65,12 @@ def main(args):
frames = soundfile.info(fname).frames
dest = train_f if rand.random() > args.valid_percent else valid_f
print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
print(
"{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)

View File

@ -4,16 +4,17 @@
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
__all__ = ['pdb']
__version__ = '1.0.0a0'
__all__ = ["pdb"]
__version__ = "1.0.0a0"
import sys
# backwards compatibility to support `from fairseq.meters import AverageMeter`
from fairseq.logging import meters, metrics, progress_bar # noqa
sys.modules['fairseq.meters'] = meters
sys.modules['fairseq.metrics'] = metrics
sys.modules['fairseq.progress_bar'] = progress_bar
sys.modules["fairseq.meters"] = meters
sys.modules["fairseq.metrics"] = metrics
sys.modules["fairseq.progress_bar"] = progress_bar
import fairseq.criterions # noqa
import fairseq.models # noqa

View File

@ -4,9 +4,4 @@
# LICENSE file in the root directory of this source tree.
# import models/tasks to register them
from . import ( # noqa
dummy_lm,
dummy_masked_lm,
dummy_model,
dummy_mt,
)
from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa

View File

@ -7,25 +7,27 @@ import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task('dummy_lm')
@register_task("dummy_lm")
class DummyLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49996, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')
parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments "
"per sample for BERT dataset",
)
def __init__(self, args, dictionary):
super().__init__(args)
@ -44,8 +46,8 @@ class DummyLMTask(LegacyFairseqTask):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
logger.info('dictionary: {} types'.format(len(dictionary)))
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@ -59,16 +61,16 @@ class DummyLMTask(LegacyFairseqTask):
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.tokens_per_sample, dtype=torch.long
),
},
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz,
'ntokens': bsz * self.args.tokens_per_sample,
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
@ -84,7 +86,6 @@ class DummyLMTask(LegacyFairseqTask):
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch

View File

@ -7,32 +7,34 @@ import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task('dummy_masked_lm')
@register_task("dummy_masked_lm")
class DummyMaskedLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49995, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')
parser.add_argument("--dict-size", default=49995, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments "
"per sample for BERT dataset",
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
# add mask token
self.mask_idx = dictionary.add_symbol('<mask>')
self.mask_idx = dictionary.add_symbol("<mask>")
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
mask_idx = 0
@ -52,8 +54,8 @@ class DummyMaskedLMTask(LegacyFairseqTask):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
logger.info('dictionary: {} types'.format(len(dictionary)))
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@ -67,16 +69,16 @@ class DummyMaskedLMTask(LegacyFairseqTask):
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.tokens_per_sample, dtype=torch.long
),
},
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz,
'ntokens': bsz * self.args.tokens_per_sample,
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
@ -92,7 +94,6 @@ class DummyMaskedLMTask(LegacyFairseqTask):
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch

View File

@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import Dictionary
from fairseq.models import (
FairseqDecoder,
@ -15,17 +14,16 @@ from fairseq.models import (
)
@register_model('dummy_model')
@register_model("dummy_model")
class DummyModel(FairseqLanguageModel):
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
@staticmethod
def add_args(parser):
parser.add_argument('--num-layers', type=int, default=24)
parser.add_argument('--embed-dim', type=int, default=1024)
parser.add_argument("--num-layers", type=int, default=24)
parser.add_argument("--embed-dim", type=int, default=1024)
@classmethod
def build_model(cls, args, task):
@ -41,32 +39,35 @@ class DummyModel(FairseqLanguageModel):
class DummyEncoder(FairseqDecoder):
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__(Dictionary())
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
self.layers_a = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
])
self.layers_b = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4*embed_dim), # FFN
nn.ReLU(),
nn.Linear(4*embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
])
self.layers_a = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
]
)
self.layers_b = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4 * embed_dim), # FFN
nn.ReLU(),
nn.Linear(4 * embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
]
)
self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens, masked_tokens=None):
@ -90,6 +91,6 @@ class DummyEncoder(FairseqDecoder):
return F.softmax(logits, dim=-1)
@register_model_architecture('dummy_model', 'dummy_model')
@register_model_architecture("dummy_model", "dummy_model")
def base_architecture(args):
pass

View File

@ -7,24 +7,22 @@ import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task('dummy_mt')
@register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49996, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--src-len', default=30, type=int)
parser.add_argument('--tgt-len', default=30, type=int)
parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument("--src-len", default=30, type=int)
parser.add_argument("--tgt-len", default=30, type=int)
def __init__(self, args, dictionary):
super().__init__(args)
@ -41,8 +39,8 @@ class DummyMTTask(LegacyFairseqTask):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
logger.info('dictionary: {} types'.format(len(dictionary)))
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
@ -62,17 +60,17 @@ class DummyMTTask(LegacyFairseqTask):
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full(
(bsz, ), self.args.src_len, dtype=torch.long
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.src_len, dtype=torch.long
),
'prev_output_tokens': tgt.clone(),
"prev_output_tokens": tgt.clone(),
},
'target': tgt,
'nsentences': bsz,
'ntokens': bsz * self.args.tgt_len,
"target": tgt,
"nsentences": bsz,
"ntokens": bsz * self.args.tgt_len,
},
num_items=self.args.dataset_size,
item_size=item_size,
@ -88,7 +86,6 @@ class DummyMTTask(LegacyFairseqTask):
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch

View File

@ -6,9 +6,10 @@
import os
from collections import Counter
from fairseq.tokenizer import tokenize_line
import torch
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
def safe_readline(f):
pos = f.tell()

View File

@ -67,12 +67,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and args.keep_best_checkpoints > 0:
checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
args.best_checkpoint_metric, val_loss)] = (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
checkpoint_conds[
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
] = not hasattr(save_checkpoint, "best") or is_better(
val_loss, save_checkpoint.best
)
checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not args.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
@ -112,10 +114,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
args.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
args.best_checkpoint_metric
),
)
if not args.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
for old_chk in checkpoints[args.keep_best_checkpoints:]:
for old_chk in checkpoints[args.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
@ -133,16 +139,23 @@ def load_checkpoint(args, trainer, **passthrough_args):
reset_meters = args.reset_meters
reset_dataloader = args.reset_dataloader
if getattr(args, 'finetune_from_model', None) is not None \
and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader")
if getattr(args, "finetune_from_model", None) is not None and (
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
):
raise ValueError(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = getattr(args, "checkpoint_suffix", "")
if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
if (
args.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
args.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path)
if getattr(args, 'finetune_from_model', None) is not None and first_launch:
if getattr(args, "finetune_from_model", None) is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(args.finetune_from_model):
@ -151,19 +164,26 @@ def load_checkpoint(args, trainer, **passthrough_args):
reset_lr_scheduler = True
reset_meters = True
reset_dataloader = True
logger.info(f'loading pretrained model from {checkpoint_path}: '
'optimizer, lr scheduler, meters, dataloader will be reset')
logger.info(
f"loading pretrained model from {checkpoint_path}: "
"optimizer, lr scheduler, meters, dataloader will be reset"
)
else:
raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist')
raise ValueError(
f"--funetune-from-model {args.finetune_from_model} does not exist"
)
elif getattr(args, "model_parallel_size", 1) > 1:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = args.restore_file
if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None):
if args.restore_file != "checkpoint_last.pt" and getattr(
args, "finetune_from_model", None
):
raise ValueError(
'--finetune-from-model and --restore-file (non-default value) '
'can not be specified together: ' + str(args))
"--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: " + str(args)
)
extra_state = trainer.load_checkpoint(
checkpoint_path,
@ -213,7 +233,9 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
return state
def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
def load_model_ensemble(
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
):
"""Loads an ensemble of models.
Args:
@ -222,18 +244,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
assert not (strict and num_shards > 1), \
"Cannot load state dict with strict=True and checkpoint shards > 1"
assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble, args, _task = load_model_ensemble_and_task(
filenames, arg_overrides, task, strict, suffix, num_shards,
filenames,
arg_overrides,
task,
strict,
suffix,
num_shards,
)
return ensemble, args
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
def load_model_ensemble_and_task(
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
):
from fairseq import tasks
assert not (strict and num_shards > 1), \
"Cannot load state dict with strict=True and checkpoint shards > 1"
assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble = []
for filename in filenames:
orig_filename = filename
@ -533,7 +565,9 @@ def verify_checkpoint_directory(save_dir: str) -> None:
with open(temp_file_path, "w"):
pass
except OSError as e:
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
logger.warning(
"Unable to access checkpoint save directory: {}".format(save_dir)
)
raise e
else:
os.remove(temp_file_path)

Some files were not shown because too many files have changed in this diff Show More