Support custom Dictionary implementations in 'preprocess.py' (#448)

Summary:
The `preprocess.py` script has been refactored in order to:

1. Use the `options` module for command line arguments  parsing. This will give to `preprocess.py` the ability to load custom modules with `--user-dir` flag (already implemented to all other binaries)
2. Dictionary loading and building code has moved to Task implementation. This allows custom Dictionary classes to be used during the data generation step.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/448

Differential Revision: D13674819

Pulled By: myleott

fbshipit-source-id: b40648a98ed6c08284577e5ec25876e018d8c822
This commit is contained in:
Davide Caroselli 2019-02-01 09:43:06 -08:00 committed by Facebook Github Bot
parent ec6f8ef99a
commit bbb4120b00
6 changed files with 140 additions and 129 deletions

View File

@ -17,6 +17,12 @@ from fairseq.tasks import TASK_REGISTRY
from fairseq.utils import import_user_module
def get_preprocessing_parser(default_task='translation'):
parser = get_parser('Preprocessing', default_task)
add_preprocess_args(parser)
return parser
def get_training_parser(default_task='translation'):
parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True)
@ -142,7 +148,7 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-fp16', action='store_true',
help='use a memory-efficient version of FP16 training; implies --fp16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int,
parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int,
help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale')
@ -159,6 +165,50 @@ def get_parser(desc, default_task='translation'):
return parser
def add_preprocess_args(parser):
group = parser.add_argument_group('Preprocessing')
# fmt: off
group.add_argument("-s", "--source-lang", default=None, metavar="SRC",
help="source language")
group.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
help="target language")
group.add_argument("--trainpref", metavar="FP", default=None,
help="train file prefix")
group.add_argument("--validpref", metavar="FP", default=None,
help="comma separated, valid file prefixes")
group.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes")
group.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir")
group.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
group.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
group.add_argument("--tgtdict", metavar="FP",
help="reuse given target dictionary")
group.add_argument("--srcdict", metavar="FP",
help="reuse given source dictionary")
group.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
help="number of target words to retain")
group.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
help="number of source words to retain")
group.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
group.add_argument("--output-format", metavar="FORMAT", default="binary",
choices=["binary", "raw"],
help="output format (optional)")
group.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
group.add_argument("--only-source", action="store_true",
help="Only process the source language")
group.add_argument("--padding-factor", metavar="N", default=8, type=int,
help="Pad dictionary size to be multiple of N")
group.add_argument("--workers", metavar="N", default=1, type=int,
help="number of parallel workers")
# fmt: on
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
# fmt: off

View File

@ -11,7 +11,6 @@ import os
from .fairseq_task import FairseqTask
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
@ -73,3 +72,7 @@ for file in os.listdir(os.path.dirname(__file__)):
group_args = parser.add_argument_group('Additional command-line arguments')
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + '_parser'] = parser
def get_task(name):
return TASK_REGISTRY[name]

View File

@ -5,9 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset, iterators
import torch
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.tokenizer import Tokenizer
class FairseqTask(object):
"""
@ -24,6 +27,35 @@ class FairseqTask(object):
self.args = args
self.datasets = {}
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
return Dictionary.load(filename)
@classmethod
def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
"""Build the dictionary
Args:
filenames (list): list of filenames
workers (int): number of concurrent workers
threshold (int): defines the minimum word count
nwords (int): defines the total number of words in the final dictionary,
including special symbols
padding_factor (int): can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
d = Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
@ -59,9 +91,9 @@ class FairseqTask(object):
return self.datasets[split]
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0,
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0,
):
"""
Get an iterator that yields batches of data from the given dataset.

View File

@ -109,8 +109,8 @@ class TranslationTask(FairseqTask):
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()

View File

@ -9,64 +9,19 @@
Data pre-processing: build vocabularies and binarize training data.
"""
import argparse
from collections import Counter
from itertools import zip_longest
import os
import shutil
from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from fairseq import options, tasks
from fairseq.data import indexed_dataset
from fairseq.tokenizer import Tokenizer
from multiprocessing import Pool
from fairseq.utils import import_user_module
def get_parser():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
help="source language")
parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
help="target language")
parser.add_argument("--trainpref", metavar="FP", default=None,
help="train file prefix")
parser.add_argument("--validpref", metavar="FP", default=None,
help="comma separated, valid file prefixes")
parser.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes")
parser.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir")
parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
parser.add_argument("--tgtdict", metavar="FP",
help="reuse given target dictionary")
parser.add_argument("--srcdict", metavar="FP",
help="reuse given source dictionary")
parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
help="number of target words to retain")
parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
help="number of source words to retain")
parser.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
parser.add_argument("--output-format", metavar="FORMAT", default="binary",
choices=["binary", "raw"],
help="output format (optional)")
parser.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
parser.add_argument("--only-source", action="store_true",
help="Only process the source language")
parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
help="Pad dictionary size to be multiple of N")
parser.add_argument("--workers", metavar="N", default=1, type=int,
help="number of parallel workers")
# fmt: on
return parser
def main(args):
import_user_module(args)
@ -74,6 +29,8 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
task = tasks.get_task(args.task)
def train_path(lang):
return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
@ -89,50 +46,57 @@ def main(args):
def dict_path(lang):
return dest_path("dict", lang) + ".txt"
if args.joined_dictionary:
assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
src_dict = build_dictionary(
{train_path(lang) for lang in [args.source_lang, args.target_lang]},
args.workers,
def build_dictionary(filenames, src=False, tgt=False):
assert src ^ tgt
return task.build_dictionary(
filenames,
workers=args.workers,
threshold=args.thresholdsrc if src else args.thresholdtgt,
nwords=args.nwordssrc if src else args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict = src_dict
else:
if args.joined_dictionary:
assert (
not args.srcdict or not args.tgtdict
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], args.workers)
src_dict = build_dictionary({train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True)
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], src=True)
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
tgt_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(
[train_path(args.target_lang)], args.workers
)
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
else:
tgt_dict = None
src_dict.finalize(
threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(dict_path(args.source_lang))
if target:
if not args.joined_dictionary:
tgt_dict.finalize(
threshold=args.thresholdtgt,
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
if target and tgt_dict is not None:
tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = dictionary.Dictionary.load(dict_path(lang))
dict = task.load_dictionary(dict_path(lang))
print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
n_seq_tok = [0, 0]
replaced = Counter()
@ -229,8 +193,6 @@ def main(args):
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {}
with open(args.alignfile, "r", encoding='utf-8') as align_file:
with open(src_file_name, "r", encoding='utf-8') as src_file:
@ -260,37 +222,16 @@ def main(args):
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(
os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w", encoding='utf-8'
os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w", encoding='utf-8'
) as f:
for k, v in align_dict.items():
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def build_and_save_dictionary(
train_path, output_path, num_workers, freq_threshold, max_words, dict_cls=dictionary.Dictionary,
):
dict = build_dictionary([train_path], num_workers, dict_cls)
dict.finalize(threshold=freq_threshold, nwords=max_words)
dict_path = os.path.join(output_path, "dict.txt")
dict.save(dict_path)
return dict_path
def build_dictionary(
filenames,
workers,
dict_cls=dictionary.Dictionary,
):
d = dict_cls()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
return d
def binarize(args, filename, dict, output_prefix, lang, offset, end):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
@ -304,21 +245,6 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end):
return res
def binarize_with_load(
args,
filename,
dict_path,
output_prefix,
lang,
offset,
end,
dict_cls=dictionary.Dictionary,
):
dict = dict_cls.load(dict_path)
binarize(args, filename, dict, output_prefix, lang, offset, end)
return dataset_dest_prefix(args, output_prefix, lang)
def dataset_dest_prefix(args, output_prefix, lang):
base = "{}/{}".format(args.destdir, output_prefix)
lang_part = (
@ -346,6 +272,6 @@ def merge_files(files, outpath):
if __name__ == "__main__":
parser = get_parser()
parser = options.get_preprocessing_parser()
args = parser.parse_args()
main(args)

View File

@ -223,7 +223,7 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = preprocess.get_parser()
preprocess_parser = options.get_preprocessing_parser()
preprocess_args = preprocess_parser.parse_args(
[
'--source-lang', 'in',
@ -291,7 +291,7 @@ def generate_main(data_dir, extra_flags=None):
def preprocess_lm_data(data_dir):
preprocess_parser = preprocess.get_parser()
preprocess_parser = options.get_preprocessing_parser()
preprocess_args = preprocess_parser.parse_args([
'--only-source',
'--trainpref', os.path.join(data_dir, 'train.out'),