Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
This commit is contained in:
Myle Ott 2019-01-04 20:00:49 -08:00 committed by Facebook Github Bot
parent 0cb87130e7
commit 7633129ba8
59 changed files with 837 additions and 555 deletions

View File

@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
Fairseq features: Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines - multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU - fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates - large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training - fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks - extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets. translation and language modeling datasets.
@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6 * Python version 3.6
Currently fairseq requires PyTorch version >= 0.4.0. Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation. Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with If you use Docker make sure to increase the shared memory size either with

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)

View File

@ -6,8 +6,26 @@
Criterions Criterions
========== ==========
Criterions compute the loss function given the model and batch, roughly::
loss = criterion(model, batch)
.. automodule:: fairseq.criterions .. automodule:: fairseq.criterions
:members: :members:
.. autoclass:: fairseq.criterions.FairseqCriterion .. autoclass:: fairseq.criterions.FairseqCriterion
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:

View File

@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset .. autoclass:: fairseq.data.MonolingualDataset
:members: :members:
**Helper Datasets**
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:
.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:
Dictionary Dictionary
---------- ----------
@ -32,6 +46,8 @@ Dictionary
Iterators Iterators
--------- ---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator .. autoclass:: fairseq.data.CountingIterator
:members: :members:
.. autoclass:: fairseq.data.EpochBatchIterator .. autoclass:: fairseq.data.EpochBatchIterator

View File

@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py > MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \ > python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \ --path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt | loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types | [en] dictionary: 44206 types
| [fr] dictionary: 44463 types | [fr] dictionary: 44463 types
| Type the input sentence and press return: | Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ? > Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ? O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12 P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
This generation script produces four types of outputs: a line prefixed This generation script produces three types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the with *O* is a copy of the original source sentence; *H* is the
vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *P* is the
hypothesis along with an average log-likelihood; and *A* is the positional score per token position, including the
attention maxima for each word in the hypothesis, including the
end-of-sentence marker which is omitted from the text. end-of-sentence marker which is omitted from the text.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a

View File

@ -6,7 +6,29 @@
Learning Rate Schedulers Learning Rate Schedulers
======================== ========================
TODO Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.
.. automodule:: fairseq.optim.lr_scheduler .. automodule:: fairseq.optim.lr_scheduler
:members: :members:
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:

View File

@ -1,8 +1,8 @@
Modules Modules
======= =======
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
helpful when implementing a new :class:`FairseqModel`. be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
.. automodule:: fairseq.modules .. automodule:: fairseq.modules
:members: :members:

View File

@ -6,5 +6,27 @@
Optimizers Optimizers
========== ==========
Optimizers update the Model parameters based on the gradients.
.. automodule:: fairseq.optim .. automodule:: fairseq.optim
:members: :members:
.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:

View File

@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs): for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train')) itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr): for num_updates, batch in enumerate(itr):
loss = criterion(model, batch) task.train_step(batch, model, criterion, optimizer)
optimizer.backward(loss) average_and_clip_gradients()
optimizer.step() optimizer.step()
lr_scheduler.step_update(num_updates) lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
**Registering new plug-ins** **Registering new plug-ins**
New plug-ins are *registered* through a set of ``@register`` function New plug-ins are *registered* through a set of ``@register`` function

View File

@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
------------------------------- -------------------------------
Finally we can write a short script to evaluate our model on new inputs. Create Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents:: a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation # Parse command-line arguments for generation
parser = options.get_generation_parser() parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
# Setup task # Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load model # Load model

View File

@ -55,7 +55,9 @@ def main(parsed_args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path)) print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides)) models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)
for arg in vars(parsed_args).keys(): for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[ max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models model.max_positions() for model in models
]), ]),
ignore_invalid_inputs=True,
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True, num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()

View File

@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
@ -33,7 +33,6 @@ __all__ = [
'GroupedIterator', 'GroupedIterator',
'IndexedCachedDataset', 'IndexedCachedDataset',
'IndexedDataset', 'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'MonolingualDataset', 'MonolingualDataset',

View File

@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
class BacktranslationDataset(FairseqDataset): class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__( def __init__(
self, self,
tgt_dataset, tgt_dataset,
@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset):
cuda=True, cuda=True,
**kwargs **kwargs
): ):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
cuda: use GPU for generation
"""
self.tgt_dataset = tgt_dataset self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a self.max_len_a = max_len_a
@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset):
""" """
tgt_size = self.tgt_dataset.size(index)[0] tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size) return (tgt_size, tgt_size)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch() return getattr(self.tgt_dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices) return self.tgt_dataset.prefetch(indices)

View File

@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
if isinstance(sample_ratios, int): if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets) sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios) self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets] self.real_sizes = [len(d) for d in self.datasets]
def __len__(self): def __len__(self):
return self.cummulative_sizes[-1] return self.cumulative_sizes[-1]
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0: if dataset_idx == 0:
sample_idx = idx sample_idx = idx
else: else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx] sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx] return self.datasets[dataset_idx][sample_idx]
@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices): def prefetch(self, indices):
frm = 0 frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets): for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds) real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to frm = to

View File

@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
size_fn (callable): function that returns the size of a given index size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size. max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise. Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception raise_exception (bool, optional): if ``True``, raise an exception if
if any elements are filtered. Default: ``False`` any elements are filtered (default: False).
""" """
def check_size(idx): def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
@ -128,12 +128,12 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at num_tokens_fn (callable): function that returns the number of tokens at
a given index a given index
max_tokens (int, optional): max number of tokens in each batch. max_tokens (int, optional): max number of tokens in each batch
Default: ``None`` (default: None).
max_sentences (int, optional): max number of sentences in each max_sentences (int, optional): max number of sentences in each
batch. Default: ``None`` batch (default: None).
required_batch_size_multiple (int, optional): require batch size to required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1`` be a multiple of N (default: 1).
""" """
max_tokens = max_tokens if max_tokens is not None else float('Inf') max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf') max_sentences = max_sentences if max_sentences is not None else float('Inf')

View File

@ -200,11 +200,15 @@ class Dictionary(object):
t[-1] = self.eos() t[-1] = self.eos()
return t return t
class TruncatedDictionary(object): class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length): def __init__(self, wrapped_dict, length):
self.__class__ = type(wrapped_dict.__class__.__name__, self.__class__ = type(
(self.__class__, wrapped_dict.__class__), {}) wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{}
)
self.__dict__ = wrapped_dict.__dict__ self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length) self.length = min(len(self.wrapped_dict), length)

View File

@ -7,8 +7,6 @@
import torch.utils.data import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset): class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching.""" """A dataset that provides helpers for batching."""
@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False return False
def prefetch(self, indices): def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError raise NotImplementedError

View File

@ -52,13 +52,12 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False, read_data=True): def __init__(self, path, fix_lua_indexing=False):
super().__init__() super().__init__()
self.fix_lua_indexing = fix_lua_indexing self.fix_lua_indexing = fix_lua_indexing
self.read_index(path) self.read_index(path)
self.data_file = None self.data_file = None
if read_data: self.path = path
self.read_data(path)
def read_index(self, path): def read_index(self, path):
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
self.data_file.close() self.data_file.close()
def __getitem__(self, i): def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i) self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.size return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (
@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
os.path.exists(data_file_path(path)) os.path.exists(data_file_path(path))
) )
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset): class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True) super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None self.cache = None
self.cache_index = {} self.cache_index = {}
@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices): def prefetch(self, indices):
if all(i in self.cache_index for i in indices): if all(i in self.cache_index for i in indices):
return return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices)) indices = sorted(set(indices))
total_size = 0 total_size = 0
for i in indices: for i in indices:
@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
return item return item
class IndexedInMemoryDataset(IndexedDataset): class IndexedRawTextDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation. """Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory""" Original lines are also kept in memory"""
@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
self.sizes.append(len(tokens)) self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes) self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
return self.tokens_list[i] return self.tokens_list[i]
@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file): def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False) index = IndexedDataset(another_file)
assert index.dtype == self.dtype assert index.dtype == self.dtype
begin = self.data_offsets[-1] begin = self.data_offsets[-1]

View File

@ -69,17 +69,19 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices indices
seed (int, optional): seed for random number generator for seed (int, optional): seed for random number generator for
reproducibility. Default: 1 reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N num_shards (int, optional): shard the data iterator into N
shards. Default: 1 shards (default: 1).
shard_id (int, optional): which shard of the data iterator to shard_id (int, optional): which shard of the data iterator to
return. Default: 0 return (default: 0).
buffer_size (int, optional): number of batches to buffer. Default: 5 num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
""" """
def __init__( def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5, num_workers=0,
): ):
assert isinstance(dataset, torch.utils.data.Dataset) assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset self.dataset = dataset
@ -88,14 +90,12 @@ class EpochBatchIterator(object):
self.seed = seed self.seed = seed
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
self.buffer_size = buffer_size self.num_workers = num_workers
self.epoch = 0 self.epoch = 0
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
self._supports_prefetch = ( self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
def __len__(self): def __len__(self):
return len(self.frozen_batches) return len(self.frozen_batches)
@ -105,11 +105,10 @@ class EpochBatchIterator(object):
Args: Args:
shuffle (bool, optional): shuffle batches before returning the shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True`` iterator (default: True).
fix_batches_to_gpus: ensure that batches are always fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default: that :attr:`dataset` supports prefetching (default: False).
``False``
""" """
if self._next_epoch_itr is not None: if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr self._cur_epoch_itr = self._next_epoch_itr
@ -117,7 +116,8 @@ class EpochBatchIterator(object):
else: else:
self.epoch += 1 self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch( self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus) self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self): def end_of_epoch(self):
@ -179,50 +179,14 @@ class EpochBatchIterator(object):
batches = self.frozen_batches batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(BufferedIterator( return CountingIterator(torch.utils.data.DataLoader(
torch.utils.data.DataLoader( self.dataset,
self.dataset, collate_fn=self.collate_fn,
collate_fn=self.collate_fn, batch_sampler=batches,
batch_sampler=batches, num_workers=self.num_workers,
),
buffer_size=self.buffer_size,
)) ))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object): class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items. """Wrapper around an iterable that returns groups (chunks) of items.
@ -261,7 +225,7 @@ class ShardedIterator(object):
num_shards (int): number of shards to split the iterable into num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None`` evenly divide *num_shards* (default: None).
""" """
def __init__(self, iterable, num_shards, shard_id, fill_value=None): def __init__(self, iterable, num_shards, shard_id, fill_value=None):

View File

@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side. left_pad_source (bool, optional): pad source tensors on the left side
Default: ``True`` (default: True).
left_pad_target (bool, optional): pad target tensors on the left side. left_pad_target (bool, optional): pad target tensors on the left side
Default: ``False`` (default: False).
max_source_positions (int, optional): max number of tokens in the source max_source_positions (int, optional): max number of tokens in the
sentence. Default: ``1024`` source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the target max_target_positions (int, optional): max number of tokens in the
sentence. Default: ``1024`` target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching. shuffle (bool, optional): shuffle dataset elements before batching
Default: ``True`` (default: True).
input_feeding (bool, optional): create a shifted version of the targets input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing. to be passed into the model for input feeding/teacher forcing
Default: ``True`` (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end of remove_eos_from_source (bool, optional): if set, removes eos from end
source if it's present. Default: ``False`` of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False`` target if it's absent (default: False).
""" """
def __init__( def __init__(
@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return ( return (
hasattr(self.src, 'supports_prefetch') getattr(self.src, 'supports_prefetch', False)
and self.src.supports_prefetch and getattr(self.tgt, 'supports_prefetch', False)
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
) )
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)

View File

@ -9,7 +9,6 @@ import numpy as np
import torch import torch
from . import data_utils, FairseqDataset from . import data_utils, FairseqDataset
from typing import List
def collate(samples, pad_idx, eos_idx): def collate(samples, pad_idx, eos_idx):
@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset):
dataset (torch.utils.data.Dataset): dataset to wrap dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching. shuffle (bool, optional): shuffle the elements before batching
Default: ``True`` (default: True).
""" """
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset):
self.add_eos_for_other_targets = add_eos_for_other_targets self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle self.shuffle = shuffle
assert targets is None or all( assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'" "targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0: if targets is not None and len(targets) == 0:
targets = None targets = None
self.targets = targets self.targets = targets
@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.dataset.supports_prefetch return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
self.dataset.prefetch(indices) self.dataset.prefetch(indices)

View File

@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset):
**kwargs **kwargs
): ):
""" """
Sets up a noising dataset which takes a src batch, generates Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
a noisy src using a noising config, and returns the samples based on the supplied noising configuration.
corresponding {noisy src, original src} batch
Args: Args:
src_dataset: dataset which will be used to build self.src_dataset -- src_dataset (~torch.utils.data.Dataset): dataset to wrap.
to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset src_lengths are accurately calculated by language_pair_dataset
@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset):
We use language_pair_dataset here to encapsulate the tgt_dataset We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects. batches in the structure that SequenceGenerator expects.
src_dict: src dict src_dict (~fairseq.data.Dictionary): source dictionary
src_dict: src dictionary seed (int): seed to use when generating random noise
seed: seed to use when generating random noise noiser (WordNoising): a pre-initialized :class:`WordNoising`
noiser: a pre-initialized noiser. If this is None, a noiser will instance. If this is None, a new instance will be created using
be created using noising_class and kwargs. *noising_class* and *kwargs*.
noising_class: class to use when initializing noiser noising_class (class, optional): class to use to initialize a
kwargs: noising args for configuring noising to apply default :class:`WordNoising` instance.
Note that there is no equivalent argparse code for these args kwargs (dict, optional): arguments to initialize the default
anywhere in our top level train scripts yet. Integration is :class:`WordNoising` instance given by *noiser*.
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_noising_dataset.
""" """
self.src_dataset = src_dataset self.src_dataset = src_dataset
self.src_dict = src_dict self.src_dict = src_dict
self.seed = seed
self.noiser = noiser if noiser is not None else noising_class( self.noiser = noiser if noiser is not None else noising_class(
dictionary=src_dict, **kwargs, dictionary=src_dict, **kwargs,
) )
self.seed = seed
def __getitem__(self, index): def __getitem__(self, index):
""" """

View File

@ -13,13 +13,16 @@ from . import FairseqDataset
class RoundRobinZipDatasets(FairseqDataset): class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple FairseqDatasets together, repeating shorter datasets in a """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
round-robin fashion to match the length of the longest one.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args: Args:
datasets: a dictionary of FairseqDatasets datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
eval_key: an optional key used at evaluation time that causes this :class:`~fairseq.data.FairseqDataset` instances.
instance to pass-through batches from `datasets[eval_key]`. eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
""" """
def __init__(self, datasets, eval_key=None): def __init__(self, datasets, eval_key=None):
@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset):
dataset.valid_size(self._map_index(key, index), max_positions[key]) dataset.valid_size(self._map_index(key, index), max_positions[key])
for key, dataset in self.datasets.items() for key, dataset in self.datasets.items()
) )
@property
def supports_prefetch(self):
return all(
getattr(dataset, 'supports_prefetch', False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])

View File

@ -14,32 +14,32 @@ from . import FairseqDataset
class TokenBlockDataset(FairseqDataset): class TokenBlockDataset(FairseqDataset):
"""Break a 1d tensor of tokens into blocks. """Break a Dataset of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args: Args:
tokens: 1d tensor of tokens to break into blocks dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos') sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode) block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of: break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size) - 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that - 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored) - 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets include_targets (bool, optional): return next tokens as targets
(default: False).
""" """
def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False): def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__() super().__init__()
self.dataset = ds self.dataset = dataset
self.pad = pad self.pad = pad
self.eos = eos self.eos = eos
self.include_targets = include_targets self.include_targets = include_targets
self.slice_indices = [] self.slice_indices = []
self.cache_index = {}
sizes = ds.sizes assert len(dataset) == len(sizes)
if break_mode is None or break_mode == 'none': if break_mode is None or break_mode == 'none':
total_size = sum(sizes) total_size = sum(sizes)
@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset):
self.sizes = np.array([e - s for s, e in self.slice_indices]) self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index): # build index mapping block indices to the underlying dataset indices
s, e = self.cache_index[index] self.block_to_dataset_index = []
ds_idx, ds_remaining = -1, 0
for to_consume in self.sizes:
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index.append((
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
))
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
item = torch.from_numpy(self.cache[s:e]).long() def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat([
self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)
])
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets: if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos) # *target* is the original sentence (=item)
# past target is rotated to the left by 2 (padded if its first) # *source* is rotated left by 1 (maybe left-padded with eos)
# *past_target* is rotated left by 2 (left-padded as needed)
if s == 0: if s == 0:
source = np.concatenate([[self.eos], self.cache[0:e - 1]]) source = torch.cat([item.new([self.eos]), buffer[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]]) past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]])
else: else:
source = self.cache[s - 1: e - 1] source = buffer[s - 1:e - 1]
if s == 1: if s == 1:
past_target = np.concatenate([[self.eos], self.cache[0:e - 2]]) past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]])
else: else:
past_target = self.cache[s - 2:e - 2] past_target = buffer[s - 2:e - 2]
return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long() return source, item, past_target
return item return item
def __len__(self): def __len__(self):
return len(self.slice_indices) return len(self.slice_indices)
def prefetch(self, indices):
indices.sort()
total_size = 0
for idx in indices:
s, e = self.slice_indices[idx]
total_size += e - s
self.cache = np.empty(total_size, dtype=np.int32)
start = 0
for idx in indices:
s, e = self.slice_indices[idx]
self.dataset.read_into(s, self.cache[start:start + e - s])
self.cache_index[idx] = (start, start + e - s)
start += e - s
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return True return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch({
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
})

View File

@ -11,7 +11,7 @@ from . import FairseqDataset
class TransformEosDataset(FairseqDataset): class TransformEosDataset(FairseqDataset):
"""A dataset wrapper that appends/prepends/strips EOS. """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`. Note that the transformation is applied in :func:`collater`.
@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.dataset.supports_prefetch() return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
return self.dataset.prefetch(indices) return self.dataset.prefetch(indices)

View File

@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import namedtuple from collections import namedtuple
import os
import pickle import pickle
import subprocess
import torch import torch
from torch import nn from torch import nn
@ -42,6 +44,38 @@ else:
import torch.distributed as dist_no_c10d import torch.distributed as dist_no_c10d
def infer_init_method(args):
if args.distributed_init_method is not None:
return
# support torch.distributed.launch
if all(key in os.environ for key in [
'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
]):
args.distributed_init_method = 'tcp://{addr}:{port}'.format(
addr=os.environ['MASTER_ADDR'],
port=os.environ['MASTER_PORT'],
)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = int(os.environ['RANK'])
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384):
pickle.loads(bytes(out_buffer[2:size+2].tolist())) pickle.loads(bytes(out_buffer[2:size+2].tolist()))
) )
return result return result
except pickle.UnpicklingError as e: except pickle.UnpicklingError:
raise Exception( raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all ' 'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates ' 'workers to enter the function together, so this error usually indicates '
@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384):
'in your training script that can cause one worker to finish an epoch ' 'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.' 'while other workers are still iterating over their portions of the data.'
) )

View File

@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the
c10d version of DDP. c10d version of DDP.
This version also supports the *accumulate_grads* feature, which allows faster This version also supports the *accumulate_grads* feature, which allows faster
training with --update-freq. training with `--update-freq`.
""" """
import copy import copy
@ -27,18 +27,18 @@ from . import distributed_utils
class LegacyDistributedDataParallel(nn.Module): class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level. """Implements distributed data parallelism at the module level.
A simplified version of torch.nn.parallel.DistributedDataParallel. A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does This version uses a c10d process group for communication and does not
not broadcast buffers. broadcast buffers.
Args: Args:
module: module to be parallelized module (~torch.nn.Module): module to be parallelized
world_size: number of parallel workers world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for process_group (optional): the c10d process group to be used for
distributed data all-reduction. If None, the default process group distributed data all-reduction. If None, the default process group
will be used. will be used.
buffer_size: number of elements to buffer before performing all-reduce buffer_size (int, optional): number of elements to buffer before
(default: 256M). performing all-reduce (default: 256M).
""" """
def __init__(self, module, world_size, process_group=None, buffer_size=2**28): def __init__(self, module, world_size, process_group=None, buffer_size=2**28):

View File

@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder):
connections are added between layers when ``residual=1`` (which is connections are added between layers when ``residual=1`` (which is
the default behavior). the default behavior).
dropout (float, optional): dropout to be applied before each conv layer dropout (float, optional): dropout to be applied before each conv layer
normalization_constant (float, optional): multiplies the result of the left_pad (bool, optional): whether the input is left-padded
residual block by sqrt(value) (default: True).
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
""" """
def __init__( def __init__(
@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder):
self.residuals = [] self.residuals = []
layer_in_channels = [in_channels] layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions): for _, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0: if residual == 0:
residual_dim = out_channels residual_dim = out_channels
else: else:

View File

@ -524,6 +524,7 @@ def base_architecture(args):
args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '') args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '')
args.pretrained = getattr(args, 'pretrained', 'False') args.pretrained = getattr(args, 'pretrained', 'False')
@register_model_architecture('fconv_self_att', 'fconv_self_att_wp') @register_model_architecture('fconv_self_att', 'fconv_self_att_wp')
def fconv_self_att_wp(args): def fconv_self_att_wp(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)

View File

@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder):
if bidirectional: if bidirectional:
self.output_units *= 2 self.output_units *= 2
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
if self.left_pad: if self.left_pad:
# convert left-padding to right-padding # convert left-padding to right-padding
@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder):
if self.bidirectional: if self.bidirectional:
def combine_bidir(outs): def combine_bidir(outs):
return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1) out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
final_hiddens = combine_bidir(final_hiddens) final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells) final_cells = combine_bidir(final_cells)
@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
@ -504,6 +503,7 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args): def lstm_wiseman_iwslt_de_en(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)

View File

@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False: if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
# backward compatibility # backward compatibility
args.tie_adaptive_proj = True args.tie_adaptive_proj = True
@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.max_target_positions = args.tokens_per_sample args.max_target_positions = args.tokens_per_sample
if args.character_embeddings: if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters), embed_tokens = CharacterTokenEmbedder(
args.character_embedding_dim, task.dictionary, eval(args.character_filters),
args.decoder_embed_dim, args.character_embedding_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers, args.char_embedder_highway_layers,
) )
elif args.adaptive_input: elif args.adaptive_input:
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, embed_tokens = AdaptiveInput(
args.adaptive_input_factor, args.decoder_embed_dim, len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int)) args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int),
)
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False,
)
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder):
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default: left_pad (bool, optional): whether the input is left-padded
``True`` (default: True).
""" """
def __init__(self, args, dictionary, embed_tokens, left_pad=True): def __init__(self, args, dictionary, embed_tokens, left_pad=True):
@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs. no_encoder_attn (bool, optional): whether to attend to encoder outputs
Default: ``False`` (default: False).
left_pad (bool, optional): whether the input is left-padded. Default: left_pad (bool, optional): whether the input is left-padded
``False`` (default: False).
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
""" """
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module):
Args: Args:
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs. no_encoder_attn (bool, optional): whether to attend to encoder outputs
Default: ``False`` (default: False).
""" """
def __init__(self, args, no_encoder_attn=False): def __init__(self, args, no_encoder_attn=False):

View File

@ -7,7 +7,6 @@
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from typing import List from typing import List
@ -16,13 +15,13 @@ from typing import List
class AdaptiveInput(nn.Module): class AdaptiveInput(nn.Module):
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
padding_idx: int, padding_idx: int,
initial_dim: int, initial_dim: int,
factor: float, factor: float,
output_dim: int, output_dim: int,
cutoff: List[int], cutoff: List[int],
): ):
super().__init__() super().__init__()

View File

@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module):
m = nn.Sequential( m = nn.Sequential(
proj, proj,
nn.Dropout(self.dropout), nn.Dropout(self.dropout),
nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \ nn.Linear(
if tied_emb is None else TiedLinear(tied_emb, transpose=False) dim, self.cutoff[i + 1] - self.cutoff[i], bias=False,
) if tied_emb is None else TiedLinear(tied_emb, transpose=False),
) )
self.tail.append(m) self.tail.append(m)

View File

@ -9,7 +9,7 @@ import importlib
import os import os
from .fairseq_optimizer import FairseqOptimizer from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
OPTIMIZER_REGISTRY = {} OPTIMIZER_REGISTRY = {}

View File

@ -70,10 +70,11 @@ class FairseqOptimizer(object):
group.update(optimizer_overrides) group.update(optimizer_overrides)
def backward(self, loss): def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
loss.backward() loss.backward()
def multiply_grads(self, c): def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``.""" """Multiplies grads by a constant *c*."""
for p in self.params: for p in self.params:
if p.grad is not None: if p.grad is not None:
p.grad.data.mul_(c) p.grad.data.mul_(c)

View File

@ -45,6 +45,164 @@ class DynamicLossScaler(object):
return False return False
class FP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
)
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else:
self.fp32_params.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
class ConvertToFP32(object): class ConvertToFP32(object):
""" """
A wrapper around a list of params that will convert them to FP32 on the A wrapper around a list of params that will convert them to FP32 on the
@ -94,14 +252,13 @@ class ConvertToFP32(object):
raise StopIteration raise StopIteration
class FP16Optimizer(optim.FairseqOptimizer): class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
""" """
Wrap an *optimizer* to support FP16 (mixed precision) training. Wrap an *optimizer* to support FP16 (mixed precision) training.
Args: Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less
args (argparse.Namespace): fairseq args memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff
params (iterable): iterable of parameters to optimize is reduced optimization speed, which can be mitigated with `--update-freq`.
optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap
""" """
def __init__(self, args, params, optimizer): def __init__(self, args, params, optimizer):
@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer):
tolerance=args.fp16_scale_tolerance, tolerance=args.fp16_scale_tolerance,
) )
@staticmethod @classmethod
def build_optimizer(args, params): def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params) fp16_optimizer = optim.build_optimizer(args, params)
return FP16Optimizer(args, params, fp16_optimizer) return cls(args, params, fp16_optimizer)
@property @property
def optimizer(self): def optimizer(self):
@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer) ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss): def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale loss = loss * self.scaler.loss_scale
loss.backward() loss.backward()
self._grads_are_scaled = True self._grads_are_scaled = True
@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
assert multiply_grads == 1. assert multiply_grads == 1.
def multiply_grads(self, c): def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``.""" """Multiplies grads by a constant *c*."""
if self._grads_are_scaled: if self._grads_are_scaled:
self._unscale_grads(c) self._unscale_grads(c)
else: else:

View File

@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('cosine') @register_lr_scheduler('cosine')
class CosineSchedule(FairseqLRScheduler): class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function. """Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (`--lr`). learning rate (``--lr``).
During warmup:
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num] lr = lrs[update_num]
After warmup:
After warmup::
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr is current percentage of updates within the current period range where ``t_curr`` is current percentage of updates within the current period
t_i is the current period range, which is scaled by t_mul after every iteration range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler):
if args.warmup_init_lr < 0: if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0] args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0] self.min_lr = args.lr[0]
self.max_lr = args.max_lr self.max_lr = args.max_lr
assert self.max_lr > self.min_lr, 'max_lr must be more than lr' assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler):
t_curr = curr_updates - (self.period * i) t_curr = curr_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))

View File

@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number. """Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of learning rate (``--lr``). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate. updates, with a decay factor set to align with the configured learning rate.
During warmup: During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num] lr = lrs[update_num]
After warmup: After warmup::
lr = decay_factor / sqrt(update_num)
where
decay_factor = args.lr * sqrt(args.warmup_updates) decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):

View File

@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler
class TriangularSchedule(FairseqLRScheduler): class TriangularSchedule(FairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule. """Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details See https://arxiv.org/pdf/1506.01186.pdf for details.
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):

View File

@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
args.update_freq = eval_str_list(args.update_freq, type=int) args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None: if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
if getattr(args, 'memory_efficient_fp16', False):
args.fp16 = True
# Apply architecture configuration. # Apply architecture configuration.
if hasattr(args, 'arch'): if hasattr(args, 'arch'):
@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-fp16', action='store_true',
help='use a memory-efficient version of FP16 training; implies --fp16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int, parser.add_argument('--fp16-init-scale', default=2**7, type=int,
help='default FP16 loss scale') help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int, parser.add_argument('--fp16-scale-window', type=int,
@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'):
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
# fmt: off # fmt: off
group.add_argument('--num-workers', default=0, type=int, metavar='N',
help='how many subprocesses to use for data loading')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
@ -178,7 +185,7 @@ def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training') group = parser.add_argument_group('Distributed training')
# fmt: off # fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N', group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(), default=max(1, torch.cuda.device_count()),
help='total number of GPUs across all nodes (default: all visible GPUs)') help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int, group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker') help='rank of the current worker')
@ -189,7 +196,7 @@ def add_distributed_training_args(parser):
'establish initial connetion') 'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int, group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)') help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int, group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)') help='which GPU to use (usually configured automatically)')
group.add_argument('--ddp-backend', default='c10d', type=str, group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'], choices=['c10d', 'no_c10d'],
@ -197,8 +204,8 @@ def add_distributed_training_args(parser):
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB', group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction') help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true', group.add_argument('--fix-batches-to-gpus', action='store_true',
help='Don\'t shuffle batches between GPUs, this reduces overall ' help='don\'t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of' 'randomness and may affect precision but avoids the cost of '
're-reading the data') 're-reading the data')
# fmt: on # fmt: on
return group return group
@ -263,7 +270,9 @@ def add_checkpoint_args(parser):
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates') help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N', group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates') help='keep the last N checkpoints saved with --save-interval-updates')
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
help='keep last N epoch checkpoints')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
@ -280,11 +289,11 @@ def add_common_eval_args(group):
help='path(s) to model file(s), colon separated') help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring') help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true', group.add_argument('--quiet', action='store_true',
help='only print final scores') help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training') help='a dictionary used to override model args at generation '
'that were used during model training')
# fmt: on # fmt: on

View File

@ -22,6 +22,7 @@ class SequenceGenerator(object):
match_source_len=False, no_repeat_ngram_size=0 match_source_len=False, no_repeat_ngram_size=0
): ):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
beam_size (int, optional): beam width (default: 1) beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will min/maxlen (int, optional): the length of the generated output will
@ -90,11 +91,14 @@ class SequenceGenerator(object):
cuda=False, timer=None, prefix_size=0, cuda=False, timer=None, prefix_size=0,
): ):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
maxlen_a/b: generate sequences of maximum length ax + b, maxlen_a/b (int, optional): generate sequences of maximum length
where x is the source sentence length. ``ax + b``, where ``x`` is the source sentence length.
cuda: use GPU for generation cuda (bool, optional): use GPU for generation
timer: StopwatchMeter for timing generations. timer (StopwatchMeter, optional): time generations
prefix_size (int, optional): prefill the generation with the gold
prefix up to this length.
""" """
if maxlen_b is None: if maxlen_b is None:
maxlen_b = self.maxlen maxlen_b = self.maxlen
@ -132,12 +136,13 @@ class SequenceGenerator(object):
"""Generate a batch of translations. """Generate a batch of translations.
Args: Args:
encoder_input: dictionary containing the inputs to encoder_input (dict): dictionary containing the inputs to
model.encoder.forward *model.encoder.forward*.
beam_size: int overriding the beam size. defaults to beam_size (int, optional): overriding the beam size
self.beam_size (default: *self.beam_size*).
max_len: maximum length of the generated sequence max_len (int, optional): maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens prefix_tokens (LongTensor, optional): force decoder to begin with
these tokens
""" """
with torch.no_grad(): with torch.no_grad():
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens) return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)

View File

@ -87,4 +87,3 @@ class SequenceScorer(object):
index=sample['target'].data.unsqueeze(-1), index=sample['target'].data.unsqueeze(-1),
) )
return avg_probs.squeeze(2), avg_attn return avg_probs.squeeze(2), avg_attn

View File

@ -1,7 +1,8 @@
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse import argparse

View File

@ -61,29 +61,32 @@ class FairseqTask(object):
def get_batch_iterator( def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None, self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, seed=1, num_shards=1, shard_id=0, num_workers=0,
): ):
""" """
Get an iterator that yields batches of data from the given dataset. Get an iterator that yields batches of data from the given dataset.
Args: Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch. max_tokens (int, optional): max number of tokens in each batch
Default: ``None`` (default: None).
max_sentences (int, optional): max number of sentences in each max_sentences (int, optional): max number of sentences in each
batch. Default: ``None`` batch (default: None).
max_positions (optional): max sentence length supported by the max_positions (optional): max sentence length supported by the
model. Default: ``None`` model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False`` sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1`` be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for seed (int, optional): seed for random number generator for
reproducibility. Default: ``1`` reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N num_shards (int, optional): shard the data iterator into N
shards. Default: ``1`` shards (default: 1).
shard_id (int, optional): which shard of the data iterator to shard_id (int, optional): which shard of the data iterator to
return. Default: ``0`` return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
Returns: Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
@ -114,6 +117,7 @@ class FairseqTask(object):
seed=seed, seed=seed,
num_shards=num_shards, num_shards=num_shards,
shard_id=shard_id, shard_id=shard_id,
num_workers=num_workers,
) )
def build_model(self, args): def build_model(self, args):

View File

@ -10,9 +10,15 @@ import numpy as np
import os import os
from fairseq.data import ( from fairseq.data import (
ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, ConcatDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary, Dictionary,
IndexedCachedDataset, IndexedDataset) IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MonolingualDataset,
TokenBlockDataset,
TruncatedDictionary,
)
from . import FairseqTask, register_task from . import FairseqTask, register_task
@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask):
'If set to "eos", includes only one sentence per sample.') 'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset') help='max number of tokens per sample for LM dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--output-dictionary-size', default=-1, type=int, parser.add_argument('--output-dictionary-size', default=-1, type=int,
@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask):
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path): elif not self.args.raw_text and IndexedDataset.exists(path):
ds = IndexedDataset(path, fix_lua_indexing=True) if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else: else:
if k > 0: if k > 0:
break break
@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append( loaded_datasets.append(
TokenBlockDataset( TokenBlockDataset(
ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), ds, ds.sizes, self.args.tokens_per_sample,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True, break_mode=self.args.sample_break_mode, include_targets=True,
)) )
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))

View File

@ -12,8 +12,12 @@ import torch
from fairseq import options from fairseq import options
from fairseq.data import ( from fairseq.data import (
Dictionary, LanguagePairDataset, IndexedInMemoryDataset, Dictionary,
IndexedRawTextDataset, RoundRobinZipDatasets, IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
RoundRobinZipDatasets,
) )
from fairseq.models import FairseqMultiModel from fairseq.models import FairseqMultiModel
@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask):
help='source language (only needed for inference)') help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)') help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename): elif not self.args.raw_text and IndexedDataset.exists(filename):
return True return True
return False return False
def indexed_dataset(path, dictionary): def indexed_dataset(path, dictionary):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path): elif IndexedDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True) if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None return None
def sort_lang_pair(lang_pair): def sort_lang_pair(lang_pair):

View File

@ -6,13 +6,17 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import itertools
import numpy as np
import os import os
from fairseq import options, utils from fairseq import options, utils
from fairseq.data import ( from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
) )
from . import FairseqTask, register_task from . import FairseqTask, register_task
@ -49,6 +53,8 @@ class TranslationTask(FairseqTask):
help='source language') help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language') help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
@ -132,7 +138,10 @@ class TranslationTask(FairseqTask):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path): elif IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=True) if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None return None
src_datasets = [] src_datasets = []

View File

@ -6,10 +6,11 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import Counter from collections import Counter
import os, re from multiprocessing import Pool
import os
import re
import torch import torch
from multiprocessing import Pool
SPACE_NORMALIZER = re.compile(r"\s+") SPACE_NORMALIZER = re.compile(r"\s+")
@ -27,7 +28,8 @@ def safe_readline(f):
return f.readline() return f.readline()
except UnicodeDecodeError: except UnicodeDecodeError:
pos -= 1 pos -= 1
f.seek(pos) # search where this character begins f.seek(pos) # search where this character begins
class Tokenizer: class Tokenizer:
@ -41,7 +43,7 @@ class Tokenizer:
end = offset + chunk_size end = offset + chunk_size
f.seek(offset) f.seek(offset)
if offset > 0: if offset > 0:
safe_readline(f) # drop first incomplete line safe_readline(f) # drop first incomplete line
line = f.readline() line = f.readline()
while line: while line:
for word in tokenize(line): for word in tokenize(line):
@ -73,14 +75,17 @@ class Tokenizer:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word)) merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod @staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, def binarize(
append_eos=True, reverse_order=False, filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
offset=0, end=-1): reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0 nseq, ntok = 0, 0
replaced = Counter() replaced = Counter()
def replaced_consumer(word, idx): def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word: if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word]) replaced.update([word])
with open(filename, 'r') as f: with open(filename, 'r') as f:
f.seek(offset) f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used # next(f) breaks f.tell(), hence readline() must be used

View File

@ -30,22 +30,23 @@ class Trainer(object):
""" """
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None): def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args self.args = args
self.task = task self.task = task
# copy model and criterion to current device # copy model and criterion to current device
self.criterion = criterion.cuda() self.criterion = criterion
self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16: if args.fp16:
self._model = model.half().cuda() self._model = self._model.half()
else: if self.cuda:
self._model = model.cuda() self.criterion = self.criterion.cuda()
self._model = self._model.cuda()
self._dummy_batch = dummy_batch self._dummy_batch = dummy_batch
self._oom_batch = oom_batch self._oom_batch = oom_batch
self._lr_scheduler = None
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
self._optimizer = None self._optimizer = None
@ -71,7 +72,6 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
@property @property
def model(self): def model(self):
if self._wrapped_model is None: if self._wrapped_model is None:
@ -89,19 +89,26 @@ class Trainer(object):
self._build_optimizer() self._build_optimizer()
return self._optimizer return self._optimizer
@property
def lr_scheduler(self):
if self._lr_scheduler is None:
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
return self._lr_scheduler
def _build_optimizer(self): def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
if self.args.fp16: if self.args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, ' print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster') 'please switch to FP32 which is likely to be faster')
params = list(filter(lambda p: p.requires_grad, self.model.parameters())) if self.args.memory_efficient_fp16:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else: else:
if torch.cuda.get_device_capability(0)[0] >= 7: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16') print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._optimizer = optim.build_optimizer(self.args, params)
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
@ -151,7 +158,8 @@ class Trainer(object):
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates() seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) if self.cuda:
torch.cuda.manual_seed(seed)
self.model.train() self.model.train()
self.zero_grad() self.zero_grad()
@ -296,7 +304,8 @@ class Trainer(object):
for p in self.model.parameters(): for p in self.model.parameters():
if p.grad is not None: if p.grad is not None:
del p.grad # free some memory del p.grad # free some memory
torch.cuda.empty_cache() if self.cuda:
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True) return self.valid_step(sample, raise_oom=True)
else: else:
raise e raise e
@ -377,4 +386,6 @@ class Trainer(object):
def _prepare_sample(self, sample): def _prepare_sample(self, sample):
if sample is None or len(sample) == 0: if sample is None or len(sample) == 0:
return None return None
return utils.move_to_cuda(sample) if self.cuda:
sample = utils.move_to_cuda(sample)
return sample

View File

@ -378,6 +378,14 @@ def item(tensor):
return tensor return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf.""" """FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t) return t.float().fill_(float('-inf')).type_as(t)

18
fb_train.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.fb.rendezvous.zeus # noqa: F401
from fairseq import options
from train import main
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)

View File

@ -11,7 +11,7 @@ Translate pre-processed data with a trained model.
import torch import torch
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
@ -41,7 +41,9 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
@ -69,6 +71,7 @@ def main(args):
required_batch_size_multiple=8, required_batch_size_multiple=8,
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
# Initialize generator # Initialize generator

View File

@ -75,8 +75,9 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(':') models, _model_args = utils.load_ensemble_for_inference(
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides)) args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Set dictionaries # Set dictionaries
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import signal
import torch
from fairseq import distributed_utils, options
from train import main as single_process_main
def main(args):
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d')
mp = torch.multiprocessing.get_context('spawn')
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
base_rank = args.distributed_rank
for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()
def run(args, error_queue):
try:
args.distributed_rank = distributed_utils.distributed_init(args)
single_process_main(args)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((args.distributed_rank, traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
self.children_pids.append(pid)
def error_listener(self):
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)

View File

@ -18,7 +18,7 @@ import shutil
from fairseq.data import indexed_dataset, dictionary from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process from multiprocessing import Pool
def get_parser(): def get_parser():

View File

@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16']) train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir) generate_main(data_dir)
def test_memory_efficient_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16'])
generate_main(data_dir)
def test_update_freq(self): def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir: with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase):
data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'], data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
) )
self.assertTrue( self.assertTrue(
'skip this example with --skip-invalid-size-inputs-valid-test' \ 'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception)
in str(context.exception)
) )
train_translation_model( train_translation_model(
data_dir, 'fconv_iwslt_de_en', data_dir, 'fconv_iwslt_de_en',

View File

@ -12,10 +12,6 @@ import os
import tempfile import tempfile
import unittest import unittest
import torch
from fairseq import options
from . import test_binaries from . import test_binaries
@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase):
'--fp16-init-scale', '4096', '--fp16-init-scale', '4096',
]) ])
def test_reproducibility_memory_efficient_fp16(self):
self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
'--memory-efficient-fp16',
'--fp16-init-scale', '4096',
])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -39,8 +39,10 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))) tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False) tokens_ds = data.TokenBlockDataset(
tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator( epoch_itr = data.EpochBatchIterator(
@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase):
self.applied_patches = [patch(p, d) for p, d in self.patches.items()] self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches] [p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self): def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)

View File

@ -28,9 +28,8 @@ def main(args):
args.max_tokens = 6000 args.max_tokens = 6000
print(args) print(args)
if not torch.cuda.is_available(): if torch.cuda.is_available() and not args.cpu:
raise NotImplementedError('Training on CPU is not supported') torch.cuda.set_device(args.device_id)
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc. # Setup task, e.g., translation, language modeling, etc.
@ -74,6 +73,7 @@ def main(args):
seed=args.seed, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank, shard_id=args.distributed_rank,
num_workers=args.num_workers,
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
seed=args.seed, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank, shard_id=args.distributed_rank,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar( progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, args, itr, epoch_itr.epoch,
@ -306,7 +307,15 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
# remove old checkpoints; checkpoints are sorted in descending order # remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]: for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk) if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt')
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr): def load_checkpoint(args, trainer, epoch_itr):
@ -346,23 +355,50 @@ def load_dataset_splits(task, splits):
raise e raise e
def distributed_main(i, args):
import socket
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
main(args)
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_training_parser() parser = options.get_training_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
if args.distributed_port > 0 or args.distributed_init_method is not None: if args.distributed_init_method is None:
from distributed_train import main as distributed_main distributed_utils.infer_init_method(args)
distributed_main(args) if args.distributed_init_method is not None:
# distributed training
distributed_main(args.device_id, args)
args.distributed_rank = distributed_utils.distributed_init(args)
main(args)
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main # fallback for single node with multiple GPUs
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000) port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1 args.distributed_rank = None # set based on device id
print(
'''| NOTE: you may get better performance with:
multiprocessing_main(args) python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...)
'''.format(
ngpu=args.distributed_world_size,
no_c10d=(
'--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d'
else ''
),
)
)
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),
nprocs=args.distributed_world_size,
)
else: else:
# single GPU training
main(args) main(args)