Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
This commit is contained in:
Myle Ott 2019-01-04 20:00:49 -08:00 committed by Facebook Github Bot
parent 0cb87130e7
commit 7633129ba8
59 changed files with 837 additions and 555 deletions

View File

@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU
- fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets.
@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
Currently fairseq requires PyTorch version >= 0.4.0.
Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)

View File

@ -6,8 +6,26 @@
Criterions
==========
Criterions compute the loss function given the model and batch, roughly::
loss = criterion(model, batch)
.. automodule:: fairseq.criterions
:members:
.. autoclass:: fairseq.criterions.FairseqCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:

View File

@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset
:members:
**Helper Datasets**
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:
.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:
Dictionary
----------
@ -32,6 +46,8 @@ Dictionary
Iterators
---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator
:members:
.. autoclass:: fairseq.data.EpochBatchIterator

View File

@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5
--beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
This generation script produces four types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the
vocabulary; *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *A* is the
attention maxima for each word in the hypothesis, including the
This generation script produces three types of outputs: a line prefixed
with *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *P* is the
positional score per token position, including the
end-of-sentence marker which is omitted from the text.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a

View File

@ -6,7 +6,29 @@
Learning Rate Schedulers
========================
TODO
Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.
.. automodule:: fairseq.optim.lr_scheduler
:members:
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:

View File

@ -1,8 +1,8 @@
Modules
=======
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be
helpful when implementing a new :class:`FairseqModel`.
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
.. automodule:: fairseq.modules
:members:

View File

@ -6,5 +6,27 @@
Optimizers
==========
Optimizers update the Model parameters based on the gradients.
.. automodule:: fairseq.optim
:members:
.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:

View File

@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
loss = criterion(model, batch)
optimizer.backward(loss)
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
**Registering new plug-ins**
New plug-ins are *registered* through a set of ``@register`` function

View File

@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
-------------------------------
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents::
a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation
parser = options.get_generation_parser()
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)
# Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args)
# Load model

View File

@ -55,7 +55,9 @@ def main(parsed_args):
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()

View File

@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
@ -33,7 +33,6 @@ __all__ = [
'GroupedIterator',
'IndexedCachedDataset',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',

View File

@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__(
self,
tgt_dataset,
@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset):
cuda=True,
**kwargs
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
cuda: use GPU for generation
"""
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset):
"""
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
@property
def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch()
return getattr(self.tgt_dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)

View File

@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cummulative_sizes[-1]
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx]
@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets):
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to

View File

@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception
if any elements are filtered. Default: ``False``
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
@ -128,12 +128,12 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')

View File

@ -200,11 +200,15 @@ class Dictionary(object):
t[-1] = self.eos()
return t
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__), {})
self.__class__ = type(
wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{}
)
self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)

View File

@ -7,8 +7,6 @@
import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError

View File

@ -52,13 +52,12 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False, read_data=True):
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
self.fix_lua_indexing = fix_lua_indexing
self.read_index(path)
self.data_file = None
if read_data:
self.read_data(path)
self.path = path
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
self.data_file.close()
def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod
def exists(path):
return (
@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True)
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None
self.cache_index = {}
@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
return item
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
class IndexedRawTextDataset(torch.utils.data.Dataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False)
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]

View File

@ -69,17 +69,19 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: 1
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards. Default: 1
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return. Default: 0
buffer_size (int, optional): number of batches to buffer. Default: 5
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
"""
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5,
num_workers=0,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
@ -88,14 +90,12 @@ class EpochBatchIterator(object):
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.buffer_size = buffer_size
self.num_workers = num_workers
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = (
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
def __len__(self):
return len(self.frozen_batches)
@ -105,11 +105,10 @@ class EpochBatchIterator(object):
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
iterator (default: True).
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default:
``False``
that :attr:`dataset` supports prefetching (default: False).
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
@ -117,7 +116,8 @@ class EpochBatchIterator(object):
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus)
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr
def end_of_epoch(self):
@ -179,50 +179,14 @@ class EpochBatchIterator(object):
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(BufferedIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
),
buffer_size=self.buffer_size,
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
num_workers=self.num_workers,
))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
@ -261,7 +225,7 @@ class ShardedIterator(object):
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None``
evenly divide *num_shards* (default: None).
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):

View File

@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side.
Default: ``False``
max_source_positions (int, optional): max number of tokens in the source
sentence. Default: ``1024``
max_target_positions (int, optional): max number of tokens in the target
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the
target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
remove_eos_from_source (bool, optional): if set, removes eos from end of
source if it's present. Default: ``False``
to be passed into the model for input feeding/teacher forcing
(default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False``
target if it's absent (default: False).
"""
def __init__(
@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, 'supports_prefetch')
and self.src.supports_prefetch
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
getattr(self.src, 'supports_prefetch', False)
and getattr(self.tgt, 'supports_prefetch', False)
)
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)

View File

@ -9,7 +9,6 @@ import numpy as np
import torch
from . import data_utils, FairseqDataset
from typing import List
def collate(samples, pad_idx, eos_idx):
@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset):
dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
shuffle (bool, optional): shuffle the elements before batching
(default: True).
"""
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset):
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
assert targets is None or all(
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'"
assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
"targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0:
targets = None
self.targets = targets
@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset):
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch(indices)

View File

@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset):
**kwargs
):
"""
Sets up a noising dataset which takes a src batch, generates
a noisy src using a noising config, and returns the
corresponding {noisy src, original src} batch
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
samples based on the supplied noising configuration.
Args:
src_dataset: dataset which will be used to build self.src_dataset --
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset):
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
src_dict: src dict
src_dict: src dictionary
seed: seed to use when generating random noise
noiser: a pre-initialized noiser. If this is None, a noiser will
be created using noising_class and kwargs.
noising_class: class to use when initializing noiser
kwargs: noising args for configuring noising to apply
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_noising_dataset.
src_dict (~fairseq.data.Dictionary): source dictionary
seed (int): seed to use when generating random noise
noiser (WordNoising): a pre-initialized :class:`WordNoising`
instance. If this is None, a new instance will be created using
*noising_class* and *kwargs*.
noising_class (class, optional): class to use to initialize a
default :class:`WordNoising` instance.
kwargs (dict, optional): arguments to initialize the default
:class:`WordNoising` instance given by *noiser*.
"""
self.src_dataset = src_dataset
self.src_dict = src_dict
self.seed = seed
self.noiser = noiser if noiser is not None else noising_class(
dictionary=src_dict, **kwargs,
)
self.seed = seed
def __getitem__(self, index):
"""

View File

@ -13,13 +13,16 @@ from . import FairseqDataset
class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple FairseqDatasets together, repeating shorter datasets in a
round-robin fashion to match the length of the longest one.
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args:
datasets: a dictionary of FairseqDatasets
eval_key: an optional key used at evaluation time that causes this
instance to pass-through batches from `datasets[eval_key]`.
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
:class:`~fairseq.data.FairseqDataset` instances.
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
"""
def __init__(self, datasets, eval_key=None):
@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset):
dataset.valid_size(self._map_index(key, index), max_positions[key])
for key, dataset in self.datasets.items()
)
@property
def supports_prefetch(self):
return all(
getattr(dataset, 'supports_prefetch', False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])

View File

@ -14,32 +14,32 @@ from . import FairseqDataset
class TokenBlockDataset(FairseqDataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
"""Break a Dataset of tokens into blocks.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
include_targets (bool, optional): return next tokens as targets
(default: False).
"""
def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False):
def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__()
self.dataset = ds
self.dataset = dataset
self.pad = pad
self.eos = eos
self.include_targets = include_targets
self.slice_indices = []
self.cache_index = {}
sizes = ds.sizes
assert len(dataset) == len(sizes)
if break_mode is None or break_mode == 'none':
total_size = sum(sizes)
@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset):
self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index):
s, e = self.cache_index[index]
# build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = []
ds_idx, ds_remaining = -1, 0
for to_consume in self.sizes:
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index.append((
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
))
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
item = torch.from_numpy(self.cache[s:e]).long()
def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat([
self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)
])
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first)
# *target* is the original sentence (=item)
# *source* is rotated left by 1 (maybe left-padded with eos)
# *past_target* is rotated left by 2 (left-padded as needed)
if s == 0:
source = np.concatenate([[self.eos], self.cache[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]])
source = torch.cat([item.new([self.eos]), buffer[0:e - 1]])
past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]])
else:
source = self.cache[s - 1: e - 1]
source = buffer[s - 1:e - 1]
if s == 1:
past_target = np.concatenate([[self.eos], self.cache[0:e - 2]])
past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]])
else:
past_target = self.cache[s - 2:e - 2]
past_target = buffer[s - 2:e - 2]
return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long()
return source, item, past_target
return item
def __len__(self):
return len(self.slice_indices)
def prefetch(self, indices):
indices.sort()
total_size = 0
for idx in indices:
s, e = self.slice_indices[idx]
total_size += e - s
self.cache = np.empty(total_size, dtype=np.int32)
start = 0
for idx in indices:
s, e = self.slice_indices[idx]
self.dataset.read_into(s, self.cache[start:start + e - s])
self.cache_index[idx] = (start, start + e - s)
start += e - s
@property
def supports_prefetch(self):
return True
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch({
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
})

View File

@ -11,7 +11,7 @@ from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A dataset wrapper that appends/prepends/strips EOS.
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset):
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch()
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)

View File

@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import os
import pickle
import subprocess
import torch
from torch import nn
@ -42,6 +44,38 @@ else:
import torch.distributed as dist_no_c10d
def infer_init_method(args):
if args.distributed_init_method is not None:
return
# support torch.distributed.launch
if all(key in os.environ for key in [
'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
]):
args.distributed_init_method = 'tcp://{addr}:{port}'.format(
addr=os.environ['MASTER_ADDR'],
port=os.environ['MASTER_PORT'],
)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = int(os.environ['RANK'])
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384):
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result
except pickle.UnpicklingError as e:
except pickle.UnpicklingError:
raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates '
@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384):
'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.'
)

View File

@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the
c10d version of DDP.
This version also supports the *accumulate_grads* feature, which allows faster
training with --update-freq.
training with `--update-freq`.
"""
import copy
@ -27,18 +27,18 @@ from . import distributed_utils
class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level.
A simplified version of torch.nn.parallel.DistributedDataParallel.
This version uses a c10d process group for communication and does
not broadcast buffers.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
Args:
module: module to be parallelized
world_size: number of parallel workers
module (~torch.nn.Module): module to be parallelized
world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for
distributed data all-reduction. If None, the default process group
will be used.
buffer_size: number of elements to buffer before performing all-reduce
(default: 256M).
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
"""
def __init__(self, module, world_size, process_group=None, buffer_size=2**28):

View File

@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder):
connections are added between layers when ``residual=1`` (which is
the default behavior).
dropout (float, optional): dropout to be applied before each conv layer
normalization_constant (float, optional): multiplies the result of the
residual block by sqrt(value)
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
left_pad (bool, optional): whether the input is left-padded
(default: True).
"""
def __init__(
@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder):
self.residuals = []
layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions):
for _, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0:
residual_dim = out_channels
else:

View File

@ -524,6 +524,7 @@ def base_architecture(args):
args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '')
args.pretrained = getattr(args, 'pretrained', 'False')
@register_model_architecture('fconv_self_att', 'fconv_self_att_wp')
def fconv_self_att_wp(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)

View File

@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder):
if bidirectional:
self.output_units *= 2
def forward(self, src_tokens, src_lengths):
if self.left_pad:
# convert left-padding to right-padding
@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder):
if self.bidirectional:
def combine_bidir(outs):
return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1)
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells)
@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
@ -504,6 +503,7 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args):
args.dropout = getattr(args, 'dropout', 0.1)

View File

@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models
base_lm_architecture(args)
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False:
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
# backward compatibility
args.tie_adaptive_proj = True
@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.max_target_positions = args.tokens_per_sample
if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters),
args.character_embedding_dim,
args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
embed_tokens = CharacterTokenEmbedder(
task.dictionary, eval(args.character_filters),
args.character_embedding_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
elif args.adaptive_input:
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int))
embed_tokens = AdaptiveInput(
len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int),
)
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False,
)
return TransformerLanguageModel(decoder)
@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder):
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
left_pad (bool, optional): whether the input is left-padded
(default: True).
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
left_pad (bool, optional): whether the input is left-padded
(default: False).
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module):
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, no_encoder_attn=False):

View File

@ -7,7 +7,6 @@
import torch
import torch.nn.functional as F
from torch import nn
from typing import List
@ -16,13 +15,13 @@ from typing import List
class AdaptiveInput(nn.Module):
def __init__(
self,
vocab_size: int,
padding_idx: int,
initial_dim: int,
factor: float,
output_dim: int,
cutoff: List[int],
self,
vocab_size: int,
padding_idx: int,
initial_dim: int,
factor: float,
output_dim: int,
cutoff: List[int],
):
super().__init__()

View File

@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module):
m = nn.Sequential(
proj,
nn.Dropout(self.dropout),
nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \
if tied_emb is None else TiedLinear(tied_emb, transpose=False)
nn.Linear(
dim, self.cutoff[i + 1] - self.cutoff[i], bias=False,
) if tied_emb is None else TiedLinear(tied_emb, transpose=False),
)
self.tail.append(m)

View File

@ -9,7 +9,7 @@ import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer
from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
OPTIMIZER_REGISTRY = {}

View File

@ -70,10 +70,11 @@ class FairseqOptimizer(object):
group.update(optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
loss.backward()
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
"""Multiplies grads by a constant *c*."""
for p in self.params:
if p.grad is not None:
p.grad.data.mul_(c)

View File

@ -45,6 +45,164 @@ class DynamicLossScaler(object):
return False
class FP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
)
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else:
self.fp32_params.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
class ConvertToFP32(object):
"""
A wrapper around a list of params that will convert them to FP32 on the
@ -94,14 +252,13 @@ class ConvertToFP32(object):
raise StopIteration
class FP16Optimizer(optim.FairseqOptimizer):
class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap
Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less
memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff
is reduced optimization speed, which can be mitigated with `--update-freq`.
"""
def __init__(self, args, params, optimizer):
@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer):
tolerance=args.fp16_scale_tolerance,
)
@staticmethod
def build_optimizer(args, params):
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params)
return FP16Optimizer(args, params, fp16_optimizer)
return cls(args, params, fp16_optimizer)
@property
def optimizer(self):
@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale
loss.backward()
self._grads_are_scaled = True
@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
assert multiply_grads == 1.
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
"""Multiplies grads by a constant *c*."""
if self._grads_are_scaled:
self._unscale_grads(c)
else:

View File

@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('cosine')
class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`).
During warmup:
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``).
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
After warmup::
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr is current percentage of updates within the current period range
t_i is the current period range, which is scaled by t_mul after every iteration
where ``t_curr`` is current percentage of updates within the current period
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
"""
def __init__(self, args, optimizer):
@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler):
if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0]
self.min_lr = args.lr[0]
self.max_lr = args.max_lr
assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler):
t_curr = curr_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink
min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))

View File

@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup:
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = decay_factor / sqrt(update_num)
where
After warmup::
decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
"""
def __init__(self, args, optimizer):

View File

@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler
class TriangularSchedule(FairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details
See https://arxiv.org/pdf/1506.01186.pdf for details.
"""
def __init__(self, args, optimizer):

View File

@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
if getattr(args, 'memory_efficient_fp16', False):
args.fp16 = True
# Apply architecture configuration.
if hasattr(args, 'arch'):
@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-fp16', action='store_true',
help='use a memory-efficient version of FP16 training; implies --fp16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int,
help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'):
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
# fmt: off
group.add_argument('--num-workers', default=0, type=int, metavar='N',
help='how many subprocesses to use for data loading')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N',
@ -178,7 +185,7 @@ def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
# fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(),
default=max(1, torch.cuda.device_count()),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker')
@ -189,7 +196,7 @@ def add_distributed_training_args(parser):
'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'],
@ -197,8 +204,8 @@ def add_distributed_training_args(parser):
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true',
help='Don\'t shuffle batches between GPUs, this reduces overall '
'randomness and may affect precision but avoids the cost of'
help='don\'t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of '
're-reading the data')
# fmt: on
return group
@ -263,7 +270,9 @@ def add_checkpoint_args(parser):
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates')
help='keep the last N checkpoints saved with --save-interval-updates')
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
help='keep last N epoch checkpoints')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
@ -280,11 +289,11 @@ def add_common_eval_args(group):
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
help='a dictionary used to override model args at generation '
'that were used during model training')
# fmt: on

View File

@ -22,6 +22,7 @@ class SequenceGenerator(object):
match_source_len=False, no_repeat_ngram_size=0
):
"""Generates translations of a given source sentence.
Args:
beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will
@ -90,11 +91,14 @@ class SequenceGenerator(object):
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
maxlen_a/b (int, optional): generate sequences of maximum length
``ax + b``, where ``x`` is the source sentence length.
cuda (bool, optional): use GPU for generation
timer (StopwatchMeter, optional): time generations
prefix_size (int, optional): prefill the generation with the gold
prefix up to this length.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
@ -132,12 +136,13 @@ class SequenceGenerator(object):
"""Generate a batch of translations.
Args:
encoder_input: dictionary containing the inputs to
model.encoder.forward
beam_size: int overriding the beam size. defaults to
self.beam_size
max_len: maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens
encoder_input (dict): dictionary containing the inputs to
*model.encoder.forward*.
beam_size (int, optional): overriding the beam size
(default: *self.beam_size*).
max_len (int, optional): maximum length of the generated sequence
prefix_tokens (LongTensor, optional): force decoder to begin with
these tokens
"""
with torch.no_grad():
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)

View File

@ -87,4 +87,3 @@ class SequenceScorer(object):
index=sample['target'].data.unsqueeze(-1),
)
return avg_probs.squeeze(2), avg_attn

View File

@ -1,7 +1,8 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse

View File

@ -61,29 +61,32 @@ class FairseqTask(object):
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0,
seed=1, num_shards=1, shard_id=0, num_workers=0,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
batch (default: None).
max_positions (optional): max sentence length supported by the
model. Default: ``None``
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False``
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
@ -114,6 +117,7 @@ class FairseqTask(object):
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
)
def build_model(self, args):

View File

@ -10,9 +10,15 @@ import numpy as np
import os
from fairseq.data import (
ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary,
IndexedCachedDataset, IndexedDataset)
ConcatDataset,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MonolingualDataset,
TokenBlockDataset,
TruncatedDictionary,
)
from . import FairseqTask, register_task
@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask):
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--output-dictionary-size', default=-1, type=int,
@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask):
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
ds = IndexedDataset(path, fix_lua_indexing=True)
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if k > 0:
break
@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append(
TokenBlockDataset(
ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
ds, ds.sizes, self.args.tokens_per_sample,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True,
))
)
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))

View File

@ -12,8 +12,12 @@ import torch
from fairseq import options
from fairseq.data import (
Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset, RoundRobinZipDatasets,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
RoundRobinZipDatasets,
)
from fairseq.models import FairseqMultiModel
@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask):
help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
def sort_lang_pair(lang_pair):

View File

@ -6,13 +6,17 @@
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from fairseq import options, utils
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
)
from . import FairseqTask, register_task
@ -49,6 +53,8 @@ class TranslationTask(FairseqTask):
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
@ -132,7 +138,10 @@ class TranslationTask(FairseqTask):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=True)
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
src_datasets = []

View File

@ -6,10 +6,11 @@
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os, re
from multiprocessing import Pool
import os
import re
import torch
from multiprocessing import Pool
SPACE_NORMALIZER = re.compile(r"\s+")
@ -27,7 +28,8 @@ def safe_readline(f):
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
f.seek(pos) # search where this character begins
class Tokenizer:
@ -41,7 +43,7 @@ class Tokenizer:
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
@ -73,14 +75,17 @@ class Tokenizer:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False,
offset=0, end=-1):
def binarize(
filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used

View File

@ -30,22 +30,23 @@ class Trainer(object):
"""
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args
self.task = task
# copy model and criterion to current device
self.criterion = criterion.cuda()
self.criterion = criterion
self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16:
self._model = model.half().cuda()
else:
self._model = model.cuda()
self._model = self._model.half()
if self.cuda:
self.criterion = self.criterion.cuda()
self._model = self._model.cuda()
self._dummy_batch = dummy_batch
self._oom_batch = oom_batch
self._lr_scheduler = None
self._num_updates = 0
self._optim_history = None
self._optimizer = None
@ -71,7 +72,6 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
@property
def model(self):
if self._wrapped_model is None:
@ -89,19 +89,26 @@ class Trainer(object):
self._build_optimizer()
return self._optimizer
@property
def lr_scheduler(self):
if self._lr_scheduler is None:
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
return self._lr_scheduler
def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
if self.args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster')
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
if self.args.memory_efficient_fp16:
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
self._optimizer = optim.build_optimizer(self.args, params)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
@ -151,7 +158,8 @@ class Trainer(object):
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
self.model.train()
self.zero_grad()
@ -296,7 +304,8 @@ class Trainer(object):
for p in self.model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
if self.cuda:
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True)
else:
raise e
@ -377,4 +386,6 @@ class Trainer(object):
def _prepare_sample(self, sample):
if sample is None or len(sample) == 0:
return None
return utils.move_to_cuda(sample)
if self.cuda:
sample = utils.move_to_cuda(sample)
return sample

View File

@ -378,6 +378,14 @@ def item(tensor):
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)

18
fb_train.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.fb.rendezvous.zeus # noqa: F401
from fairseq import options
from train import main
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)

View File

@ -11,7 +11,7 @@ Translate pre-processed data with a trained model.
import torch
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
@ -41,7 +41,9 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Optimize ensemble for generation
for model in models:
@ -69,6 +71,7 @@ def main(args):
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
# Initialize generator

View File

@ -75,8 +75,9 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(':')
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Set dictionaries
tgt_dict = task.target_dictionary

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import signal
import torch
from fairseq import distributed_utils, options
from train import main as single_process_main
def main(args):
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d')
mp = torch.multiprocessing.get_context('spawn')
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
base_rank = args.distributed_rank
for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()
def run(args, error_queue):
try:
args.distributed_rank = distributed_utils.distributed_init(args)
single_process_main(args)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((args.distributed_rank, traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
self.children_pids.append(pid)
def error_listener(self):
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)

View File

@ -18,7 +18,7 @@ import shutil
from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process
from multiprocessing import Pool
def get_parser():

View File

@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir)
def test_memory_efficient_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16'])
generate_main(data_dir)
def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase):
data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
)
self.assertTrue(
'skip this example with --skip-invalid-size-inputs-valid-test' \
in str(context.exception)
'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception)
)
train_translation_model(
data_dir, 'fconv_iwslt_de_en',

View File

@ -12,10 +12,6 @@ import os
import tempfile
import unittest
import torch
from fairseq import options
from . import test_binaries
@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase):
'--fp16-init-scale', '4096',
])
def test_reproducibility_memory_efficient_fp16(self):
self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
'--memory-efficient-fp16',
'--fp16-init-scale', '4096',
])
if __name__ == '__main__':
unittest.main()

View File

@ -39,8 +39,10 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False)
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(
tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator(
@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase):
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)

View File

@ -28,9 +28,8 @@ def main(args):
args.max_tokens = 6000
print(args)
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
torch.cuda.set_device(args.device_id)
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc.
@ -74,6 +73,7 @@ def main(args):
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
)
# Load the latest checkpoint if one is available
@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch,
@ -306,7 +307,15 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt')
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr):
@ -346,23 +355,50 @@ def load_dataset_splits(task, splits):
raise e
def distributed_main(i, args):
import socket
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main
if args.distributed_init_method is None:
distributed_utils.infer_init_method(args)
distributed_main(args)
if args.distributed_init_method is not None:
# distributed training
distributed_main(args.device_id, args)
args.distributed_rank = distributed_utils.distributed_init(args)
main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
# fallback for single node with multiple GPUs
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1
args.distributed_rank = None # set based on device id
print(
'''| NOTE: you may get better performance with:
multiprocessing_main(args)
python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...)
'''.format(
ngpu=args.distributed_world_size,
no_c10d=(
'--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d'
else ''
),
)
)
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),
nprocs=args.distributed_world_size,
)
else:
# single GPU training
main(args)