More fully deprecate --raw-text and --lazy-load (fixes #1488)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/947

Differential Revision: D19084273

Pulled By: myleott

fbshipit-source-id: de80d9abfac8e3d813a9c9b343b41327c500344e
This commit is contained in:
Myle Ott 2019-12-16 17:18:29 -08:00 committed by Facebook Github Bot
parent ebc1f91c7d
commit be3515b289
10 changed files with 19 additions and 72 deletions

View File

@ -158,14 +158,6 @@ Fairseq supports FP16 training with the ``--fp16`` flag:
> fairseq-train --fp16 (...)
Lazily loading large training datasets
--------------------------------------
By default fairseq loads the entire training set into system memory. For large
datasets, the ``--lazy-load`` option can be used to instead load batches on-demand.
For optimal performance, use the ``--num-workers`` option to control the number
of background processes that will load batches.
Distributed training
--------------------

View File

@ -63,8 +63,8 @@ def check_args(args):
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
args.replace_unk is None or args.dataset_impl == "raw"
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
def get_dataset_itr(args, task):

View File

@ -326,6 +326,11 @@ def _upgrade_state_dict(state):
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# --raw-text and --lazy-load are deprecated
if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# set any missing default values in the task, model or other registries
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])

View File

@ -46,10 +46,6 @@ class CrossLingualLMTask(FairseqTask):
parser.add_argument('--monolingual-langs', default='en', type=str,
help='comma separated list of languages for which we'
' want to train XLM on')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--shuffle', action='store_true',
help='shuffle each monolingual dataset while'
' training')

View File

@ -32,8 +32,6 @@ class DenoisingTask(FairseqTask):
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments'
' per sample for dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument(
'--sample-break-mode', default="complete_doc", type=str,
help='mode for breaking sentence',

View File

@ -63,10 +63,6 @@ class LanguageModelingTask(FairseqTask):
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--output-dictionary-size', default=-1, type=int,
help='limit the size of output dictionary')
parser.add_argument('--self-target', action='store_true',
@ -97,17 +93,6 @@ class LanguageModelingTask(FairseqTask):
Args:
args (argparse.Namespace): parsed command-line arguments
"""
if getattr(args, "raw_text", False):
utils.deprecation_warning(
"--raw-text is deprecated, please use --dataset-impl=raw"
)
args.dataset_impl = "raw"
elif getattr(args, "lazy_load", False):
utils.deprecation_warning(
"--lazy-load is deprecated, please use --dataset-impl=lazy"
)
args.dataset_impl = "lazy"
dictionary = None
output_dictionary = None
if args.data:

View File

@ -71,10 +71,6 @@ class MultilingualTranslationTask(FairseqTask):
help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
@ -123,12 +119,6 @@ class MultilingualTranslationTask(FairseqTask):
def prepare(cls, args, **kargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
if args.lang_pairs is None:
raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.')

View File

@ -8,6 +8,8 @@ import os
from fairseq.data import (
BacktranslationDataset,
data_utils,
indexed_dataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
@ -143,21 +145,10 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
else:
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
def load_indexed_dataset(path, dictionary):
return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl)
# load parallel datasets
src_datasets, tgt_datasets = {}, {}
@ -170,8 +161,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
src_datasets[lang_pair] = load_indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = load_indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
@ -184,7 +175,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
dataset = indexed_dataset(filename, self.dicts[tgt])
dataset = load_indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
dataset.sizes,
@ -232,8 +223,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
if not split_exists(split, tgt, None, tgt):
continue
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
tgt_dataset1,
self.dicts[tgt],

View File

@ -134,10 +134,6 @@ class TranslationTask(FairseqTask):
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
@ -168,12 +164,6 @@ class TranslationTask(FairseqTask):
"""
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
paths = args.data.split(os.pathsep)
assert len(paths) > 0

View File

@ -17,8 +17,8 @@ def main(args):
assert args.path is not None, '--path required for generation!'
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
assert args.replace_unk is None or args.dataset_impl == 'raw', \
'--replace-unk requires a raw text dataset (--dataset-impl=raw)'
utils.import_user_module(args)