diff --git a/README.md b/README.md index 81593e79..3cc0e31e 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,13 @@ of various sequence-to-sequence models, including: Fairseq features: - multi-GPU (distributed) training on one machine or across multiple machines -- fast beam search generation on both CPU and GPU +- fast generation on both CPU and GPU with multiple search algorithms implemented: + - beam search + - Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + - sampling (unconstrained and top-k) - large mini-batch training even on a single GPU via delayed updates - fast half-precision floating point (FP16) training -- extensible: easily register new models, criterions, and tasks +- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers We also provide [pre-trained models](#pre-trained-models) for several benchmark translation and language modeling datasets. @@ -34,7 +37,7 @@ translation and language modeling datasets. * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * Python version 3.6 -Currently fairseq requires PyTorch version >= 0.4.0. +Currently fairseq requires PyTorch version >= 1.0.0. Please follow the instructions here: https://github.com/pytorch/pytorch#installation. If you use Docker make sure to increase the shared memory size either with diff --git a/distributed_train.py b/distributed_train.py deleted file mode 100644 index f3725a6e..00000000 --- a/distributed_train.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -import os -import socket -import subprocess - -from train import main as single_process_main -from fairseq import distributed_utils, options - - -def main(args): - if args.distributed_init_method is None and args.distributed_port > 0: - # We can determine the init method automatically for Slurm. - node_list = os.environ.get('SLURM_JOB_NODELIST') - if node_list is not None: - try: - hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) - args.distributed_init_method = 'tcp://{host}:{port}'.format( - host=hostnames.split()[0].decode('utf-8'), - port=args.distributed_port) - args.distributed_rank = int(os.environ.get('SLURM_PROCID')) - args.device_id = int(os.environ.get('SLURM_LOCALID')) - except subprocess.CalledProcessError as e: # scontrol failed - raise e - except FileNotFoundError as e: # Slurm is not installed - pass - if args.distributed_init_method is None and args.distributed_port is None: - raise ValueError('--distributed-init-method or --distributed-port ' - 'must be specified for distributed training') - - args.distributed_rank = distributed_utils.distributed_init(args) - print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) - single_process_main(args) - - -if __name__ == '__main__': - parser = options.get_training_parser() - args = options.parse_args_and_arch(parser) - main(args) diff --git a/docs/criterions.rst b/docs/criterions.rst index ab64b27b..d6b8ca6b 100644 --- a/docs/criterions.rst +++ b/docs/criterions.rst @@ -6,8 +6,26 @@ Criterions ========== +Criterions compute the loss function given the model and batch, roughly:: + + loss = criterion(model, batch) + .. automodule:: fairseq.criterions :members: + .. autoclass:: fairseq.criterions.FairseqCriterion :members: :undoc-members: + +.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss + :members: + :undoc-members: +.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss + :members: + :undoc-members: +.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion + :members: + :undoc-members: +.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion + :members: + :undoc-members: diff --git a/docs/data.rst b/docs/data.rst index fed63278..55c27c91 100644 --- a/docs/data.rst +++ b/docs/data.rst @@ -21,6 +21,20 @@ mini-batches. .. autoclass:: fairseq.data.MonolingualDataset :members: +**Helper Datasets** + +These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and +provide additional functionality: + +.. autoclass:: fairseq.data.BacktranslationDataset + :members: +.. autoclass:: fairseq.data.ConcatDataset + :members: +.. autoclass:: fairseq.data.RoundRobinZipDatasets + :members: +.. autoclass:: fairseq.data.TransformEosDataset + :members: + Dictionary ---------- @@ -32,6 +46,8 @@ Dictionary Iterators --------- +.. autoclass:: fairseq.data.BufferedIterator + :members: .. autoclass:: fairseq.data.CountingIterator :members: .. autoclass:: fairseq.data.EpochBatchIterator diff --git a/docs/getting_started.rst b/docs/getting_started.rst index fff2e294..6f5d5a67 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5: > MODEL_DIR=wmt14.en-fr.fconv-py > python interactive.py \ --path $MODEL_DIR/model.pt $MODEL_DIR \ - --beam 5 + --beam 5 --source-lang en --target-lang fr | loading model(s) from wmt14.en-fr.fconv-py/model.pt | [en] dictionary: 44206 types | [fr] dictionary: 44463 types | Type the input sentence and press return: > Why is it rare to discover new marine mam@@ mal species ? O Why is it rare to discover new marine mam@@ mal species ? - H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? - A 0 1 3 3 5 6 6 8 8 8 7 11 12 + H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ? + P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288 -This generation script produces four types of outputs: a line prefixed -with *S* shows the supplied source sentence after applying the -vocabulary; *O* is a copy of the original source sentence; *H* is the -hypothesis along with an average log-likelihood; and *A* is the -attention maxima for each word in the hypothesis, including the +This generation script produces three types of outputs: a line prefixed +with *O* is a copy of the original source sentence; *H* is the +hypothesis along with an average log-likelihood; and *P* is the +positional score per token position, including the end-of-sentence marker which is omitted from the text. See the `README `__ for a diff --git a/docs/lr_scheduler.rst b/docs/lr_scheduler.rst index ecc60467..b1b23ada 100644 --- a/docs/lr_scheduler.rst +++ b/docs/lr_scheduler.rst @@ -6,7 +6,29 @@ Learning Rate Schedulers ======================== -TODO +Learning Rate Schedulers update the learning rate over the course of training. +Learning rates can be updated after each update via :func:`step_update` or at +epoch boundaries via :func:`step`. .. automodule:: fairseq.optim.lr_scheduler :members: + +.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler + :members: + :undoc-members: + +.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule + :members: + :undoc-members: diff --git a/docs/modules.rst b/docs/modules.rst index 1cd3df0c..71989bbb 100644 --- a/docs/modules.rst +++ b/docs/modules.rst @@ -1,8 +1,8 @@ Modules ======= -Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be -helpful when implementing a new :class:`FairseqModel`. +Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may +be helpful when implementing a new :class:`~fairseq.models.FairseqModel`. .. automodule:: fairseq.modules :members: diff --git a/docs/optim.rst b/docs/optim.rst index e7c0a4a2..67370b33 100644 --- a/docs/optim.rst +++ b/docs/optim.rst @@ -6,5 +6,27 @@ Optimizers ========== +Optimizers update the Model parameters based on the gradients. + .. automodule:: fairseq.optim :members: + +.. autoclass:: fairseq.optim.FairseqOptimizer + :members: + :undoc-members: + +.. autoclass:: fairseq.optim.adagrad.Adagrad + :members: + :undoc-members: +.. autoclass:: fairseq.optim.adam.FairseqAdam + :members: + :undoc-members: +.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer + :members: + :undoc-members: +.. autoclass:: fairseq.optim.nag.FairseqNAG + :members: + :undoc-members: +.. autoclass:: fairseq.optim.sgd.SGD + :members: + :undoc-members: diff --git a/docs/overview.rst b/docs/overview.rst index d3ba171e..e8191c5a 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -22,12 +22,18 @@ fairseq implements the following high-level training flow:: for epoch in range(num_epochs): itr = task.get_batch_iterator(task.dataset('train')) for num_updates, batch in enumerate(itr): - loss = criterion(model, batch) - optimizer.backward(loss) + task.train_step(batch, model, criterion, optimizer) + average_and_clip_gradients() optimizer.step() lr_scheduler.step_update(num_updates) lr_scheduler.step(epoch) +where the default implementation for ``train.train_step`` is roughly:: + + def train_step(self, batch, model, criterion, optimizer): + loss = criterion(model, batch) + optimizer.backward(loss) + **Registering new plug-ins** New plug-ins are *registered* through a set of ``@register`` function diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index b198e783..262c90ef 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory. ------------------------------- Finally we can write a short script to evaluate our model on new inputs. Create -a new file named :file:`eval_classify.py` with the following contents:: +a new file named :file:`eval_classifier.py` with the following contents:: from fairseq import data, options, tasks, utils from fairseq.tokenizer import Tokenizer # Parse command-line arguments for generation - parser = options.get_generation_parser() + parser = options.get_generation_parser(default_task='simple_classification') args = options.parse_args_and_arch(parser) # Setup task - args.task = 'simple_classification' task = tasks.setup_task(args) # Load model diff --git a/eval_lm.py b/eval_lm.py index f1219643..929dc046 100644 --- a/eval_lm.py +++ b/eval_lm.py @@ -55,7 +55,9 @@ def main(parsed_args): # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) - models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides)) + models, args = utils.load_ensemble_for_inference( + parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides), + ) for arg in vars(parsed_args).keys(): if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: @@ -83,9 +85,10 @@ def main(parsed_args): max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), + ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, - ignore_invalid_inputs=True, + num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index acd66567..380c7b04 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary from .fairseq_dataset import FairseqDataset from .backtranslation_dataset import BacktranslationDataset from .concat_dataset import ConcatDataset -from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset +from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset from .language_pair_dataset import LanguagePairDataset from .monolingual_dataset import MonolingualDataset from .round_robin_zip_datasets import RoundRobinZipDatasets @@ -33,7 +33,6 @@ __all__ = [ 'GroupedIterator', 'IndexedCachedDataset', 'IndexedDataset', - 'IndexedInMemoryDataset', 'IndexedRawTextDataset', 'LanguagePairDataset', 'MonolingualDataset', diff --git a/fairseq/data/backtranslation_dataset.py b/fairseq/data/backtranslation_dataset.py index cabdcc64..7697c321 100644 --- a/fairseq/data/backtranslation_dataset.py +++ b/fairseq/data/backtranslation_dataset.py @@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): class BacktranslationDataset(FairseqDataset): + """ + Sets up a backtranslation dataset which takes a tgt batch, generates + a src using a tgt-src backtranslation function (*backtranslation_fn*), + and returns the corresponding `{generated src, input tgt}` batch. + + Args: + tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be + backtranslated. Only the source side of this dataset will be used. + After backtranslation, the source sentences in this dataset will be + returned as the targets. + backtranslation_fn (callable): function to call to generate + backtranslations. This is typically the `generate` method of a + :class:`~fairseq.sequence_generator.SequenceGenerator` object. + max_len_a, max_len_b (int, int): will be used to compute + `maxlen = max_len_a * src_len + max_len_b`, which will be passed + into *backtranslation_fn*. + output_collater (callable, optional): function to call on the + backtranslated samples to create the final batch + (default: ``tgt_dataset.collater``). + cuda: use GPU for generation + """ + def __init__( self, tgt_dataset, @@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset): cuda=True, **kwargs ): - """ - Sets up a backtranslation dataset which takes a tgt batch, generates - a src using a tgt-src backtranslation function (*backtranslation_fn*), - and returns the corresponding `{generated src, input tgt}` batch. - - Args: - tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be - backtranslated. Only the source side of this dataset will be - used. After backtranslation, the source sentences in this - dataset will be returned as the targets. - backtranslation_fn (callable): function to call to generate - backtranslations. This is typically the `generate` method of a - :class:`~fairseq.sequence_generator.SequenceGenerator` object. - max_len_a, max_len_b (int, int): will be used to compute - `maxlen = max_len_a * src_len + max_len_b`, which will be - passed into *backtranslation_fn*. - output_collater (callable, optional): function to call on the - backtranslated samples to create the final batch (default: - ``tgt_dataset.collater``) - cuda: use GPU for generation - """ self.tgt_dataset = tgt_dataset self.backtranslation_fn = backtranslation_fn self.max_len_a = max_len_a @@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset): """ tgt_size = self.tgt_dataset.size(index)[0] return (tgt_size, tgt_size) - + @property def supports_prefetch(self): - return self.tgt_dataset.supports_prefetch() + return getattr(self.tgt_dataset, 'supports_prefetch', False) def prefetch(self, indices): return self.tgt_dataset.prefetch(indices) - diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py index 2331075e..f3d4c3f3 100644 --- a/fairseq/data/concat_dataset.py +++ b/fairseq/data/concat_dataset.py @@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset): if isinstance(sample_ratios, int): sample_ratios = [sample_ratios] * len(self.datasets) self.sample_ratios = sample_ratios - self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios) + self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) self.real_sizes = [len(d) for d in self.datasets] def __len__(self): - return self.cummulative_sizes[-1] + return self.cumulative_sizes[-1] def __getitem__(self, idx): - dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: - sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] sample_idx = sample_idx % self.real_sizes[dataset_idx] return self.datasets[dataset_idx][sample_idx] @@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset): def prefetch(self, indices): frm = 0 - for to, ds in zip(self.cummulative_sizes, self.datasets): + for to, ds in zip(self.cumulative_sizes, self.datasets): real_size = len(ds) ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) frm = to diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index af4d0654..31438bc8 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): size_fn (callable): function that returns the size of a given index max_positions (tuple): filter elements larger than this size. Comparisons are done component-wise. - raise_exception (bool, optional): if ``True``, raise an exception - if any elements are filtered. Default: ``False`` + raise_exception (bool, optional): if ``True``, raise an exception if + any elements are filtered (default: False). """ def check_size(idx): if isinstance(max_positions, float) or isinstance(max_positions, int): @@ -128,12 +128,12 @@ def batch_by_size( indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index - max_tokens (int, optional): max number of tokens in each batch. - Default: ``None`` + max_tokens (int, optional): max number of tokens in each batch + (default: None). max_sentences (int, optional): max number of sentences in each - batch. Default: ``None`` + batch (default: None). required_batch_size_multiple (int, optional): require batch size to - be a multiple of N. Default: ``1`` + be a multiple of N (default: 1). """ max_tokens = max_tokens if max_tokens is not None else float('Inf') max_sentences = max_sentences if max_sentences is not None else float('Inf') diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index f56f8093..3d3d4e56 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -200,11 +200,15 @@ class Dictionary(object): t[-1] = self.eos() return t + class TruncatedDictionary(object): def __init__(self, wrapped_dict, length): - self.__class__ = type(wrapped_dict.__class__.__name__, - (self.__class__, wrapped_dict.__class__), {}) + self.__class__ = type( + wrapped_dict.__class__.__name__, + (self.__class__, wrapped_dict.__class__), + {} + ) self.__dict__ = wrapped_dict.__dict__ self.wrapped_dict = wrapped_dict self.length = min(len(self.wrapped_dict), length) diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index b7e93b2c..eb2b6ae3 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -7,8 +7,6 @@ import torch.utils.data -from fairseq.data import data_utils - class FairseqDataset(torch.utils.data.Dataset): """A dataset that provides helpers for batching.""" @@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset): @property def supports_prefetch(self): + """Whether this dataset supports prefetching.""" return False def prefetch(self, indices): + """Prefetch the data required for this epoch.""" raise NotImplementedError diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 661daf95..3c662205 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -52,13 +52,12 @@ def data_file_path(prefix_path): class IndexedDataset(torch.utils.data.Dataset): """Loader for TorchNet IndexedDataset""" - def __init__(self, path, fix_lua_indexing=False, read_data=True): + def __init__(self, path, fix_lua_indexing=False): super().__init__() self.fix_lua_indexing = fix_lua_indexing self.read_index(path) self.data_file = None - if read_data: - self.read_data(path) + self.path = path def read_index(self, path): with open(index_file_path(path), 'rb') as f: @@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset): self.data_file.close() def __getitem__(self, i): + if not self.data_file: + self.read_data(self.path) self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]) a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset): def __len__(self): return self.size - def read_into(self, start, dst): - self.data_file.seek(start * self.element_size) - self.data_file.readinto(dst) - if self.fix_lua_indexing: - dst -= 1 # subtract 1 for 0-based indexing - @staticmethod def exists(path): return ( @@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset): os.path.exists(data_file_path(path)) ) + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + class IndexedCachedDataset(IndexedDataset): def __init__(self, path, fix_lua_indexing=False): - super().__init__(path, fix_lua_indexing, True) + super().__init__(path, fix_lua_indexing=fix_lua_indexing) self.cache = None self.cache_index = {} @@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset): def prefetch(self, indices): if all(i in self.cache_index for i in indices): return + if not self.data_file: + self.read_data(self.path) indices = sorted(set(indices)) total_size = 0 for i in indices: @@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset): return item -class IndexedInMemoryDataset(IndexedDataset): - """Loader for TorchNet IndexedDataset, keeps all the data in memory""" - - def read_data(self, path): - self.data_file = open(data_file_path(path), 'rb') - self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype) - self.data_file.readinto(self.buffer) - self.data_file.close() - if self.fix_lua_indexing: - self.buffer -= 1 # subtract 1 for 0-based indexing - - def read_into(self, start, dst): - if self.token_blob is None: - self.token_blob = [t for l in self.tokens_list for t in l] - np.copyto(dst, self.token_blob[start:]) - - def __del__(self): - pass - - def __getitem__(self, i): - self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] - a = np.empty(tensor_size, dtype=self.dtype) - np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]]) - return torch.from_numpy(a).long() - - -class IndexedRawTextDataset(IndexedDataset): +class IndexedRawTextDataset(torch.utils.data.Dataset): """Takes a text file as input and binarizes it in memory at instantiation. Original lines are also kept in memory""" @@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset): self.sizes.append(len(tokens)) self.sizes = np.array(self.sizes) + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError('index out of range') + def __getitem__(self, i): self.check_index(i) return self.tokens_list[i] @@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object): self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) def merge_file_(self, another_file): - index = IndexedDataset(another_file, read_data=False) + index = IndexedDataset(another_file) assert index.dtype == self.dtype begin = self.data_offsets[-1] diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 184f8001..cbd305f1 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -69,17 +69,19 @@ class EpochBatchIterator(object): batch_sampler (~torch.utils.data.Sampler): an iterator over batches of indices seed (int, optional): seed for random number generator for - reproducibility. Default: 1 + reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N - shards. Default: 1 + shards (default: 1). shard_id (int, optional): which shard of the data iterator to - return. Default: 0 - buffer_size (int, optional): number of batches to buffer. Default: 5 + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). """ def __init__( self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - buffer_size=5, + num_workers=0, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -88,14 +90,12 @@ class EpochBatchIterator(object): self.seed = seed self.num_shards = num_shards self.shard_id = shard_id - self.buffer_size = buffer_size + self.num_workers = num_workers self.epoch = 0 self._cur_epoch_itr = None self._next_epoch_itr = None - self._supports_prefetch = ( - hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch - ) + self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) def __len__(self): return len(self.frozen_batches) @@ -105,11 +105,10 @@ class EpochBatchIterator(object): Args: shuffle (bool, optional): shuffle batches before returning the - iterator. Default: ``True`` + iterator (default: True). fix_batches_to_gpus: ensure that batches are always allocated to the same shards across epochs. Requires - that :attr:`dataset` supports prefetching. Default: - ``False`` + that :attr:`dataset` supports prefetching (default: False). """ if self._next_epoch_itr is not None: self._cur_epoch_itr = self._next_epoch_itr @@ -117,7 +116,8 @@ class EpochBatchIterator(object): else: self.epoch += 1 self._cur_epoch_itr = self._get_iterator_for_epoch( - self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus) + self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, + ) return self._cur_epoch_itr def end_of_epoch(self): @@ -179,50 +179,14 @@ class EpochBatchIterator(object): batches = self.frozen_batches batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) - return CountingIterator(BufferedIterator( - torch.utils.data.DataLoader( - self.dataset, - collate_fn=self.collate_fn, - batch_sampler=batches, - ), - buffer_size=self.buffer_size, + return CountingIterator(torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches, + num_workers=self.num_workers, )) -class BufferedIterator(object): - """Wrapper around an iterable that prefetches items into a buffer. - - Args: - iterable (iterable): iterable to wrap - buffer_size (int): number of items to prefetch and buffer - """ - - def __init__(self, iterable, buffer_size): - self.iterable = iterable - - self.q = queue.Queue(maxsize=buffer_size) - self.thread = threading.Thread(target=self._load_q, daemon=True) - self.thread.start() - - def __len__(self): - return len(self.iterable) - - def __iter__(self): - return self - - def __next__(self): - x = self.q.get() - if x is None: - self.thread.join() - raise StopIteration - return x[0] - - def _load_q(self): - for x in self.iterable: - self.q.put([x]) # wrap in list so that it's never None - self.q.put(None) - - class GroupedIterator(object): """Wrapper around an iterable that returns groups (chunks) of items. @@ -261,7 +225,7 @@ class ShardedIterator(object): num_shards (int): number of shards to split the iterable into shard_id (int): which shard to iterator over fill_value (Any, optional): padding value when the iterable doesn't - evenly divide *num_shards*. Default: ``None`` + evenly divide *num_shards* (default: None). """ def __init__(self, iterable, num_shards, shard_id, fill_value=None): diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 2c949acc..6c060d04 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset): tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary - left_pad_source (bool, optional): pad source tensors on the left side. - Default: ``True`` - left_pad_target (bool, optional): pad target tensors on the left side. - Default: ``False`` - max_source_positions (int, optional): max number of tokens in the source - sentence. Default: ``1024`` - max_target_positions (int, optional): max number of tokens in the target - sentence. Default: ``1024`` - shuffle (bool, optional): shuffle dataset elements before batching. - Default: ``True`` + left_pad_source (bool, optional): pad source tensors on the left side + (default: True). + left_pad_target (bool, optional): pad target tensors on the left side + (default: False). + max_source_positions (int, optional): max number of tokens in the + source sentence (default: 1024). + max_target_positions (int, optional): max number of tokens in the + target sentence (default: 1024). + shuffle (bool, optional): shuffle dataset elements before batching + (default: True). input_feeding (bool, optional): create a shifted version of the targets - to be passed into the model for input feeding/teacher forcing. - Default: ``True`` - remove_eos_from_source (bool, optional): if set, removes eos from end of - source if it's present. Default: ``False`` + to be passed into the model for input feeding/teacher forcing + (default: True). + remove_eos_from_source (bool, optional): if set, removes eos from end + of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of - target if it's absent. Default: ``False`` + target if it's absent (default: False). """ def __init__( @@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset): indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] - def prefetch(self, indices): - self.src.prefetch(indices) - self.tgt.prefetch(indices) - @property def supports_prefetch(self): return ( - hasattr(self.src, 'supports_prefetch') - and self.src.supports_prefetch - and hasattr(self.tgt, 'supports_prefetch') - and self.tgt.supports_prefetch + getattr(self.src, 'supports_prefetch', False) + and getattr(self.tgt, 'supports_prefetch', False) ) + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index e6d056fd..f465406e 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -9,7 +9,6 @@ import numpy as np import torch from . import data_utils, FairseqDataset -from typing import List def collate(samples, pad_idx, eos_idx): @@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset): dataset (torch.utils.data.Dataset): dataset to wrap sizes (List[int]): sentence lengths vocab (~fairseq.data.Dictionary): vocabulary - shuffle (bool, optional): shuffle the elements before batching. - Default: ``True`` + shuffle (bool, optional): shuffle the elements before batching + (default: True). """ def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, @@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset): self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle - assert targets is None or all( - t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'" + assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \ + "targets must be none or one of 'self', 'future', 'past'" if targets is not None and len(targets) == 0: targets = None self.targets = targets @@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset): @property def supports_prefetch(self): - return self.dataset.supports_prefetch + return getattr(self.dataset, 'supports_prefetch', False) def prefetch(self, indices): self.dataset.prefetch(indices) diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py index e0029ab9..e66fd753 100644 --- a/fairseq/data/noising.py +++ b/fairseq/data/noising.py @@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset): **kwargs ): """ - Sets up a noising dataset which takes a src batch, generates - a noisy src using a noising config, and returns the - corresponding {noisy src, original src} batch + Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the + samples based on the supplied noising configuration. + Args: - src_dataset: dataset which will be used to build self.src_dataset -- + src_dataset (~torch.utils.data.Dataset): dataset to wrap. + to build self.src_dataset -- a LanguagePairDataset with src dataset as the source dataset and None as the target dataset. Should NOT have padding so that src_lengths are accurately calculated by language_pair_dataset @@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset): We use language_pair_dataset here to encapsulate the tgt_dataset so we can re-use the LanguagePairDataset collater to format the batches in the structure that SequenceGenerator expects. - src_dict: src dict - src_dict: src dictionary - seed: seed to use when generating random noise - noiser: a pre-initialized noiser. If this is None, a noiser will - be created using noising_class and kwargs. - noising_class: class to use when initializing noiser - kwargs: noising args for configuring noising to apply - Note that there is no equivalent argparse code for these args - anywhere in our top level train scripts yet. Integration is - still in progress. You can still, however, test out this dataset - functionality with the appropriate args as in the corresponding - unittest: test_noising_dataset. + src_dict (~fairseq.data.Dictionary): source dictionary + seed (int): seed to use when generating random noise + noiser (WordNoising): a pre-initialized :class:`WordNoising` + instance. If this is None, a new instance will be created using + *noising_class* and *kwargs*. + noising_class (class, optional): class to use to initialize a + default :class:`WordNoising` instance. + kwargs (dict, optional): arguments to initialize the default + :class:`WordNoising` instance given by *noiser*. """ - self.src_dataset = src_dataset self.src_dict = src_dict + self.seed = seed self.noiser = noiser if noiser is not None else noising_class( dictionary=src_dict, **kwargs, ) - self.seed = seed def __getitem__(self, index): """ diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py index 2021f397..675352d1 100644 --- a/fairseq/data/round_robin_zip_datasets.py +++ b/fairseq/data/round_robin_zip_datasets.py @@ -13,13 +13,16 @@ from . import FairseqDataset class RoundRobinZipDatasets(FairseqDataset): - """Zip multiple FairseqDatasets together, repeating shorter datasets in a - round-robin fashion to match the length of the longest one. + """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together. + + Shorter datasets are repeated in a round-robin fashion to match the length + of the longest one. Args: - datasets: a dictionary of FairseqDatasets - eval_key: an optional key used at evaluation time that causes this - instance to pass-through batches from `datasets[eval_key]`. + datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of + :class:`~fairseq.data.FairseqDataset` instances. + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. """ def __init__(self, datasets, eval_key=None): @@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset): dataset.valid_size(self._map_index(key, index), max_positions[key]) for key, dataset in self.datasets.items() ) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, 'supports_prefetch', False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch([self._map_index(key, index) for index in indices]) diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index ff4bc030..b2b3cf01 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -14,32 +14,32 @@ from . import FairseqDataset class TokenBlockDataset(FairseqDataset): - """Break a 1d tensor of tokens into blocks. - - The blocks are fetched from the original tensor so no additional memory is allocated. + """Break a Dataset of tokens into blocks. Args: - tokens: 1d tensor of tokens to break into blocks - sizes: sentence lengths (required for 'complete' and 'eos') - block_size: maximum block size (ignored in 'eos' break mode) - break_mode: Mode used for breaking tokens. Values can be one of: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes (List[int]): sentence lengths (required for 'complete' and 'eos') + block_size (int): maximum block size (ignored in 'eos' break mode) + break_mode (str, optional): Mode used for breaking tokens. Values can + be one of: - 'none': break tokens into equally sized blocks (up to block_size) - 'complete': break tokens into blocks (up to block_size) such that blocks contains complete sentences, although block_size may be exceeded if some sentences exceed block_size - 'eos': each block contains one sentence (block_size is ignored) - include_targets: return next tokens as targets + include_targets (bool, optional): return next tokens as targets + (default: False). """ - def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False): + def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False): super().__init__() - self.dataset = ds + self.dataset = dataset self.pad = pad self.eos = eos self.include_targets = include_targets self.slice_indices = [] - self.cache_index = {} - sizes = ds.sizes + + assert len(dataset) == len(sizes) if break_mode is None or break_mode == 'none': total_size = sum(sizes) @@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset): self.sizes = np.array([e - s for s, e in self.slice_indices]) - def __getitem__(self, index): - s, e = self.cache_index[index] + # build index mapping block indices to the underlying dataset indices + self.block_to_dataset_index = [] + ds_idx, ds_remaining = -1, 0 + for to_consume in self.sizes: + if ds_remaining == 0: + ds_idx += 1 + ds_remaining = sizes[ds_idx] + start_ds_idx = ds_idx + start_offset = sizes[ds_idx] - ds_remaining + while to_consume > ds_remaining: + to_consume -= ds_remaining + ds_idx += 1 + ds_remaining = sizes[ds_idx] + ds_remaining -= to_consume + self.block_to_dataset_index.append(( + start_ds_idx, # starting index in dataset + start_offset, # starting offset within starting index + ds_idx, # ending index in dataset + )) + assert ds_remaining == 0 + assert ds_idx == len(self.dataset) - 1 - item = torch.from_numpy(self.cache[s:e]).long() + def __getitem__(self, index): + start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] + buffer = torch.cat([ + self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1) + ]) + slice_s, slice_e = self.slice_indices[index] + length = slice_e - slice_s + s, e = start_offset, start_offset + length + item = buffer[s:e] if self.include_targets: - # target is the sentence, for source, rotate item one token to the left (would start with eos) - # past target is rotated to the left by 2 (padded if its first) + # *target* is the original sentence (=item) + # *source* is rotated left by 1 (maybe left-padded with eos) + # *past_target* is rotated left by 2 (left-padded as needed) if s == 0: - source = np.concatenate([[self.eos], self.cache[0:e - 1]]) - past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]]) + source = torch.cat([item.new([self.eos]), buffer[0:e - 1]]) + past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]]) else: - source = self.cache[s - 1: e - 1] + source = buffer[s - 1:e - 1] if s == 1: - past_target = np.concatenate([[self.eos], self.cache[0:e - 2]]) + past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]]) else: - past_target = self.cache[s - 2:e - 2] + past_target = buffer[s - 2:e - 2] - return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long() + return source, item, past_target return item def __len__(self): return len(self.slice_indices) - def prefetch(self, indices): - indices.sort() - total_size = 0 - for idx in indices: - s, e = self.slice_indices[idx] - total_size += e - s - self.cache = np.empty(total_size, dtype=np.int32) - start = 0 - for idx in indices: - s, e = self.slice_indices[idx] - self.dataset.read_into(s, self.cache[start:start + e - s]) - self.cache_index[idx] = (start, start + e - s) - start += e - s - @property def supports_prefetch(self): - return True + return getattr(self.dataset, 'supports_prefetch', False) + + def prefetch(self, indices): + self.dataset.prefetch({ + ds_idx + for index in indices + for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] + for ds_idx in range(start_ds_idx, end_ds_idx + 1) + }) diff --git a/fairseq/data/transform_eos_dataset.py b/fairseq/data/transform_eos_dataset.py index 6f8158ad..6ac4c6c8 100644 --- a/fairseq/data/transform_eos_dataset.py +++ b/fairseq/data/transform_eos_dataset.py @@ -11,7 +11,7 @@ from . import FairseqDataset class TransformEosDataset(FairseqDataset): - """A dataset wrapper that appends/prepends/strips EOS. + """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS. Note that the transformation is applied in :func:`collater`. @@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset): @property def supports_prefetch(self): - return self.dataset.supports_prefetch() + return getattr(self.dataset, 'supports_prefetch', False) def prefetch(self, indices): return self.dataset.prefetch(indices) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 07a3625c..0916675b 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -6,7 +6,9 @@ # can be found in the PATENTS file in the same directory. from collections import namedtuple +import os import pickle +import subprocess import torch from torch import nn @@ -42,6 +44,38 @@ else: import torch.distributed as dist_no_c10d +def infer_init_method(args): + if args.distributed_init_method is not None: + return + + # support torch.distributed.launch + if all(key in os.environ for key in [ + 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK' + ]): + args.distributed_init_method = 'tcp://{addr}:{port}'.format( + addr=os.environ['MASTER_ADDR'], + port=os.environ['MASTER_PORT'], + ) + args.distributed_world_size = int(os.environ['WORLD_SIZE']) + args.distributed_rank = int(os.environ['RANK']) + + # we can determine the init method automatically for Slurm + elif args.distributed_port > 0: + node_list = os.environ.get('SLURM_JOB_NODELIST') + if node_list is not None: + try: + hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) + args.distributed_init_method = 'tcp://{host}:{port}'.format( + host=hostnames.split()[0].decode('utf-8'), + port=args.distributed_port) + args.distributed_rank = int(os.environ.get('SLURM_PROCID')) + args.device_id = int(os.environ.get('SLURM_LOCALID')) + except subprocess.CalledProcessError as e: # scontrol failed + raise e + except FileNotFoundError: # Slurm is not installed + pass + + def distributed_init(args): if args.distributed_world_size == 1: raise ValueError('Cannot initialize distributed with distributed_world_size=1') @@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384): pickle.loads(bytes(out_buffer[2:size+2].tolist())) ) return result - except pickle.UnpicklingError as e: + except pickle.UnpicklingError: raise Exception( 'Unable to unpickle data from other workers. all_gather_list requires all ' 'workers to enter the function together, so this error usually indicates ' @@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384): 'in your training script that can cause one worker to finish an epoch ' 'while other workers are still iterating over their portions of the data.' ) - diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py index 07858b72..289411fb 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/legacy_distributed_data_parallel.py @@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the c10d version of DDP. This version also supports the *accumulate_grads* feature, which allows faster -training with --update-freq. +training with `--update-freq`. """ import copy @@ -27,18 +27,18 @@ from . import distributed_utils class LegacyDistributedDataParallel(nn.Module): """Implements distributed data parallelism at the module level. - A simplified version of torch.nn.parallel.DistributedDataParallel. - This version uses a c10d process group for communication and does - not broadcast buffers. + A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. + This version uses a c10d process group for communication and does not + broadcast buffers. Args: - module: module to be parallelized - world_size: number of parallel workers + module (~torch.nn.Module): module to be parallelized + world_size (int): number of parallel workers process_group (optional): the c10d process group to be used for distributed data all-reduction. If None, the default process group will be used. - buffer_size: number of elements to buffer before performing all-reduce - (default: 256M). + buffer_size (int, optional): number of elements to buffer before + performing all-reduce (default: 256M). """ def __init__(self, module, world_size, process_group=None, buffer_size=2**28): diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py index cdef45b2..9e77d95a 100644 --- a/fairseq/models/fconv.py +++ b/fairseq/models/fconv.py @@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder): connections are added between layers when ``residual=1`` (which is the default behavior). dropout (float, optional): dropout to be applied before each conv layer - normalization_constant (float, optional): multiplies the result of the - residual block by sqrt(value) - left_pad (bool, optional): whether the input is left-padded. Default: - ``True`` + left_pad (bool, optional): whether the input is left-padded + (default: True). """ def __init__( @@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder): self.residuals = [] layer_in_channels = [in_channels] - for i, (out_channels, kernel_size, residual) in enumerate(convolutions): + for _, (out_channels, kernel_size, residual) in enumerate(convolutions): if residual == 0: residual_dim = out_channels else: diff --git a/fairseq/models/fconv_self_att.py b/fairseq/models/fconv_self_att.py index 19cbb184..1261f86f 100644 --- a/fairseq/models/fconv_self_att.py +++ b/fairseq/models/fconv_self_att.py @@ -524,6 +524,7 @@ def base_architecture(args): args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '') args.pretrained = getattr(args, 'pretrained', 'False') + @register_model_architecture('fconv_self_att', 'fconv_self_att_wp') def fconv_self_att_wp(args): args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 56cb2c3c..d4974b0e 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder): if bidirectional: self.output_units *= 2 - def forward(self, src_tokens, src_lengths): if self.left_pad: # convert left-padding to right-padding @@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder): if self.bidirectional: def combine_bidir(outs): - return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1) + out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() + return out.view(self.num_layers, bsz, -1) final_hiddens = combine_bidir(final_hiddens) final_cells = combine_bidir(final_cells) @@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder): elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) - def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): encoder_out = encoder_out_dict['encoder_out'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] @@ -504,6 +503,7 @@ def base_architecture(args): args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') + @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') def lstm_wiseman_iwslt_de_en(args): args.dropout = getattr(args, 'dropout', 0.1) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 20b3a97b..795a5fa7 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel): # make sure all arguments are present in older models base_lm_architecture(args) - if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False: + if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False: # backward compatibility args.tie_adaptive_proj = True @@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel): args.max_target_positions = args.tokens_per_sample if args.character_embeddings: - embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters), - args.character_embedding_dim, - args.decoder_embed_dim, - args.char_embedder_highway_layers, - ) + embed_tokens = CharacterTokenEmbedder( + task.dictionary, eval(args.character_filters), + args.character_embedding_dim, args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) elif args.adaptive_input: - embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, - args.adaptive_input_factor, args.decoder_embed_dim, - options.eval_str_list(args.adaptive_input_cutoff, type=int)) + embed_tokens = AdaptiveInput( + len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, + args.adaptive_input_factor, args.decoder_embed_dim, + options.eval_str_list(args.adaptive_input_cutoff, type=int), + ) else: embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) @@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel): args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) assert args.decoder_input_dim == args.decoder_output_dim - decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) + decoder = TransformerDecoder( + args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False, + ) return TransformerLanguageModel(decoder) @@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder): args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding - left_pad (bool, optional): whether the input is left-padded. Default: - ``True`` + left_pad (bool, optional): whether the input is left-padded + (default: True). """ def __init__(self, args, dictionary, embed_tokens, left_pad=True): @@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder): args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding - no_encoder_attn (bool, optional): whether to attend to encoder outputs. - Default: ``False`` - left_pad (bool, optional): whether the input is left-padded. Default: - ``False`` + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + left_pad (bool, optional): whether the input is left-padded + (default: False). + final_norm (bool, optional): apply layer norm to the output of the + final decoder layer (default: True). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True): @@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module): Args: args (argparse.Namespace): parsed command-line arguments - no_encoder_attn (bool, optional): whether to attend to encoder outputs. - Default: ``False`` + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). """ def __init__(self, args, no_encoder_attn=False): diff --git a/fairseq/modules/adaptive_input.py b/fairseq/modules/adaptive_input.py index db8dce22..3ad8603d 100644 --- a/fairseq/modules/adaptive_input.py +++ b/fairseq/modules/adaptive_input.py @@ -7,7 +7,6 @@ import torch -import torch.nn.functional as F from torch import nn from typing import List @@ -16,13 +15,13 @@ from typing import List class AdaptiveInput(nn.Module): def __init__( - self, - vocab_size: int, - padding_idx: int, - initial_dim: int, - factor: float, - output_dim: int, - cutoff: List[int], + self, + vocab_size: int, + padding_idx: int, + initial_dim: int, + factor: float, + output_dim: int, + cutoff: List[int], ): super().__init__() diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py index 50bafd74..1446f527 100644 --- a/fairseq/modules/adaptive_softmax.py +++ b/fairseq/modules/adaptive_softmax.py @@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module): m = nn.Sequential( proj, nn.Dropout(self.dropout), - nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \ - if tied_emb is None else TiedLinear(tied_emb, transpose=False) + nn.Linear( + dim, self.cutoff[i + 1] - self.cutoff[i], bias=False, + ) if tied_emb is None else TiedLinear(tied_emb, transpose=False), ) self.tail.append(m) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 2424d728..901eea86 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -9,7 +9,7 @@ import importlib import os from .fairseq_optimizer import FairseqOptimizer -from .fp16_optimizer import FP16Optimizer +from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer OPTIMIZER_REGISTRY = {} diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 22b26b6e..a73fba00 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -70,10 +70,11 @@ class FairseqOptimizer(object): group.update(optimizer_overrides) def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" loss.backward() def multiply_grads(self, c): - """Multiplies grads by a constant ``c``.""" + """Multiplies grads by a constant *c*.""" for p in self.params: if p.grad is not None: p.grad.data.mul_(c) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 19e6b57f..c1358bb2 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -45,6 +45,164 @@ class DynamicLossScaler(object): return False +class FP16Optimizer(optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + """ + + def __init__(self, args, params, fp32_optimizer, fp32_params): + super().__init__(args, params) + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + + if getattr(args, 'fp16_scale_window', None) is None: + if len(args.update_freq) > 1: + raise ValueError( + '--fp16-scale-window must be given explicitly when using a ' + 'custom --update-freq schedule' + ) + scale_window = 2**14 / args.distributed_world_size / args.update_freq[0] + else: + scale_window = args.fp16_scale_window + + self.scaler = DynamicLossScaler( + init_scale=args.fp16_init_scale, + scale_window=scale_window, + tolerance=args.fp16_scale_tolerance, + ) + + @classmethod + def build_optimizer(cls, args, params): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ + # create FP32 copy of parameters and grads + total_param_size = sum(p.data.numel() for p in params) + fp32_params = params[0].new(0).float().new(total_param_size) + offset = 0 + for p in params: + numel = p.data.numel() + fp32_params[offset:offset+numel].copy_(p.data.view(-1)) + offset += numel + fp32_params = torch.nn.Parameter(fp32_params) + fp32_params.grad = fp32_params.data.new(total_param_size) + + fp32_optimizer = optim.build_optimizer(args, [fp32_params]) + return cls(args, params, fp32_optimizer, fp32_params) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.fp32_optimizer.state_dict() + state_dict['loss_scale'] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if 'loss_scale' in state_dict: + self.scaler.loss_scale = state_dict['loss_scale'] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + loss = loss * self.scaler.loss_scale + loss.backward() + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): + if self._needs_sync: + # copy FP16 grads to FP32 + offset = 0 + for p in self.params: + if not p.requires_grad: + continue + grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) + numel = grad_data.numel() + self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1)) + offset += numel + + # correct for dynamic loss scaler + self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale) + + self._needs_sync = False + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + if self._needs_sync: + self._sync_fp16_grads_to_fp32(c) + else: + self.fp32_params.grad.data.mul_(c) + + def clip_grad_norm(self, max_norm): + """Clips gradient norm and updates dynamic loss scaler.""" + self._sync_fp16_grads_to_fp32() + grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm) + + # detect overflow and adjust loss scale + overflow = DynamicLossScaler.has_overflow(grad_norm) + self.scaler.update_scale(overflow) + if overflow: + if self.scaler.loss_scale <= self.args.min_loss_scale: + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. + raise FloatingPointError(( + 'Minimum loss scale reached ({}). Your loss is probably exploding. ' + 'Try lowering the learning rate, using gradient clipping or ' + 'increasing the batch size.' + ).format(self.args.min_loss_scale)) + raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) + return grad_norm + + def step(self, closure=None): + """Performs a single optimization step.""" + self._sync_fp16_grads_to_fp32() + self.fp32_optimizer.step(closure) + + # copy FP32 params back into FP16 model + offset = 0 + for p in self.params: + if not p.requires_grad: + continue + numel = p.data.numel() + p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) + offset += numel + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.fp32_optimizer.zero_grad() + for p in self.params: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + self._needs_sync = False + + class ConvertToFP32(object): """ A wrapper around a list of params that will convert them to FP32 on the @@ -94,14 +252,13 @@ class ConvertToFP32(object): raise StopIteration -class FP16Optimizer(optim.FairseqOptimizer): +class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): """ Wrap an *optimizer* to support FP16 (mixed precision) training. - Args: - args (argparse.Namespace): fairseq args - params (iterable): iterable of parameters to optimize - optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap + Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less + memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff + is reduced optimization speed, which can be mitigated with `--update-freq`. """ def __init__(self, args, params, optimizer): @@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer): tolerance=args.fp16_scale_tolerance, ) - @staticmethod - def build_optimizer(args, params): + @classmethod + def build_optimizer(cls, args, params): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ fp16_optimizer = optim.build_optimizer(args, params) - return FP16Optimizer(args, params, fp16_optimizer) + return cls(args, params, fp16_optimizer) @property def optimizer(self): @@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer): ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer) def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ loss = loss * self.scaler.loss_scale loss.backward() self._grads_are_scaled = True @@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer): assert multiply_grads == 1. def multiply_grads(self, c): - """Multiplies grads by a constant ``c``.""" + """Multiplies grads by a constant *c*.""" if self._grads_are_scaled: self._unscale_grads(c) else: diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 976fee72..4493a8b9 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler @register_lr_scheduler('cosine') class CosineSchedule(FairseqLRScheduler): """Assign LR based on a cyclical schedule that follows the cosine function. - See https://arxiv.org/pdf/1608.03983.pdf for details + + See https://arxiv.org/pdf/1608.03983.pdf for details. + We also support a warmup phase where we linearly increase the learning rate - from some initial learning rate (`--warmup-init-lr`) until the configured - learning rate (`--lr`). - During warmup: + from some initial learning rate (``--warmup-init-lr``) until the configured + learning rate (``--lr``). + + During warmup:: + lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lr = lrs[update_num] - After warmup: + + After warmup:: + lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) - where - t_curr is current percentage of updates within the current period range - t_i is the current period range, which is scaled by t_mul after every iteration + + where ``t_curr`` is current percentage of updates within the current period + range and ``t_i`` is the current period range, which is scaled by ``t_mul`` + after every iteration. """ def __init__(self, args, optimizer): @@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler): if args.warmup_init_lr < 0: args.warmup_init_lr = args.lr[0] - self.min_lr = args.lr[0] + self.min_lr = args.lr[0] self.max_lr = args.max_lr assert self.max_lr > self.min_lr, 'max_lr must be more than lr' @@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler): t_curr = curr_updates - (self.period * i) lr_shrink = self.lr_shrink ** i - min_lr = self.min_lr * lr_shrink + min_lr = self.min_lr * lr_shrink max_lr = self.max_lr * lr_shrink self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index b17e0f5f..a3a48f04 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler): """Decay the LR based on the inverse square root of the update number. We also support a warmup phase where we linearly increase the learning rate - from some initial learning rate (`--warmup-init-lr`) until the configured - learning rate (`--lr`). Thereafter we decay proportional to the number of + from some initial learning rate (``--warmup-init-lr``) until the configured + learning rate (``--lr``). Thereafter we decay proportional to the number of updates, with a decay factor set to align with the configured learning rate. - During warmup: + During warmup:: lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lr = lrs[update_num] - After warmup: - - lr = decay_factor / sqrt(update_num) - - where + After warmup:: decay_factor = args.lr * sqrt(args.warmup_updates) + lr = decay_factor / sqrt(update_num) """ def __init__(self, args, optimizer): diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index 66deba31..7b3307a0 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler class TriangularSchedule(FairseqLRScheduler): """Assign LR based on a triangular cyclical schedule. - See https://arxiv.org/pdf/1506.01186.pdf for details - + See https://arxiv.org/pdf/1506.01186.pdf for details. """ def __init__(self, args, optimizer): diff --git a/fairseq/options.py b/fairseq/options.py index e71f22b9..900eaae8 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): args.update_freq = eval_str_list(args.update_freq, type=int) if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None: args.max_sentences_valid = args.max_sentences + if getattr(args, 'memory_efficient_fp16', False): + args.fp16 = True # Apply architecture configuration. if hasattr(args, 'arch'): @@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'): choices=['json', 'none', 'simple', 'tqdm']) parser.add_argument('--seed', default=1, type=int, metavar='N', help='pseudo random number generator seed') + parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--fp16', action='store_true', help='use FP16') + parser.add_argument('--memory-efficient-fp16', action='store_true', + help='use a memory-efficient version of FP16 training; implies --fp16') parser.add_argument('--fp16-init-scale', default=2**7, type=int, help='default FP16 loss scale') parser.add_argument('--fp16-scale-window', type=int, @@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'): def add_dataset_args(parser, train=False, gen=False): group = parser.add_argument_group('Dataset and data loading') # fmt: off + group.add_argument('--num-workers', default=0, type=int, metavar='N', + help='how many subprocesses to use for data loading') group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', help='ignore too long or too short lines in valid and test set') group.add_argument('--max-tokens', type=int, metavar='N', @@ -178,7 +185,7 @@ def add_distributed_training_args(parser): group = parser.add_argument_group('Distributed training') # fmt: off group.add_argument('--distributed-world-size', type=int, metavar='N', - default=torch.cuda.device_count(), + default=max(1, torch.cuda.device_count()), help='total number of GPUs across all nodes (default: all visible GPUs)') group.add_argument('--distributed-rank', default=0, type=int, help='rank of the current worker') @@ -189,7 +196,7 @@ def add_distributed_training_args(parser): 'establish initial connetion') group.add_argument('--distributed-port', default=-1, type=int, help='port number (not required if using --distributed-init-method)') - group.add_argument('--device-id', default=0, type=int, + group.add_argument('--device-id', '--local_rank', default=0, type=int, help='which GPU to use (usually configured automatically)') group.add_argument('--ddp-backend', default='c10d', type=str, choices=['c10d', 'no_c10d'], @@ -197,8 +204,8 @@ def add_distributed_training_args(parser): group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB', help='bucket size for reduction') group.add_argument('--fix-batches-to-gpus', action='store_true', - help='Don\'t shuffle batches between GPUs, this reduces overall ' - 'randomness and may affect precision but avoids the cost of' + help='don\'t shuffle batches between GPUs; this reduces overall ' + 'randomness and may affect precision but avoids the cost of ' 're-reading the data') # fmt: on return group @@ -263,7 +270,9 @@ def add_checkpoint_args(parser): group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', help='save a checkpoint (and validate) every N updates') group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N', - help='keep last N checkpoints saved with --save-interval-updates') + help='keep the last N checkpoints saved with --save-interval-updates') + group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N', + help='keep last N epoch checkpoints') group.add_argument('--no-save', action='store_true', help='don\'t save models or checkpoints') group.add_argument('--no-epoch-checkpoints', action='store_true', @@ -280,11 +289,11 @@ def add_common_eval_args(group): help='path(s) to model file(s), colon separated') group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, help='remove BPE tokens before scoring') - group.add_argument('--cpu', action='store_true', help='generate on CPU') group.add_argument('--quiet', action='store_true', help='only print final scores') group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', - help='a dictionary used to override model args at generation that were used during model training') + help='a dictionary used to override model args at generation ' + 'that were used during model training') # fmt: on diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 2bce2e82..76300d72 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -22,6 +22,7 @@ class SequenceGenerator(object): match_source_len=False, no_repeat_ngram_size=0 ): """Generates translations of a given source sentence. + Args: beam_size (int, optional): beam width (default: 1) min/maxlen (int, optional): the length of the generated output will @@ -90,11 +91,14 @@ class SequenceGenerator(object): cuda=False, timer=None, prefix_size=0, ): """Iterate over a batched dataset and yield individual translations. + Args: - maxlen_a/b: generate sequences of maximum length ax + b, - where x is the source sentence length. - cuda: use GPU for generation - timer: StopwatchMeter for timing generations. + maxlen_a/b (int, optional): generate sequences of maximum length + ``ax + b``, where ``x`` is the source sentence length. + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + prefix_size (int, optional): prefill the generation with the gold + prefix up to this length. """ if maxlen_b is None: maxlen_b = self.maxlen @@ -132,12 +136,13 @@ class SequenceGenerator(object): """Generate a batch of translations. Args: - encoder_input: dictionary containing the inputs to - model.encoder.forward - beam_size: int overriding the beam size. defaults to - self.beam_size - max_len: maximum length of the generated sequence - prefix_tokens: force decoder to begin with these tokens + encoder_input (dict): dictionary containing the inputs to + *model.encoder.forward*. + beam_size (int, optional): overriding the beam size + (default: *self.beam_size*). + max_len (int, optional): maximum length of the generated sequence + prefix_tokens (LongTensor, optional): force decoder to begin with + these tokens """ with torch.no_grad(): return self._generate(encoder_input, beam_size, maxlen, prefix_tokens) diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index ce6ab31a..2495088e 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -87,4 +87,3 @@ class SequenceScorer(object): index=sample['target'].data.unsqueeze(-1), ) return avg_probs.squeeze(2), avg_attn - diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 430a49a2..acf4771c 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # -# This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. import argparse diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 0c54e505..a8983618 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -61,29 +61,32 @@ class FairseqTask(object): def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, + seed=1, num_shards=1, shard_id=0, num_workers=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch - max_tokens (int, optional): max number of tokens in each batch. - Default: ``None`` + max_tokens (int, optional): max number of tokens in each batch + (default: None). max_sentences (int, optional): max number of sentences in each - batch. Default: ``None`` + batch (default: None). max_positions (optional): max sentence length supported by the - model. Default: ``None`` + model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for - sentences that are too long. Default: ``False`` + sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to - be a multiple of N. Default: ``1`` + be a multiple of N (default: 1). seed (int, optional): seed for random number generator for - reproducibility. Default: ``1`` + reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N - shards. Default: ``1`` + shards (default: 1). shard_id (int, optional): which shard of the data iterator to - return. Default: ``0`` + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the @@ -114,6 +117,7 @@ class FairseqTask(object): seed=seed, num_shards=num_shards, shard_id=shard_id, + num_workers=num_workers, ) def build_model(self, args): diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 882ad3bf..bfbac862 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -10,9 +10,15 @@ import numpy as np import os from fairseq.data import ( - ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, - MonolingualDataset, TokenBlockDataset, TruncatedDictionary, - IndexedCachedDataset, IndexedDataset) + ConcatDataset, + Dictionary, + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + MonolingualDataset, + TokenBlockDataset, + TruncatedDictionary, +) from . import FairseqTask, register_task @@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask): 'If set to "eos", includes only one sentence per sample.') parser.add_argument('--tokens-per-sample', default=1024, type=int, help='max number of tokens per sample for LM dataset') + parser.add_argument('--lazy-load', action='store_true', + help='load the dataset lazily') parser.add_argument('--raw-text', default=False, action='store_true', help='load raw text dataset') parser.add_argument('--output-dictionary-size', default=-1, type=int, @@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask): if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): - ds = IndexedDataset(path, fix_lua_indexing=True) + if self.args.lazy_load: + ds = IndexedDataset(path, fix_lua_indexing=True) + else: + ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: if k > 0: break @@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask): loaded_datasets.append( TokenBlockDataset( - ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), + ds, ds.sizes, self.args.tokens_per_sample, + pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, - )) + ) + ) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index b37ae479..69fcb124 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -12,8 +12,12 @@ import torch from fairseq import options from fairseq.data import ( - Dictionary, LanguagePairDataset, IndexedInMemoryDataset, - IndexedRawTextDataset, RoundRobinZipDatasets, + Dictionary, + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + LanguagePairDataset, + RoundRobinZipDatasets, ) from fairseq.models import FairseqMultiModel @@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask): help='source language (only needed for inference)') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language (only needed for inference)') + parser.add_argument('--lazy-load', action='store_true', + help='load the dataset lazily') parser.add_argument('--raw-text', action='store_true', help='load raw text dataset') parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', @@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask): filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True - elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename): + elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) - elif IndexedInMemoryDataset.exists(path): - return IndexedInMemoryDataset(path, fix_lua_indexing=True) + elif IndexedDataset.exists(path): + if self.args.lazy_load: + return IndexedDataset(path, fix_lua_indexing=True) + else: + return IndexedCachedDataset(path, fix_lua_indexing=True) return None def sort_lang_pair(lang_pair): diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 7d4e62f0..f14189a7 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -6,13 +6,17 @@ # can be found in the PATENTS file in the same directory. import itertools -import numpy as np import os from fairseq import options, utils from fairseq.data import ( - data_utils, Dictionary, LanguagePairDataset, ConcatDataset, - IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset + ConcatDataset, + data_utils, + Dictionary, + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + LanguagePairDataset, ) from . import FairseqTask, register_task @@ -49,6 +53,8 @@ class TranslationTask(FairseqTask): help='source language') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language') + parser.add_argument('--lazy-load', action='store_true', + help='load the dataset lazily') parser.add_argument('--raw-text', action='store_true', help='load raw text dataset') parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', @@ -132,7 +138,10 @@ class TranslationTask(FairseqTask): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): - return IndexedCachedDataset(path, fix_lua_indexing=True) + if self.args.lazy_load: + return IndexedDataset(path, fix_lua_indexing=True) + else: + return IndexedCachedDataset(path, fix_lua_indexing=True) return None src_datasets = [] diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py index 378c1f29..80c1b4f4 100644 --- a/fairseq/tokenizer.py +++ b/fairseq/tokenizer.py @@ -6,10 +6,11 @@ # can be found in the PATENTS file in the same directory. from collections import Counter -import os, re +from multiprocessing import Pool +import os +import re import torch -from multiprocessing import Pool SPACE_NORMALIZER = re.compile(r"\s+") @@ -27,7 +28,8 @@ def safe_readline(f): return f.readline() except UnicodeDecodeError: pos -= 1 - f.seek(pos) # search where this character begins + f.seek(pos) # search where this character begins + class Tokenizer: @@ -41,7 +43,7 @@ class Tokenizer: end = offset + chunk_size f.seek(offset) if offset > 0: - safe_readline(f) # drop first incomplete line + safe_readline(f) # drop first incomplete line line = f.readline() while line: for word in tokenize(line): @@ -73,14 +75,17 @@ class Tokenizer: merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word)) @staticmethod - def binarize(filename, dict, consumer, tokenize=tokenize_line, - append_eos=True, reverse_order=False, - offset=0, end=-1): + def binarize( + filename, dict, consumer, tokenize=tokenize_line, append_eos=True, + reverse_order=False, offset=0, end=-1, + ): nseq, ntok = 0, 0 replaced = Counter() + def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) + with open(filename, 'r') as f: f.seek(offset) # next(f) breaks f.tell(), hence readline() must be used diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 64089ccf..9fe70ebc 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -30,22 +30,23 @@ class Trainer(object): """ def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None): - - if not torch.cuda.is_available(): - raise NotImplementedError('Training on CPU is not supported') - self.args = args self.task = task # copy model and criterion to current device - self.criterion = criterion.cuda() + self.criterion = criterion + self._model = model + self.cuda = torch.cuda.is_available() and not args.cpu if args.fp16: - self._model = model.half().cuda() - else: - self._model = model.cuda() + self._model = self._model.half() + if self.cuda: + self.criterion = self.criterion.cuda() + self._model = self._model.cuda() self._dummy_batch = dummy_batch self._oom_batch = oom_batch + + self._lr_scheduler = None self._num_updates = 0 self._optim_history = None self._optimizer = None @@ -71,7 +72,6 @@ class Trainer(object): self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds - @property def model(self): if self._wrapped_model is None: @@ -89,19 +89,26 @@ class Trainer(object): self._build_optimizer() return self._optimizer + @property + def lr_scheduler(self): + if self._lr_scheduler is None: + self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) + return self._lr_scheduler + def _build_optimizer(self): + params = list(filter(lambda p: p.requires_grad, self.model.parameters())) if self.args.fp16: - if torch.cuda.get_device_capability(0)[0] < 7: + if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: print('| WARNING: your device does NOT support faster training with --fp16, ' 'please switch to FP32 which is likely to be faster') - params = list(filter(lambda p: p.requires_grad, self.model.parameters())) - self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + if self.args.memory_efficient_fp16: + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params) + else: + self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) else: - if torch.cuda.get_device_capability(0)[0] >= 7: + if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: print('| NOTICE: your device may support faster training with --fp16') - self._optimizer = optim.build_optimizer(self.args, self.model.parameters()) - - self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer) + self._optimizer = optim.build_optimizer(self.args, params) def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" @@ -151,7 +158,8 @@ class Trainer(object): # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if self.cuda: + torch.cuda.manual_seed(seed) self.model.train() self.zero_grad() @@ -296,7 +304,8 @@ class Trainer(object): for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory - torch.cuda.empty_cache() + if self.cuda: + torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) else: raise e @@ -377,4 +386,6 @@ class Trainer(object): def _prepare_sample(self, sample): if sample is None or len(sample) == 0: return None - return utils.move_to_cuda(sample) + if self.cuda: + sample = utils.move_to_cuda(sample) + return sample diff --git a/fairseq/utils.py b/fairseq/utils.py index 6dfe2b21..6a40e0e2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -378,6 +378,14 @@ def item(tensor): return tensor +def clip_grad_norm_(tensor, max_norm): + grad_norm = item(torch.norm(tensor)) + if grad_norm > max_norm > 0: + clip_coef = max_norm / (grad_norm + 1e-6) + tensor.mul_(clip_coef) + return grad_norm + + def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float('-inf')).type_as(t) diff --git a/fb_train.py b/fb_train.py new file mode 100644 index 00000000..4befddc6 --- /dev/null +++ b/fb_train.py @@ -0,0 +1,18 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch.fb.rendezvous.zeus # noqa: F401 + +from fairseq import options + +from train import main + + +if __name__ == '__main__': + parser = options.get_training_parser() + args = options.parse_args_and_arch(parser) + main(args) diff --git a/generate.py b/generate.py index 211ed37d..39c9eb8d 100644 --- a/generate.py +++ b/generate.py @@ -11,7 +11,7 @@ Translate pre-processed data with a trained model. import torch -from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils +from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_scorer import SequenceScorer @@ -41,7 +41,9 @@ def main(args): # Load ensemble print('| loading model(s) from {}'.format(args.path)) - models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) + models, _model_args = utils.load_ensemble_for_inference( + args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), + ) # Optimize ensemble for generation for model in models: @@ -69,6 +71,7 @@ def main(args): required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, + num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator diff --git a/interactive.py b/interactive.py index 78c77182..30680b3d 100644 --- a/interactive.py +++ b/interactive.py @@ -75,8 +75,9 @@ def main(args): # Load ensemble print('| loading model(s) from {}'.format(args.path)) - model_paths = args.path.split(':') - models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides)) + models, _model_args = utils.load_ensemble_for_inference( + args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), + ) # Set dictionaries tgt_dict = task.target_dictionary diff --git a/multiprocessing_train.py b/multiprocessing_train.py deleted file mode 100644 index 91ab87a4..00000000 --- a/multiprocessing_train.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -import os -import signal -import torch - -from fairseq import distributed_utils, options - -from train import main as single_process_main - - -def main(args): - if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': - print('| WARNING: when using --update-freq on a single machine, you ' - 'will get better performance with --ddp-backend=no_c10d') - - mp = torch.multiprocessing.get_context('spawn') - - # Create a thread to listen for errors in the child processes. - error_queue = mp.SimpleQueue() - error_handler = ErrorHandler(error_queue) - - # Train with multiprocessing. - procs = [] - base_rank = args.distributed_rank - for i in range(torch.cuda.device_count()): - args.distributed_rank = base_rank + i - args.device_id = i - procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True)) - procs[i].start() - error_handler.add_child(procs[i].pid) - for p in procs: - p.join() - - -def run(args, error_queue): - try: - args.distributed_rank = distributed_utils.distributed_init(args) - single_process_main(args) - except KeyboardInterrupt: - pass # killed by parent, do nothing - except Exception: - # propagate exception to parent process, keeping original traceback - import traceback - error_queue.put((args.distributed_rank, traceback.format_exc())) - - -class ErrorHandler(object): - """A class that listens for exceptions in children processes and propagates - the tracebacks to the parent process.""" - - def __init__(self, error_queue): - import signal - import threading - self.error_queue = error_queue - self.children_pids = [] - self.error_thread = threading.Thread(target=self.error_listener, daemon=True) - self.error_thread.start() - signal.signal(signal.SIGUSR1, self.signal_handler) - - def add_child(self, pid): - self.children_pids.append(pid) - - def error_listener(self): - (rank, original_trace) = self.error_queue.get() - self.error_queue.put((rank, original_trace)) - os.kill(os.getpid(), signal.SIGUSR1) - - def signal_handler(self, signalnum, stackframe): - for pid in self.children_pids: - os.kill(pid, signal.SIGINT) # kill children processes - (rank, original_trace) = self.error_queue.get() - msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n" - msg += original_trace - raise Exception(msg) - - -if __name__ == '__main__': - parser = options.get_training_parser() - args = options.parse_args_and_arch(parser) - main(args) diff --git a/preprocess.py b/preprocess.py index 95ae4d64..56fa0f83 100644 --- a/preprocess.py +++ b/preprocess.py @@ -18,7 +18,7 @@ import shutil from fairseq.data import indexed_dataset, dictionary from fairseq.tokenizer import Tokenizer, tokenize_line -from multiprocessing import Pool, Manager, Process +from multiprocessing import Pool def get_parser(): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 7d86f2f7..406876e2 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase): train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16']) generate_main(data_dir) + def test_memory_efficient_fp16(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16']) + generate_main(data_dir) + def test_update_freq(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_update_freq') as data_dir: @@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase): data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'], ) self.assertTrue( - 'skip this example with --skip-invalid-size-inputs-valid-test' \ - in str(context.exception) + 'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception) ) train_translation_model( data_dir, 'fconv_iwslt_de_en', diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 980ca971..aee56e45 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -12,10 +12,6 @@ import os import tempfile import unittest -import torch - -from fairseq import options - from . import test_binaries @@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase): '--fp16-init-scale', '4096', ]) + def test_reproducibility_memory_efficient_fp16(self): + self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [ + '--memory-efficient-fp16', + '--fp16-init-scale', '4096', + ]) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_train.py b/tests/test_train.py index 2d9b3168..cfffc3ed 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -39,8 +39,10 @@ def mock_dict(): def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): - tokens = torch.LongTensor(list(range(epoch_size))) - tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False) + tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) + tokens_ds = data.TokenBlockDataset( + tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False, + ) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) epoch_itr = data.EpochBatchIterator( @@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase): self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] - def test_load_partial_checkpoint(self): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) diff --git a/train.py b/train.py index 4cee08c7..3475a5eb 100644 --- a/train.py +++ b/train.py @@ -28,9 +28,8 @@ def main(args): args.max_tokens = 6000 print(args) - if not torch.cuda.is_available(): - raise NotImplementedError('Training on CPU is not supported') - torch.cuda.set_device(args.device_id) + if torch.cuda.is_available() and not args.cpu: + torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) # Setup task, e.g., translation, language modeling, etc. @@ -74,6 +73,7 @@ def main(args): seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, + num_workers=args.num_workers, ) # Load the latest checkpoint if one is available @@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets): seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, + num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, @@ -306,7 +307,15 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: - os.remove(old_chk) + if os.path.lexists(old_chk): + os.remove(old_chk) + + if args.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt') + for old_chk in checkpoints[args.keep_last_epochs:]: + if os.path.lexists(old_chk): + os.remove(old_chk) def load_checkpoint(args, trainer, epoch_itr): @@ -346,23 +355,50 @@ def load_dataset_splits(task, splits): raise e +def distributed_main(i, args): + import socket + args.device_id = i + if args.distributed_rank is None: # torch.multiprocessing.spawn + args.distributed_rank = i + args.distributed_rank = distributed_utils.distributed_init(args) + print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) + main(args) + + if __name__ == '__main__': parser = options.get_training_parser() args = options.parse_args_and_arch(parser) - if args.distributed_port > 0 or args.distributed_init_method is not None: - from distributed_train import main as distributed_main + if args.distributed_init_method is None: + distributed_utils.infer_init_method(args) - distributed_main(args) + if args.distributed_init_method is not None: + # distributed training + distributed_main(args.device_id, args) + args.distributed_rank = distributed_utils.distributed_init(args) + main(args) elif args.distributed_world_size > 1: - from multiprocessing_train import main as multiprocessing_main - - # Set distributed training parameters for a single node. - args.distributed_world_size = torch.cuda.device_count() + # fallback for single node with multiple GPUs port = random.randint(10000, 20000) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) - args.distributed_port = port + 1 + args.distributed_rank = None # set based on device id + print( + '''| NOTE: you may get better performance with: - multiprocessing_main(args) + python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...) + '''.format( + ngpu=args.distributed_world_size, + no_c10d=( + '--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d' + else '' + ), + ) + ) + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, ), + nprocs=args.distributed_world_size, + ) else: + # single GPU training main(args)