mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Merge internal changes (#283)
Summary: Pull Request resolved: https://github.com/pytorch/translate/pull/283 Pull Request resolved: https://github.com/pytorch/fairseq/pull/428 Differential Revision: D13564190 Pulled By: myleott fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
This commit is contained in:
parent
0cb87130e7
commit
7633129ba8
@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
|
||||
|
||||
Fairseq features:
|
||||
- multi-GPU (distributed) training on one machine or across multiple machines
|
||||
- fast beam search generation on both CPU and GPU
|
||||
- fast generation on both CPU and GPU with multiple search algorithms implemented:
|
||||
- beam search
|
||||
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
|
||||
- sampling (unconstrained and top-k)
|
||||
- large mini-batch training even on a single GPU via delayed updates
|
||||
- fast half-precision floating point (FP16) training
|
||||
- extensible: easily register new models, criterions, and tasks
|
||||
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
||||
|
||||
We also provide [pre-trained models](#pre-trained-models) for several benchmark
|
||||
translation and language modeling datasets.
|
||||
@ -34,7 +37,7 @@ translation and language modeling datasets.
|
||||
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
|
||||
* Python version 3.6
|
||||
|
||||
Currently fairseq requires PyTorch version >= 0.4.0.
|
||||
Currently fairseq requires PyTorch version >= 1.0.0.
|
||||
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
|
||||
|
||||
If you use Docker make sure to increase the shared memory size either with
|
||||
|
@ -1,45 +0,0 @@
|
||||
#!/usr/bin/env python3 -u
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
from train import main as single_process_main
|
||||
from fairseq import distributed_utils, options
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.distributed_init_method is None and args.distributed_port > 0:
|
||||
# We can determine the init method automatically for Slurm.
|
||||
node_list = os.environ.get('SLURM_JOB_NODELIST')
|
||||
if node_list is not None:
|
||||
try:
|
||||
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
|
||||
args.distributed_init_method = 'tcp://{host}:{port}'.format(
|
||||
host=hostnames.split()[0].decode('utf-8'),
|
||||
port=args.distributed_port)
|
||||
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
|
||||
args.device_id = int(os.environ.get('SLURM_LOCALID'))
|
||||
except subprocess.CalledProcessError as e: # scontrol failed
|
||||
raise e
|
||||
except FileNotFoundError as e: # Slurm is not installed
|
||||
pass
|
||||
if args.distributed_init_method is None and args.distributed_port is None:
|
||||
raise ValueError('--distributed-init-method or --distributed-port '
|
||||
'must be specified for distributed training')
|
||||
|
||||
args.distributed_rank = distributed_utils.distributed_init(args)
|
||||
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
|
||||
single_process_main(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
@ -6,8 +6,26 @@
|
||||
Criterions
|
||||
==========
|
||||
|
||||
Criterions compute the loss function given the model and batch, roughly::
|
||||
|
||||
loss = criterion(model, batch)
|
||||
|
||||
.. automodule:: fairseq.criterions
|
||||
:members:
|
||||
|
||||
.. autoclass:: fairseq.criterions.FairseqCriterion
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -21,6 +21,20 @@ mini-batches.
|
||||
.. autoclass:: fairseq.data.MonolingualDataset
|
||||
:members:
|
||||
|
||||
**Helper Datasets**
|
||||
|
||||
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
|
||||
provide additional functionality:
|
||||
|
||||
.. autoclass:: fairseq.data.BacktranslationDataset
|
||||
:members:
|
||||
.. autoclass:: fairseq.data.ConcatDataset
|
||||
:members:
|
||||
.. autoclass:: fairseq.data.RoundRobinZipDatasets
|
||||
:members:
|
||||
.. autoclass:: fairseq.data.TransformEosDataset
|
||||
:members:
|
||||
|
||||
|
||||
Dictionary
|
||||
----------
|
||||
@ -32,6 +46,8 @@ Dictionary
|
||||
Iterators
|
||||
---------
|
||||
|
||||
.. autoclass:: fairseq.data.BufferedIterator
|
||||
:members:
|
||||
.. autoclass:: fairseq.data.CountingIterator
|
||||
:members:
|
||||
.. autoclass:: fairseq.data.EpochBatchIterator
|
||||
|
@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
|
||||
> MODEL_DIR=wmt14.en-fr.fconv-py
|
||||
> python interactive.py \
|
||||
--path $MODEL_DIR/model.pt $MODEL_DIR \
|
||||
--beam 5
|
||||
--beam 5 --source-lang en --target-lang fr
|
||||
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
|
||||
| [en] dictionary: 44206 types
|
||||
| [fr] dictionary: 44463 types
|
||||
| Type the input sentence and press return:
|
||||
> Why is it rare to discover new marine mam@@ mal species ?
|
||||
O Why is it rare to discover new marine mam@@ mal species ?
|
||||
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
|
||||
A 0 1 3 3 5 6 6 8 8 8 7 11 12
|
||||
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
|
||||
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
|
||||
|
||||
This generation script produces four types of outputs: a line prefixed
|
||||
with *S* shows the supplied source sentence after applying the
|
||||
vocabulary; *O* is a copy of the original source sentence; *H* is the
|
||||
hypothesis along with an average log-likelihood; and *A* is the
|
||||
attention maxima for each word in the hypothesis, including the
|
||||
This generation script produces three types of outputs: a line prefixed
|
||||
with *O* is a copy of the original source sentence; *H* is the
|
||||
hypothesis along with an average log-likelihood; and *P* is the
|
||||
positional score per token position, including the
|
||||
end-of-sentence marker which is omitted from the text.
|
||||
|
||||
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
|
||||
|
@ -6,7 +6,29 @@
|
||||
Learning Rate Schedulers
|
||||
========================
|
||||
|
||||
TODO
|
||||
Learning Rate Schedulers update the learning rate over the course of training.
|
||||
Learning rates can be updated after each update via :func:`step_update` or at
|
||||
epoch boundaries via :func:`step`.
|
||||
|
||||
.. automodule:: fairseq.optim.lr_scheduler
|
||||
:members:
|
||||
|
||||
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -1,8 +1,8 @@
|
||||
Modules
|
||||
=======
|
||||
|
||||
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be
|
||||
helpful when implementing a new :class:`FairseqModel`.
|
||||
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
|
||||
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
|
||||
|
||||
.. automodule:: fairseq.modules
|
||||
:members:
|
||||
|
@ -6,5 +6,27 @@
|
||||
Optimizers
|
||||
==========
|
||||
|
||||
Optimizers update the Model parameters based on the gradients.
|
||||
|
||||
.. automodule:: fairseq.optim
|
||||
:members:
|
||||
|
||||
.. autoclass:: fairseq.optim.FairseqOptimizer
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. autoclass:: fairseq.optim.adagrad.Adagrad
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.adam.FairseqAdam
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.nag.FairseqNAG
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.sgd.SGD
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
|
||||
for epoch in range(num_epochs):
|
||||
itr = task.get_batch_iterator(task.dataset('train'))
|
||||
for num_updates, batch in enumerate(itr):
|
||||
loss = criterion(model, batch)
|
||||
optimizer.backward(loss)
|
||||
task.train_step(batch, model, criterion, optimizer)
|
||||
average_and_clip_gradients()
|
||||
optimizer.step()
|
||||
lr_scheduler.step_update(num_updates)
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
where the default implementation for ``train.train_step`` is roughly::
|
||||
|
||||
def train_step(self, batch, model, criterion, optimizer):
|
||||
loss = criterion(model, batch)
|
||||
optimizer.backward(loss)
|
||||
|
||||
**Registering new plug-ins**
|
||||
|
||||
New plug-ins are *registered* through a set of ``@register`` function
|
||||
|
@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
|
||||
-------------------------------
|
||||
|
||||
Finally we can write a short script to evaluate our model on new inputs. Create
|
||||
a new file named :file:`eval_classify.py` with the following contents::
|
||||
a new file named :file:`eval_classifier.py` with the following contents::
|
||||
|
||||
from fairseq import data, options, tasks, utils
|
||||
from fairseq.tokenizer import Tokenizer
|
||||
|
||||
# Parse command-line arguments for generation
|
||||
parser = options.get_generation_parser()
|
||||
parser = options.get_generation_parser(default_task='simple_classification')
|
||||
args = options.parse_args_and_arch(parser)
|
||||
|
||||
# Setup task
|
||||
args.task = 'simple_classification'
|
||||
task = tasks.setup_task(args)
|
||||
|
||||
# Load model
|
||||
|
@ -55,7 +55,9 @@ def main(parsed_args):
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(parsed_args.path))
|
||||
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides))
|
||||
models, args = utils.load_ensemble_for_inference(
|
||||
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
|
||||
)
|
||||
|
||||
for arg in vars(parsed_args).keys():
|
||||
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
|
||||
@ -83,9 +85,10 @@ def main(parsed_args):
|
||||
max_positions=utils.resolve_max_positions(*[
|
||||
model.max_positions() for model in models
|
||||
]),
|
||||
ignore_invalid_inputs=True,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
ignore_invalid_inputs=True,
|
||||
num_workers=args.num_workers,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
|
||||
gen_timer = StopwatchMeter()
|
||||
|
@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
|
||||
from .fairseq_dataset import FairseqDataset
|
||||
from .backtranslation_dataset import BacktranslationDataset
|
||||
from .concat_dataset import ConcatDataset
|
||||
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
|
||||
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
|
||||
from .language_pair_dataset import LanguagePairDataset
|
||||
from .monolingual_dataset import MonolingualDataset
|
||||
from .round_robin_zip_datasets import RoundRobinZipDatasets
|
||||
@ -33,7 +33,6 @@ __all__ = [
|
||||
'GroupedIterator',
|
||||
'IndexedCachedDataset',
|
||||
'IndexedDataset',
|
||||
'IndexedInMemoryDataset',
|
||||
'IndexedRawTextDataset',
|
||||
'LanguagePairDataset',
|
||||
'MonolingualDataset',
|
||||
|
@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
|
||||
|
||||
|
||||
class BacktranslationDataset(FairseqDataset):
|
||||
"""
|
||||
Sets up a backtranslation dataset which takes a tgt batch, generates
|
||||
a src using a tgt-src backtranslation function (*backtranslation_fn*),
|
||||
and returns the corresponding `{generated src, input tgt}` batch.
|
||||
|
||||
Args:
|
||||
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
|
||||
backtranslated. Only the source side of this dataset will be used.
|
||||
After backtranslation, the source sentences in this dataset will be
|
||||
returned as the targets.
|
||||
backtranslation_fn (callable): function to call to generate
|
||||
backtranslations. This is typically the `generate` method of a
|
||||
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
|
||||
max_len_a, max_len_b (int, int): will be used to compute
|
||||
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
|
||||
into *backtranslation_fn*.
|
||||
output_collater (callable, optional): function to call on the
|
||||
backtranslated samples to create the final batch
|
||||
(default: ``tgt_dataset.collater``).
|
||||
cuda: use GPU for generation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tgt_dataset,
|
||||
@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset):
|
||||
cuda=True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sets up a backtranslation dataset which takes a tgt batch, generates
|
||||
a src using a tgt-src backtranslation function (*backtranslation_fn*),
|
||||
and returns the corresponding `{generated src, input tgt}` batch.
|
||||
|
||||
Args:
|
||||
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
|
||||
backtranslated. Only the source side of this dataset will be
|
||||
used. After backtranslation, the source sentences in this
|
||||
dataset will be returned as the targets.
|
||||
backtranslation_fn (callable): function to call to generate
|
||||
backtranslations. This is typically the `generate` method of a
|
||||
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
|
||||
max_len_a, max_len_b (int, int): will be used to compute
|
||||
`maxlen = max_len_a * src_len + max_len_b`, which will be
|
||||
passed into *backtranslation_fn*.
|
||||
output_collater (callable, optional): function to call on the
|
||||
backtranslated samples to create the final batch (default:
|
||||
``tgt_dataset.collater``)
|
||||
cuda: use GPU for generation
|
||||
"""
|
||||
self.tgt_dataset = tgt_dataset
|
||||
self.backtranslation_fn = backtranslation_fn
|
||||
self.max_len_a = max_len_a
|
||||
@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset):
|
||||
"""
|
||||
tgt_size = self.tgt_dataset.size(index)[0]
|
||||
return (tgt_size, tgt_size)
|
||||
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return self.tgt_dataset.supports_prefetch()
|
||||
return getattr(self.tgt_dataset, 'supports_prefetch', False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.tgt_dataset.prefetch(indices)
|
||||
|
||||
|
@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
|
||||
if isinstance(sample_ratios, int):
|
||||
sample_ratios = [sample_ratios] * len(self.datasets)
|
||||
self.sample_ratios = sample_ratios
|
||||
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
||||
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
||||
self.real_sizes = [len(d) for d in self.datasets]
|
||||
|
||||
def __len__(self):
|
||||
return self.cummulative_sizes[-1]
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
|
||||
|
||||
def prefetch(self, indices):
|
||||
frm = 0
|
||||
for to, ds in zip(self.cummulative_sizes, self.datasets):
|
||||
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
||||
real_size = len(ds)
|
||||
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
||||
frm = to
|
||||
|
@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
|
||||
size_fn (callable): function that returns the size of a given index
|
||||
max_positions (tuple): filter elements larger than this size.
|
||||
Comparisons are done component-wise.
|
||||
raise_exception (bool, optional): if ``True``, raise an exception
|
||||
if any elements are filtered. Default: ``False``
|
||||
raise_exception (bool, optional): if ``True``, raise an exception if
|
||||
any elements are filtered (default: False).
|
||||
"""
|
||||
def check_size(idx):
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
@ -128,12 +128,12 @@ def batch_by_size(
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
num_tokens_fn (callable): function that returns the number of tokens at
|
||||
a given index
|
||||
max_tokens (int, optional): max number of tokens in each batch.
|
||||
Default: ``None``
|
||||
max_tokens (int, optional): max number of tokens in each batch
|
||||
(default: None).
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch. Default: ``None``
|
||||
batch (default: None).
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be a multiple of N. Default: ``1``
|
||||
be a multiple of N (default: 1).
|
||||
"""
|
||||
max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
||||
max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
||||
|
@ -200,11 +200,15 @@ class Dictionary(object):
|
||||
t[-1] = self.eos()
|
||||
return t
|
||||
|
||||
|
||||
class TruncatedDictionary(object):
|
||||
|
||||
def __init__(self, wrapped_dict, length):
|
||||
self.__class__ = type(wrapped_dict.__class__.__name__,
|
||||
(self.__class__, wrapped_dict.__class__), {})
|
||||
self.__class__ = type(
|
||||
wrapped_dict.__class__.__name__,
|
||||
(self.__class__, wrapped_dict.__class__),
|
||||
{}
|
||||
)
|
||||
self.__dict__ = wrapped_dict.__dict__
|
||||
self.wrapped_dict = wrapped_dict
|
||||
self.length = min(len(self.wrapped_dict), length)
|
||||
|
@ -7,8 +7,6 @@
|
||||
|
||||
import torch.utils.data
|
||||
|
||||
from fairseq.data import data_utils
|
||||
|
||||
|
||||
class FairseqDataset(torch.utils.data.Dataset):
|
||||
"""A dataset that provides helpers for batching."""
|
||||
@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
"""Whether this dataset supports prefetching."""
|
||||
return False
|
||||
|
||||
def prefetch(self, indices):
|
||||
"""Prefetch the data required for this epoch."""
|
||||
raise NotImplementedError
|
||||
|
@ -52,13 +52,12 @@ def data_file_path(prefix_path):
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for TorchNet IndexedDataset"""
|
||||
|
||||
def __init__(self, path, fix_lua_indexing=False, read_data=True):
|
||||
def __init__(self, path, fix_lua_indexing=False):
|
||||
super().__init__()
|
||||
self.fix_lua_indexing = fix_lua_indexing
|
||||
self.read_index(path)
|
||||
self.data_file = None
|
||||
if read_data:
|
||||
self.read_data(path)
|
||||
self.path = path
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
self.data_file.close()
|
||||
|
||||
def __getitem__(self, i):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def read_into(self, start, dst):
|
||||
self.data_file.seek(start * self.element_size)
|
||||
self.data_file.readinto(dst)
|
||||
if self.fix_lua_indexing:
|
||||
dst -= 1 # subtract 1 for 0-based indexing
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (
|
||||
@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
os.path.exists(data_file_path(path))
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
|
||||
def __init__(self, path, fix_lua_indexing=False):
|
||||
super().__init__(path, fix_lua_indexing, True)
|
||||
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
|
||||
return item
|
||||
|
||||
|
||||
class IndexedInMemoryDataset(IndexedDataset):
|
||||
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), 'rb')
|
||||
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
|
||||
self.data_file.readinto(self.buffer)
|
||||
self.data_file.close()
|
||||
if self.fix_lua_indexing:
|
||||
self.buffer -= 1 # subtract 1 for 0-based indexing
|
||||
|
||||
def read_into(self, start, dst):
|
||||
if self.token_blob is None:
|
||||
self.token_blob = [t for l in self.tokens_list for t in l]
|
||||
np.copyto(dst, self.token_blob[start:])
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, i):
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
|
||||
return torch.from_numpy(a).long()
|
||||
|
||||
|
||||
class IndexedRawTextDataset(IndexedDataset):
|
||||
class IndexedRawTextDataset(torch.utils.data.Dataset):
|
||||
"""Takes a text file as input and binarizes it in memory at instantiation.
|
||||
Original lines are also kept in memory"""
|
||||
|
||||
@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
|
||||
self.sizes.append(len(tokens))
|
||||
self.sizes = np.array(self.sizes)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self.size:
|
||||
raise IndexError('index out of range')
|
||||
|
||||
def __getitem__(self, i):
|
||||
self.check_index(i)
|
||||
return self.tokens_list[i]
|
||||
@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file, read_data=False)
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
|
@ -69,17 +69,19 @@ class EpochBatchIterator(object):
|
||||
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
|
||||
indices
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility. Default: 1
|
||||
reproducibility (default: 1).
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards. Default: 1
|
||||
shards (default: 1).
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return. Default: 0
|
||||
buffer_size (int, optional): number of batches to buffer. Default: 5
|
||||
return (default: 0).
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means the data will be loaded in the main process
|
||||
(default: 0).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
|
||||
buffer_size=5,
|
||||
num_workers=0,
|
||||
):
|
||||
assert isinstance(dataset, torch.utils.data.Dataset)
|
||||
self.dataset = dataset
|
||||
@ -88,14 +90,12 @@ class EpochBatchIterator(object):
|
||||
self.seed = seed
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.buffer_size = buffer_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.epoch = 0
|
||||
self._cur_epoch_itr = None
|
||||
self._next_epoch_itr = None
|
||||
self._supports_prefetch = (
|
||||
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
|
||||
)
|
||||
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frozen_batches)
|
||||
@ -105,11 +105,10 @@ class EpochBatchIterator(object):
|
||||
|
||||
Args:
|
||||
shuffle (bool, optional): shuffle batches before returning the
|
||||
iterator. Default: ``True``
|
||||
iterator (default: True).
|
||||
fix_batches_to_gpus: ensure that batches are always
|
||||
allocated to the same shards across epochs. Requires
|
||||
that :attr:`dataset` supports prefetching. Default:
|
||||
``False``
|
||||
that :attr:`dataset` supports prefetching (default: False).
|
||||
"""
|
||||
if self._next_epoch_itr is not None:
|
||||
self._cur_epoch_itr = self._next_epoch_itr
|
||||
@ -117,7 +116,8 @@ class EpochBatchIterator(object):
|
||||
else:
|
||||
self.epoch += 1
|
||||
self._cur_epoch_itr = self._get_iterator_for_epoch(
|
||||
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus)
|
||||
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
|
||||
)
|
||||
return self._cur_epoch_itr
|
||||
|
||||
def end_of_epoch(self):
|
||||
@ -179,50 +179,14 @@ class EpochBatchIterator(object):
|
||||
batches = self.frozen_batches
|
||||
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
|
||||
|
||||
return CountingIterator(BufferedIterator(
|
||||
torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_sampler=batches,
|
||||
),
|
||||
buffer_size=self.buffer_size,
|
||||
return CountingIterator(torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_sampler=batches,
|
||||
num_workers=self.num_workers,
|
||||
))
|
||||
|
||||
|
||||
class BufferedIterator(object):
|
||||
"""Wrapper around an iterable that prefetches items into a buffer.
|
||||
|
||||
Args:
|
||||
iterable (iterable): iterable to wrap
|
||||
buffer_size (int): number of items to prefetch and buffer
|
||||
"""
|
||||
|
||||
def __init__(self, iterable, buffer_size):
|
||||
self.iterable = iterable
|
||||
|
||||
self.q = queue.Queue(maxsize=buffer_size)
|
||||
self.thread = threading.Thread(target=self._load_q, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iterable)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
x = self.q.get()
|
||||
if x is None:
|
||||
self.thread.join()
|
||||
raise StopIteration
|
||||
return x[0]
|
||||
|
||||
def _load_q(self):
|
||||
for x in self.iterable:
|
||||
self.q.put([x]) # wrap in list so that it's never None
|
||||
self.q.put(None)
|
||||
|
||||
|
||||
class GroupedIterator(object):
|
||||
"""Wrapper around an iterable that returns groups (chunks) of items.
|
||||
|
||||
@ -261,7 +225,7 @@ class ShardedIterator(object):
|
||||
num_shards (int): number of shards to split the iterable into
|
||||
shard_id (int): which shard to iterator over
|
||||
fill_value (Any, optional): padding value when the iterable doesn't
|
||||
evenly divide *num_shards*. Default: ``None``
|
||||
evenly divide *num_shards* (default: None).
|
||||
"""
|
||||
|
||||
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
|
||||
|
@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
|
||||
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
|
||||
tgt_sizes (List[int], optional): target sentence lengths
|
||||
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
|
||||
left_pad_source (bool, optional): pad source tensors on the left side.
|
||||
Default: ``True``
|
||||
left_pad_target (bool, optional): pad target tensors on the left side.
|
||||
Default: ``False``
|
||||
max_source_positions (int, optional): max number of tokens in the source
|
||||
sentence. Default: ``1024``
|
||||
max_target_positions (int, optional): max number of tokens in the target
|
||||
sentence. Default: ``1024``
|
||||
shuffle (bool, optional): shuffle dataset elements before batching.
|
||||
Default: ``True``
|
||||
left_pad_source (bool, optional): pad source tensors on the left side
|
||||
(default: True).
|
||||
left_pad_target (bool, optional): pad target tensors on the left side
|
||||
(default: False).
|
||||
max_source_positions (int, optional): max number of tokens in the
|
||||
source sentence (default: 1024).
|
||||
max_target_positions (int, optional): max number of tokens in the
|
||||
target sentence (default: 1024).
|
||||
shuffle (bool, optional): shuffle dataset elements before batching
|
||||
(default: True).
|
||||
input_feeding (bool, optional): create a shifted version of the targets
|
||||
to be passed into the model for input feeding/teacher forcing.
|
||||
Default: ``True``
|
||||
remove_eos_from_source (bool, optional): if set, removes eos from end of
|
||||
source if it's present. Default: ``False``
|
||||
to be passed into the model for input feeding/teacher forcing
|
||||
(default: True).
|
||||
remove_eos_from_source (bool, optional): if set, removes eos from end
|
||||
of source if it's present (default: False).
|
||||
append_eos_to_target (bool, optional): if set, appends eos to end of
|
||||
target if it's absent. Default: ``False``
|
||||
target if it's absent (default: False).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
|
||||
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
|
||||
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.src.prefetch(indices)
|
||||
self.tgt.prefetch(indices)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return (
|
||||
hasattr(self.src, 'supports_prefetch')
|
||||
and self.src.supports_prefetch
|
||||
and hasattr(self.tgt, 'supports_prefetch')
|
||||
and self.tgt.supports_prefetch
|
||||
getattr(self.src, 'supports_prefetch', False)
|
||||
and getattr(self.tgt, 'supports_prefetch', False)
|
||||
)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.src.prefetch(indices)
|
||||
self.tgt.prefetch(indices)
|
||||
|
@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from . import data_utils, FairseqDataset
|
||||
from typing import List
|
||||
|
||||
|
||||
def collate(samples, pad_idx, eos_idx):
|
||||
@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset):
|
||||
dataset (torch.utils.data.Dataset): dataset to wrap
|
||||
sizes (List[int]): sentence lengths
|
||||
vocab (~fairseq.data.Dictionary): vocabulary
|
||||
shuffle (bool, optional): shuffle the elements before batching.
|
||||
Default: ``True``
|
||||
shuffle (bool, optional): shuffle the elements before batching
|
||||
(default: True).
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
|
||||
@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset):
|
||||
self.add_eos_for_other_targets = add_eos_for_other_targets
|
||||
self.shuffle = shuffle
|
||||
|
||||
assert targets is None or all(
|
||||
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'"
|
||||
assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
|
||||
"targets must be none or one of 'self', 'future', 'past'"
|
||||
if targets is not None and len(targets) == 0:
|
||||
targets = None
|
||||
self.targets = targets
|
||||
@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset):
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return self.dataset.supports_prefetch
|
||||
return getattr(self.dataset, 'supports_prefetch', False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch(indices)
|
||||
|
@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset):
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sets up a noising dataset which takes a src batch, generates
|
||||
a noisy src using a noising config, and returns the
|
||||
corresponding {noisy src, original src} batch
|
||||
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
|
||||
samples based on the supplied noising configuration.
|
||||
|
||||
Args:
|
||||
src_dataset: dataset which will be used to build self.src_dataset --
|
||||
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
|
||||
to build self.src_dataset --
|
||||
a LanguagePairDataset with src dataset as the source dataset and
|
||||
None as the target dataset. Should NOT have padding so that
|
||||
src_lengths are accurately calculated by language_pair_dataset
|
||||
@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset):
|
||||
We use language_pair_dataset here to encapsulate the tgt_dataset
|
||||
so we can re-use the LanguagePairDataset collater to format the
|
||||
batches in the structure that SequenceGenerator expects.
|
||||
src_dict: src dict
|
||||
src_dict: src dictionary
|
||||
seed: seed to use when generating random noise
|
||||
noiser: a pre-initialized noiser. If this is None, a noiser will
|
||||
be created using noising_class and kwargs.
|
||||
noising_class: class to use when initializing noiser
|
||||
kwargs: noising args for configuring noising to apply
|
||||
Note that there is no equivalent argparse code for these args
|
||||
anywhere in our top level train scripts yet. Integration is
|
||||
still in progress. You can still, however, test out this dataset
|
||||
functionality with the appropriate args as in the corresponding
|
||||
unittest: test_noising_dataset.
|
||||
src_dict (~fairseq.data.Dictionary): source dictionary
|
||||
seed (int): seed to use when generating random noise
|
||||
noiser (WordNoising): a pre-initialized :class:`WordNoising`
|
||||
instance. If this is None, a new instance will be created using
|
||||
*noising_class* and *kwargs*.
|
||||
noising_class (class, optional): class to use to initialize a
|
||||
default :class:`WordNoising` instance.
|
||||
kwargs (dict, optional): arguments to initialize the default
|
||||
:class:`WordNoising` instance given by *noiser*.
|
||||
"""
|
||||
|
||||
self.src_dataset = src_dataset
|
||||
self.src_dict = src_dict
|
||||
self.seed = seed
|
||||
self.noiser = noiser if noiser is not None else noising_class(
|
||||
dictionary=src_dict, **kwargs,
|
||||
)
|
||||
self.seed = seed
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
|
@ -13,13 +13,16 @@ from . import FairseqDataset
|
||||
|
||||
|
||||
class RoundRobinZipDatasets(FairseqDataset):
|
||||
"""Zip multiple FairseqDatasets together, repeating shorter datasets in a
|
||||
round-robin fashion to match the length of the longest one.
|
||||
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
|
||||
|
||||
Shorter datasets are repeated in a round-robin fashion to match the length
|
||||
of the longest one.
|
||||
|
||||
Args:
|
||||
datasets: a dictionary of FairseqDatasets
|
||||
eval_key: an optional key used at evaluation time that causes this
|
||||
instance to pass-through batches from `datasets[eval_key]`.
|
||||
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
|
||||
:class:`~fairseq.data.FairseqDataset` instances.
|
||||
eval_key (str, optional): a key used at evaluation time that causes
|
||||
this instance to pass-through batches from *datasets[eval_key]*.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, eval_key=None):
|
||||
@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset):
|
||||
dataset.valid_size(self._map_index(key, index), max_positions[key])
|
||||
for key, dataset in self.datasets.items()
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return all(
|
||||
getattr(dataset, 'supports_prefetch', False)
|
||||
for dataset in self.datasets.values()
|
||||
)
|
||||
|
||||
def prefetch(self, indices):
|
||||
for key, dataset in self.datasets.items():
|
||||
dataset.prefetch([self._map_index(key, index) for index in indices])
|
||||
|
@ -14,32 +14,32 @@ from . import FairseqDataset
|
||||
|
||||
|
||||
class TokenBlockDataset(FairseqDataset):
|
||||
"""Break a 1d tensor of tokens into blocks.
|
||||
|
||||
The blocks are fetched from the original tensor so no additional memory is allocated.
|
||||
"""Break a Dataset of tokens into blocks.
|
||||
|
||||
Args:
|
||||
tokens: 1d tensor of tokens to break into blocks
|
||||
sizes: sentence lengths (required for 'complete' and 'eos')
|
||||
block_size: maximum block size (ignored in 'eos' break mode)
|
||||
break_mode: Mode used for breaking tokens. Values can be one of:
|
||||
dataset (~torch.utils.data.Dataset): dataset to break into blocks
|
||||
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
|
||||
block_size (int): maximum block size (ignored in 'eos' break mode)
|
||||
break_mode (str, optional): Mode used for breaking tokens. Values can
|
||||
be one of:
|
||||
- 'none': break tokens into equally sized blocks (up to block_size)
|
||||
- 'complete': break tokens into blocks (up to block_size) such that
|
||||
blocks contains complete sentences, although block_size may be
|
||||
exceeded if some sentences exceed block_size
|
||||
- 'eos': each block contains one sentence (block_size is ignored)
|
||||
include_targets: return next tokens as targets
|
||||
include_targets (bool, optional): return next tokens as targets
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False):
|
||||
def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
|
||||
super().__init__()
|
||||
self.dataset = ds
|
||||
self.dataset = dataset
|
||||
self.pad = pad
|
||||
self.eos = eos
|
||||
self.include_targets = include_targets
|
||||
self.slice_indices = []
|
||||
self.cache_index = {}
|
||||
sizes = ds.sizes
|
||||
|
||||
assert len(dataset) == len(sizes)
|
||||
|
||||
if break_mode is None or break_mode == 'none':
|
||||
total_size = sum(sizes)
|
||||
@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset):
|
||||
|
||||
self.sizes = np.array([e - s for s, e in self.slice_indices])
|
||||
|
||||
def __getitem__(self, index):
|
||||
s, e = self.cache_index[index]
|
||||
# build index mapping block indices to the underlying dataset indices
|
||||
self.block_to_dataset_index = []
|
||||
ds_idx, ds_remaining = -1, 0
|
||||
for to_consume in self.sizes:
|
||||
if ds_remaining == 0:
|
||||
ds_idx += 1
|
||||
ds_remaining = sizes[ds_idx]
|
||||
start_ds_idx = ds_idx
|
||||
start_offset = sizes[ds_idx] - ds_remaining
|
||||
while to_consume > ds_remaining:
|
||||
to_consume -= ds_remaining
|
||||
ds_idx += 1
|
||||
ds_remaining = sizes[ds_idx]
|
||||
ds_remaining -= to_consume
|
||||
self.block_to_dataset_index.append((
|
||||
start_ds_idx, # starting index in dataset
|
||||
start_offset, # starting offset within starting index
|
||||
ds_idx, # ending index in dataset
|
||||
))
|
||||
assert ds_remaining == 0
|
||||
assert ds_idx == len(self.dataset) - 1
|
||||
|
||||
item = torch.from_numpy(self.cache[s:e]).long()
|
||||
def __getitem__(self, index):
|
||||
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
|
||||
buffer = torch.cat([
|
||||
self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)
|
||||
])
|
||||
slice_s, slice_e = self.slice_indices[index]
|
||||
length = slice_e - slice_s
|
||||
s, e = start_offset, start_offset + length
|
||||
item = buffer[s:e]
|
||||
|
||||
if self.include_targets:
|
||||
# target is the sentence, for source, rotate item one token to the left (would start with eos)
|
||||
# past target is rotated to the left by 2 (padded if its first)
|
||||
# *target* is the original sentence (=item)
|
||||
# *source* is rotated left by 1 (maybe left-padded with eos)
|
||||
# *past_target* is rotated left by 2 (left-padded as needed)
|
||||
if s == 0:
|
||||
source = np.concatenate([[self.eos], self.cache[0:e - 1]])
|
||||
past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]])
|
||||
source = torch.cat([item.new([self.eos]), buffer[0:e - 1]])
|
||||
past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]])
|
||||
else:
|
||||
source = self.cache[s - 1: e - 1]
|
||||
source = buffer[s - 1:e - 1]
|
||||
if s == 1:
|
||||
past_target = np.concatenate([[self.eos], self.cache[0:e - 2]])
|
||||
past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]])
|
||||
else:
|
||||
past_target = self.cache[s - 2:e - 2]
|
||||
past_target = buffer[s - 2:e - 2]
|
||||
|
||||
return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long()
|
||||
return source, item, past_target
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.slice_indices)
|
||||
|
||||
def prefetch(self, indices):
|
||||
indices.sort()
|
||||
total_size = 0
|
||||
for idx in indices:
|
||||
s, e = self.slice_indices[idx]
|
||||
total_size += e - s
|
||||
self.cache = np.empty(total_size, dtype=np.int32)
|
||||
start = 0
|
||||
for idx in indices:
|
||||
s, e = self.slice_indices[idx]
|
||||
self.dataset.read_into(s, self.cache[start:start + e - s])
|
||||
self.cache_index[idx] = (start, start + e - s)
|
||||
start += e - s
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
return getattr(self.dataset, 'supports_prefetch', False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch({
|
||||
ds_idx
|
||||
for index in indices
|
||||
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
|
||||
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
|
||||
})
|
||||
|
@ -11,7 +11,7 @@ from . import FairseqDataset
|
||||
|
||||
|
||||
class TransformEosDataset(FairseqDataset):
|
||||
"""A dataset wrapper that appends/prepends/strips EOS.
|
||||
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
|
||||
|
||||
Note that the transformation is applied in :func:`collater`.
|
||||
|
||||
@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset):
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return self.dataset.supports_prefetch()
|
||||
return getattr(self.dataset, 'supports_prefetch', False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.dataset.prefetch(indices)
|
||||
|
@ -6,7 +6,9 @@
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from collections import namedtuple
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -42,6 +44,38 @@ else:
|
||||
import torch.distributed as dist_no_c10d
|
||||
|
||||
|
||||
def infer_init_method(args):
|
||||
if args.distributed_init_method is not None:
|
||||
return
|
||||
|
||||
# support torch.distributed.launch
|
||||
if all(key in os.environ for key in [
|
||||
'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
|
||||
]):
|
||||
args.distributed_init_method = 'tcp://{addr}:{port}'.format(
|
||||
addr=os.environ['MASTER_ADDR'],
|
||||
port=os.environ['MASTER_PORT'],
|
||||
)
|
||||
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.distributed_rank = int(os.environ['RANK'])
|
||||
|
||||
# we can determine the init method automatically for Slurm
|
||||
elif args.distributed_port > 0:
|
||||
node_list = os.environ.get('SLURM_JOB_NODELIST')
|
||||
if node_list is not None:
|
||||
try:
|
||||
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
|
||||
args.distributed_init_method = 'tcp://{host}:{port}'.format(
|
||||
host=hostnames.split()[0].decode('utf-8'),
|
||||
port=args.distributed_port)
|
||||
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
|
||||
args.device_id = int(os.environ.get('SLURM_LOCALID'))
|
||||
except subprocess.CalledProcessError as e: # scontrol failed
|
||||
raise e
|
||||
except FileNotFoundError: # Slurm is not installed
|
||||
pass
|
||||
|
||||
|
||||
def distributed_init(args):
|
||||
if args.distributed_world_size == 1:
|
||||
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
|
||||
@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384):
|
||||
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
|
||||
)
|
||||
return result
|
||||
except pickle.UnpicklingError as e:
|
||||
except pickle.UnpicklingError:
|
||||
raise Exception(
|
||||
'Unable to unpickle data from other workers. all_gather_list requires all '
|
||||
'workers to enter the function together, so this error usually indicates '
|
||||
@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384):
|
||||
'in your training script that can cause one worker to finish an epoch '
|
||||
'while other workers are still iterating over their portions of the data.'
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the
|
||||
c10d version of DDP.
|
||||
|
||||
This version also supports the *accumulate_grads* feature, which allows faster
|
||||
training with --update-freq.
|
||||
training with `--update-freq`.
|
||||
"""
|
||||
|
||||
import copy
|
||||
@ -27,18 +27,18 @@ from . import distributed_utils
|
||||
class LegacyDistributedDataParallel(nn.Module):
|
||||
"""Implements distributed data parallelism at the module level.
|
||||
|
||||
A simplified version of torch.nn.parallel.DistributedDataParallel.
|
||||
This version uses a c10d process group for communication and does
|
||||
not broadcast buffers.
|
||||
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
|
||||
This version uses a c10d process group for communication and does not
|
||||
broadcast buffers.
|
||||
|
||||
Args:
|
||||
module: module to be parallelized
|
||||
world_size: number of parallel workers
|
||||
module (~torch.nn.Module): module to be parallelized
|
||||
world_size (int): number of parallel workers
|
||||
process_group (optional): the c10d process group to be used for
|
||||
distributed data all-reduction. If None, the default process group
|
||||
will be used.
|
||||
buffer_size: number of elements to buffer before performing all-reduce
|
||||
(default: 256M).
|
||||
buffer_size (int, optional): number of elements to buffer before
|
||||
performing all-reduce (default: 256M).
|
||||
"""
|
||||
|
||||
def __init__(self, module, world_size, process_group=None, buffer_size=2**28):
|
||||
|
@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder):
|
||||
connections are added between layers when ``residual=1`` (which is
|
||||
the default behavior).
|
||||
dropout (float, optional): dropout to be applied before each conv layer
|
||||
normalization_constant (float, optional): multiplies the result of the
|
||||
residual block by sqrt(value)
|
||||
left_pad (bool, optional): whether the input is left-padded. Default:
|
||||
``True``
|
||||
left_pad (bool, optional): whether the input is left-padded
|
||||
(default: True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder):
|
||||
self.residuals = []
|
||||
|
||||
layer_in_channels = [in_channels]
|
||||
for i, (out_channels, kernel_size, residual) in enumerate(convolutions):
|
||||
for _, (out_channels, kernel_size, residual) in enumerate(convolutions):
|
||||
if residual == 0:
|
||||
residual_dim = out_channels
|
||||
else:
|
||||
|
@ -524,6 +524,7 @@ def base_architecture(args):
|
||||
args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '')
|
||||
args.pretrained = getattr(args, 'pretrained', 'False')
|
||||
|
||||
|
||||
@register_model_architecture('fconv_self_att', 'fconv_self_att_wp')
|
||||
def fconv_self_att_wp(args):
|
||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
|
||||
|
@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder):
|
||||
if bidirectional:
|
||||
self.output_units *= 2
|
||||
|
||||
|
||||
def forward(self, src_tokens, src_lengths):
|
||||
if self.left_pad:
|
||||
# convert left-padding to right-padding
|
||||
@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder):
|
||||
if self.bidirectional:
|
||||
|
||||
def combine_bidir(outs):
|
||||
return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1)
|
||||
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
|
||||
return out.view(self.num_layers, bsz, -1)
|
||||
|
||||
final_hiddens = combine_bidir(final_hiddens)
|
||||
final_cells = combine_bidir(final_cells)
|
||||
@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
|
||||
elif not self.share_input_output_embed:
|
||||
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
|
||||
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
|
||||
encoder_out = encoder_out_dict['encoder_out']
|
||||
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
|
||||
@ -504,6 +503,7 @@ def base_architecture(args):
|
||||
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
|
||||
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
|
||||
|
||||
|
||||
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
|
||||
def lstm_wiseman_iwslt_de_en(args):
|
||||
args.dropout = getattr(args, 'dropout', 0.1)
|
||||
|
@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
|
||||
# make sure all arguments are present in older models
|
||||
base_lm_architecture(args)
|
||||
|
||||
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False:
|
||||
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
|
||||
# backward compatibility
|
||||
args.tie_adaptive_proj = True
|
||||
|
||||
@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel):
|
||||
args.max_target_positions = args.tokens_per_sample
|
||||
|
||||
if args.character_embeddings:
|
||||
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters),
|
||||
args.character_embedding_dim,
|
||||
args.decoder_embed_dim,
|
||||
args.char_embedder_highway_layers,
|
||||
)
|
||||
embed_tokens = CharacterTokenEmbedder(
|
||||
task.dictionary, eval(args.character_filters),
|
||||
args.character_embedding_dim, args.decoder_embed_dim,
|
||||
args.char_embedder_highway_layers,
|
||||
)
|
||||
elif args.adaptive_input:
|
||||
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
|
||||
args.adaptive_input_factor, args.decoder_embed_dim,
|
||||
options.eval_str_list(args.adaptive_input_cutoff, type=int))
|
||||
embed_tokens = AdaptiveInput(
|
||||
len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
|
||||
args.adaptive_input_factor, args.decoder_embed_dim,
|
||||
options.eval_str_list(args.adaptive_input_cutoff, type=int),
|
||||
)
|
||||
else:
|
||||
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
|
||||
|
||||
@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
|
||||
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
|
||||
assert args.decoder_input_dim == args.decoder_output_dim
|
||||
|
||||
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
|
||||
decoder = TransformerDecoder(
|
||||
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False,
|
||||
)
|
||||
return TransformerLanguageModel(decoder)
|
||||
|
||||
|
||||
@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder):
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): input embedding
|
||||
left_pad (bool, optional): whether the input is left-padded. Default:
|
||||
``True``
|
||||
left_pad (bool, optional): whether the input is left-padded
|
||||
(default: True).
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
|
||||
@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
|
||||
Default: ``False``
|
||||
left_pad (bool, optional): whether the input is left-padded. Default:
|
||||
``False``
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
left_pad (bool, optional): whether the input is left-padded
|
||||
(default: False).
|
||||
final_norm (bool, optional): apply layer norm to the output of the
|
||||
final decoder layer (default: True).
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
|
||||
@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
|
||||
Default: ``False``
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, args, no_encoder_attn=False):
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from typing import List
|
||||
@ -16,13 +15,13 @@ from typing import List
|
||||
class AdaptiveInput(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
padding_idx: int,
|
||||
initial_dim: int,
|
||||
factor: float,
|
||||
output_dim: int,
|
||||
cutoff: List[int],
|
||||
self,
|
||||
vocab_size: int,
|
||||
padding_idx: int,
|
||||
initial_dim: int,
|
||||
factor: float,
|
||||
output_dim: int,
|
||||
cutoff: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module):
|
||||
m = nn.Sequential(
|
||||
proj,
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \
|
||||
if tied_emb is None else TiedLinear(tied_emb, transpose=False)
|
||||
nn.Linear(
|
||||
dim, self.cutoff[i + 1] - self.cutoff[i], bias=False,
|
||||
) if tied_emb is None else TiedLinear(tied_emb, transpose=False),
|
||||
)
|
||||
|
||||
self.tail.append(m)
|
||||
|
@ -9,7 +9,7 @@ import importlib
|
||||
import os
|
||||
|
||||
from .fairseq_optimizer import FairseqOptimizer
|
||||
from .fp16_optimizer import FP16Optimizer
|
||||
from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
|
||||
|
||||
|
||||
OPTIMIZER_REGISTRY = {}
|
||||
|
@ -70,10 +70,11 @@ class FairseqOptimizer(object):
|
||||
group.update(optimizer_overrides)
|
||||
|
||||
def backward(self, loss):
|
||||
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
|
||||
loss.backward()
|
||||
|
||||
def multiply_grads(self, c):
|
||||
"""Multiplies grads by a constant ``c``."""
|
||||
"""Multiplies grads by a constant *c*."""
|
||||
for p in self.params:
|
||||
if p.grad is not None:
|
||||
p.grad.data.mul_(c)
|
||||
|
@ -45,6 +45,164 @@ class DynamicLossScaler(object):
|
||||
return False
|
||||
|
||||
|
||||
class FP16Optimizer(optim.FairseqOptimizer):
|
||||
"""
|
||||
Wrap an *optimizer* to support FP16 (mixed precision) training.
|
||||
"""
|
||||
|
||||
def __init__(self, args, params, fp32_optimizer, fp32_params):
|
||||
super().__init__(args, params)
|
||||
self.fp32_optimizer = fp32_optimizer
|
||||
self.fp32_params = fp32_params
|
||||
|
||||
if getattr(args, 'fp16_scale_window', None) is None:
|
||||
if len(args.update_freq) > 1:
|
||||
raise ValueError(
|
||||
'--fp16-scale-window must be given explicitly when using a '
|
||||
'custom --update-freq schedule'
|
||||
)
|
||||
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
|
||||
else:
|
||||
scale_window = args.fp16_scale_window
|
||||
|
||||
self.scaler = DynamicLossScaler(
|
||||
init_scale=args.fp16_init_scale,
|
||||
scale_window=scale_window,
|
||||
tolerance=args.fp16_scale_tolerance,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_optimizer(cls, args, params):
|
||||
"""
|
||||
Args:
|
||||
args (argparse.Namespace): fairseq args
|
||||
params (iterable): iterable of parameters to optimize
|
||||
"""
|
||||
# create FP32 copy of parameters and grads
|
||||
total_param_size = sum(p.data.numel() for p in params)
|
||||
fp32_params = params[0].new(0).float().new(total_param_size)
|
||||
offset = 0
|
||||
for p in params:
|
||||
numel = p.data.numel()
|
||||
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
|
||||
offset += numel
|
||||
fp32_params = torch.nn.Parameter(fp32_params)
|
||||
fp32_params.grad = fp32_params.data.new(total_param_size)
|
||||
|
||||
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
|
||||
return cls(args, params, fp32_optimizer, fp32_params)
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
return self.fp32_optimizer.optimizer
|
||||
|
||||
@property
|
||||
def optimizer_config(self):
|
||||
return self.fp32_optimizer.optimizer_config
|
||||
|
||||
def get_lr(self):
|
||||
return self.fp32_optimizer.get_lr()
|
||||
|
||||
def set_lr(self, lr):
|
||||
self.fp32_optimizer.set_lr(lr)
|
||||
|
||||
def state_dict(self):
|
||||
"""Return the optimizer's state dict."""
|
||||
state_dict = self.fp32_optimizer.state_dict()
|
||||
state_dict['loss_scale'] = self.scaler.loss_scale
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, optimizer_overrides=None):
|
||||
"""Load an optimizer state dict.
|
||||
|
||||
In general we should prefer the configuration of the existing optimizer
|
||||
instance (e.g., learning rate) over that found in the state_dict. This
|
||||
allows us to resume training from a checkpoint using a new set of
|
||||
optimizer args.
|
||||
"""
|
||||
if 'loss_scale' in state_dict:
|
||||
self.scaler.loss_scale = state_dict['loss_scale']
|
||||
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
|
||||
|
||||
def backward(self, loss):
|
||||
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
|
||||
|
||||
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
|
||||
function additionally dynamically scales the loss to avoid gradient
|
||||
underflow.
|
||||
"""
|
||||
loss = loss * self.scaler.loss_scale
|
||||
loss.backward()
|
||||
self._needs_sync = True
|
||||
|
||||
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
|
||||
if self._needs_sync:
|
||||
# copy FP16 grads to FP32
|
||||
offset = 0
|
||||
for p in self.params:
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
|
||||
numel = grad_data.numel()
|
||||
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
|
||||
offset += numel
|
||||
|
||||
# correct for dynamic loss scaler
|
||||
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
|
||||
|
||||
self._needs_sync = False
|
||||
|
||||
def multiply_grads(self, c):
|
||||
"""Multiplies grads by a constant ``c``."""
|
||||
if self._needs_sync:
|
||||
self._sync_fp16_grads_to_fp32(c)
|
||||
else:
|
||||
self.fp32_params.grad.data.mul_(c)
|
||||
|
||||
def clip_grad_norm(self, max_norm):
|
||||
"""Clips gradient norm and updates dynamic loss scaler."""
|
||||
self._sync_fp16_grads_to_fp32()
|
||||
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
|
||||
|
||||
# detect overflow and adjust loss scale
|
||||
overflow = DynamicLossScaler.has_overflow(grad_norm)
|
||||
self.scaler.update_scale(overflow)
|
||||
if overflow:
|
||||
if self.scaler.loss_scale <= self.args.min_loss_scale:
|
||||
# Use FloatingPointError as an uncommon error that parent
|
||||
# functions can safely catch to stop training.
|
||||
raise FloatingPointError((
|
||||
'Minimum loss scale reached ({}). Your loss is probably exploding. '
|
||||
'Try lowering the learning rate, using gradient clipping or '
|
||||
'increasing the batch size.'
|
||||
).format(self.args.min_loss_scale))
|
||||
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
|
||||
return grad_norm
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step."""
|
||||
self._sync_fp16_grads_to_fp32()
|
||||
self.fp32_optimizer.step(closure)
|
||||
|
||||
# copy FP32 params back into FP16 model
|
||||
offset = 0
|
||||
for p in self.params:
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
numel = p.data.numel()
|
||||
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
|
||||
offset += numel
|
||||
|
||||
def zero_grad(self):
|
||||
"""Clears the gradients of all optimized parameters."""
|
||||
self.fp32_optimizer.zero_grad()
|
||||
for p in self.params:
|
||||
if p.grad is not None:
|
||||
p.grad.detach_()
|
||||
p.grad.zero_()
|
||||
self._needs_sync = False
|
||||
|
||||
|
||||
class ConvertToFP32(object):
|
||||
"""
|
||||
A wrapper around a list of params that will convert them to FP32 on the
|
||||
@ -94,14 +252,13 @@ class ConvertToFP32(object):
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class FP16Optimizer(optim.FairseqOptimizer):
|
||||
class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
|
||||
"""
|
||||
Wrap an *optimizer* to support FP16 (mixed precision) training.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): fairseq args
|
||||
params (iterable): iterable of parameters to optimize
|
||||
optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap
|
||||
Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less
|
||||
memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff
|
||||
is reduced optimization speed, which can be mitigated with `--update-freq`.
|
||||
"""
|
||||
|
||||
def __init__(self, args, params, optimizer):
|
||||
@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer):
|
||||
tolerance=args.fp16_scale_tolerance,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_optimizer(args, params):
|
||||
@classmethod
|
||||
def build_optimizer(cls, args, params):
|
||||
"""
|
||||
Args:
|
||||
args (argparse.Namespace): fairseq args
|
||||
params (iterable): iterable of parameters to optimize
|
||||
"""
|
||||
fp16_optimizer = optim.build_optimizer(args, params)
|
||||
return FP16Optimizer(args, params, fp16_optimizer)
|
||||
return cls(args, params, fp16_optimizer)
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
|
||||
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
|
||||
|
||||
def backward(self, loss):
|
||||
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
|
||||
|
||||
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
|
||||
function additionally dynamically scales the loss to avoid gradient
|
||||
underflow.
|
||||
"""
|
||||
loss = loss * self.scaler.loss_scale
|
||||
loss.backward()
|
||||
self._grads_are_scaled = True
|
||||
@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
|
||||
assert multiply_grads == 1.
|
||||
|
||||
def multiply_grads(self, c):
|
||||
"""Multiplies grads by a constant ``c``."""
|
||||
"""Multiplies grads by a constant *c*."""
|
||||
if self._grads_are_scaled:
|
||||
self._unscale_grads(c)
|
||||
else:
|
||||
|
@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler
|
||||
@register_lr_scheduler('cosine')
|
||||
class CosineSchedule(FairseqLRScheduler):
|
||||
"""Assign LR based on a cyclical schedule that follows the cosine function.
|
||||
See https://arxiv.org/pdf/1608.03983.pdf for details
|
||||
|
||||
See https://arxiv.org/pdf/1608.03983.pdf for details.
|
||||
|
||||
We also support a warmup phase where we linearly increase the learning rate
|
||||
from some initial learning rate (`--warmup-init-lr`) until the configured
|
||||
learning rate (`--lr`).
|
||||
During warmup:
|
||||
from some initial learning rate (``--warmup-init-lr``) until the configured
|
||||
learning rate (``--lr``).
|
||||
|
||||
During warmup::
|
||||
|
||||
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
|
||||
lr = lrs[update_num]
|
||||
After warmup:
|
||||
|
||||
After warmup::
|
||||
|
||||
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
|
||||
where
|
||||
t_curr is current percentage of updates within the current period range
|
||||
t_i is the current period range, which is scaled by t_mul after every iteration
|
||||
|
||||
where ``t_curr`` is current percentage of updates within the current period
|
||||
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
|
||||
after every iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, args, optimizer):
|
||||
@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler):
|
||||
if args.warmup_init_lr < 0:
|
||||
args.warmup_init_lr = args.lr[0]
|
||||
|
||||
self.min_lr = args.lr[0]
|
||||
self.min_lr = args.lr[0]
|
||||
self.max_lr = args.max_lr
|
||||
|
||||
assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
|
||||
@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler):
|
||||
t_curr = curr_updates - (self.period * i)
|
||||
|
||||
lr_shrink = self.lr_shrink ** i
|
||||
min_lr = self.min_lr * lr_shrink
|
||||
min_lr = self.min_lr * lr_shrink
|
||||
max_lr = self.max_lr * lr_shrink
|
||||
|
||||
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
|
||||
|
@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
|
||||
"""Decay the LR based on the inverse square root of the update number.
|
||||
|
||||
We also support a warmup phase where we linearly increase the learning rate
|
||||
from some initial learning rate (`--warmup-init-lr`) until the configured
|
||||
learning rate (`--lr`). Thereafter we decay proportional to the number of
|
||||
from some initial learning rate (``--warmup-init-lr``) until the configured
|
||||
learning rate (``--lr``). Thereafter we decay proportional to the number of
|
||||
updates, with a decay factor set to align with the configured learning rate.
|
||||
|
||||
During warmup:
|
||||
During warmup::
|
||||
|
||||
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
|
||||
lr = lrs[update_num]
|
||||
|
||||
After warmup:
|
||||
|
||||
lr = decay_factor / sqrt(update_num)
|
||||
|
||||
where
|
||||
After warmup::
|
||||
|
||||
decay_factor = args.lr * sqrt(args.warmup_updates)
|
||||
lr = decay_factor / sqrt(update_num)
|
||||
"""
|
||||
|
||||
def __init__(self, args, optimizer):
|
||||
|
@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler
|
||||
class TriangularSchedule(FairseqLRScheduler):
|
||||
"""Assign LR based on a triangular cyclical schedule.
|
||||
|
||||
See https://arxiv.org/pdf/1506.01186.pdf for details
|
||||
|
||||
See https://arxiv.org/pdf/1506.01186.pdf for details.
|
||||
"""
|
||||
|
||||
def __init__(self, args, optimizer):
|
||||
|
@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
|
||||
args.update_freq = eval_str_list(args.update_freq, type=int)
|
||||
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
|
||||
args.max_sentences_valid = args.max_sentences
|
||||
if getattr(args, 'memory_efficient_fp16', False):
|
||||
args.fp16 = True
|
||||
|
||||
# Apply architecture configuration.
|
||||
if hasattr(args, 'arch'):
|
||||
@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'):
|
||||
choices=['json', 'none', 'simple', 'tqdm'])
|
||||
parser.add_argument('--seed', default=1, type=int, metavar='N',
|
||||
help='pseudo random number generator seed')
|
||||
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
|
||||
parser.add_argument('--fp16', action='store_true', help='use FP16')
|
||||
parser.add_argument('--memory-efficient-fp16', action='store_true',
|
||||
help='use a memory-efficient version of FP16 training; implies --fp16')
|
||||
parser.add_argument('--fp16-init-scale', default=2**7, type=int,
|
||||
help='default FP16 loss scale')
|
||||
parser.add_argument('--fp16-scale-window', type=int,
|
||||
@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'):
|
||||
def add_dataset_args(parser, train=False, gen=False):
|
||||
group = parser.add_argument_group('Dataset and data loading')
|
||||
# fmt: off
|
||||
group.add_argument('--num-workers', default=0, type=int, metavar='N',
|
||||
help='how many subprocesses to use for data loading')
|
||||
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
|
||||
help='ignore too long or too short lines in valid and test set')
|
||||
group.add_argument('--max-tokens', type=int, metavar='N',
|
||||
@ -178,7 +185,7 @@ def add_distributed_training_args(parser):
|
||||
group = parser.add_argument_group('Distributed training')
|
||||
# fmt: off
|
||||
group.add_argument('--distributed-world-size', type=int, metavar='N',
|
||||
default=torch.cuda.device_count(),
|
||||
default=max(1, torch.cuda.device_count()),
|
||||
help='total number of GPUs across all nodes (default: all visible GPUs)')
|
||||
group.add_argument('--distributed-rank', default=0, type=int,
|
||||
help='rank of the current worker')
|
||||
@ -189,7 +196,7 @@ def add_distributed_training_args(parser):
|
||||
'establish initial connetion')
|
||||
group.add_argument('--distributed-port', default=-1, type=int,
|
||||
help='port number (not required if using --distributed-init-method)')
|
||||
group.add_argument('--device-id', default=0, type=int,
|
||||
group.add_argument('--device-id', '--local_rank', default=0, type=int,
|
||||
help='which GPU to use (usually configured automatically)')
|
||||
group.add_argument('--ddp-backend', default='c10d', type=str,
|
||||
choices=['c10d', 'no_c10d'],
|
||||
@ -197,8 +204,8 @@ def add_distributed_training_args(parser):
|
||||
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
|
||||
help='bucket size for reduction')
|
||||
group.add_argument('--fix-batches-to-gpus', action='store_true',
|
||||
help='Don\'t shuffle batches between GPUs, this reduces overall '
|
||||
'randomness and may affect precision but avoids the cost of'
|
||||
help='don\'t shuffle batches between GPUs; this reduces overall '
|
||||
'randomness and may affect precision but avoids the cost of '
|
||||
're-reading the data')
|
||||
# fmt: on
|
||||
return group
|
||||
@ -263,7 +270,9 @@ def add_checkpoint_args(parser):
|
||||
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
|
||||
help='save a checkpoint (and validate) every N updates')
|
||||
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
|
||||
help='keep last N checkpoints saved with --save-interval-updates')
|
||||
help='keep the last N checkpoints saved with --save-interval-updates')
|
||||
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
|
||||
help='keep last N epoch checkpoints')
|
||||
group.add_argument('--no-save', action='store_true',
|
||||
help='don\'t save models or checkpoints')
|
||||
group.add_argument('--no-epoch-checkpoints', action='store_true',
|
||||
@ -280,11 +289,11 @@ def add_common_eval_args(group):
|
||||
help='path(s) to model file(s), colon separated')
|
||||
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
|
||||
help='remove BPE tokens before scoring')
|
||||
group.add_argument('--cpu', action='store_true', help='generate on CPU')
|
||||
group.add_argument('--quiet', action='store_true',
|
||||
help='only print final scores')
|
||||
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
|
||||
help='a dictionary used to override model args at generation that were used during model training')
|
||||
help='a dictionary used to override model args at generation '
|
||||
'that were used during model training')
|
||||
# fmt: on
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ class SequenceGenerator(object):
|
||||
match_source_len=False, no_repeat_ngram_size=0
|
||||
):
|
||||
"""Generates translations of a given source sentence.
|
||||
|
||||
Args:
|
||||
beam_size (int, optional): beam width (default: 1)
|
||||
min/maxlen (int, optional): the length of the generated output will
|
||||
@ -90,11 +91,14 @@ class SequenceGenerator(object):
|
||||
cuda=False, timer=None, prefix_size=0,
|
||||
):
|
||||
"""Iterate over a batched dataset and yield individual translations.
|
||||
|
||||
Args:
|
||||
maxlen_a/b: generate sequences of maximum length ax + b,
|
||||
where x is the source sentence length.
|
||||
cuda: use GPU for generation
|
||||
timer: StopwatchMeter for timing generations.
|
||||
maxlen_a/b (int, optional): generate sequences of maximum length
|
||||
``ax + b``, where ``x`` is the source sentence length.
|
||||
cuda (bool, optional): use GPU for generation
|
||||
timer (StopwatchMeter, optional): time generations
|
||||
prefix_size (int, optional): prefill the generation with the gold
|
||||
prefix up to this length.
|
||||
"""
|
||||
if maxlen_b is None:
|
||||
maxlen_b = self.maxlen
|
||||
@ -132,12 +136,13 @@ class SequenceGenerator(object):
|
||||
"""Generate a batch of translations.
|
||||
|
||||
Args:
|
||||
encoder_input: dictionary containing the inputs to
|
||||
model.encoder.forward
|
||||
beam_size: int overriding the beam size. defaults to
|
||||
self.beam_size
|
||||
max_len: maximum length of the generated sequence
|
||||
prefix_tokens: force decoder to begin with these tokens
|
||||
encoder_input (dict): dictionary containing the inputs to
|
||||
*model.encoder.forward*.
|
||||
beam_size (int, optional): overriding the beam size
|
||||
(default: *self.beam_size*).
|
||||
max_len (int, optional): maximum length of the generated sequence
|
||||
prefix_tokens (LongTensor, optional): force decoder to begin with
|
||||
these tokens
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
|
||||
|
@ -87,4 +87,3 @@ class SequenceScorer(object):
|
||||
index=sample['target'].data.unsqueeze(-1),
|
||||
)
|
||||
return avg_probs.squeeze(2), avg_attn
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import argparse
|
||||
|
@ -61,29 +61,32 @@ class FairseqTask(object):
|
||||
def get_batch_iterator(
|
||||
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
|
||||
ignore_invalid_inputs=False, required_batch_size_multiple=1,
|
||||
seed=1, num_shards=1, shard_id=0,
|
||||
seed=1, num_shards=1, shard_id=0, num_workers=0,
|
||||
):
|
||||
"""
|
||||
Get an iterator that yields batches of data from the given dataset.
|
||||
|
||||
Args:
|
||||
dataset (~fairseq.data.FairseqDataset): dataset to batch
|
||||
max_tokens (int, optional): max number of tokens in each batch.
|
||||
Default: ``None``
|
||||
max_tokens (int, optional): max number of tokens in each batch
|
||||
(default: None).
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch. Default: ``None``
|
||||
batch (default: None).
|
||||
max_positions (optional): max sentence length supported by the
|
||||
model. Default: ``None``
|
||||
model (default: None).
|
||||
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
||||
sentences that are too long. Default: ``False``
|
||||
sentences that are too long (default: False).
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be a multiple of N. Default: ``1``
|
||||
be a multiple of N (default: 1).
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility. Default: ``1``
|
||||
reproducibility (default: 1).
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards. Default: ``1``
|
||||
shards (default: 1).
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return. Default: ``0``
|
||||
return (default: 0).
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means the data will be loaded in the main process
|
||||
(default: 0).
|
||||
|
||||
Returns:
|
||||
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
|
||||
@ -114,6 +117,7 @@ class FairseqTask(object):
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
def build_model(self, args):
|
||||
|
@ -10,9 +10,15 @@ import numpy as np
|
||||
import os
|
||||
|
||||
from fairseq.data import (
|
||||
ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
|
||||
MonolingualDataset, TokenBlockDataset, TruncatedDictionary,
|
||||
IndexedCachedDataset, IndexedDataset)
|
||||
ConcatDataset,
|
||||
Dictionary,
|
||||
IndexedCachedDataset,
|
||||
IndexedDataset,
|
||||
IndexedRawTextDataset,
|
||||
MonolingualDataset,
|
||||
TokenBlockDataset,
|
||||
TruncatedDictionary,
|
||||
)
|
||||
|
||||
from . import FairseqTask, register_task
|
||||
|
||||
@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask):
|
||||
'If set to "eos", includes only one sentence per sample.')
|
||||
parser.add_argument('--tokens-per-sample', default=1024, type=int,
|
||||
help='max number of tokens per sample for LM dataset')
|
||||
parser.add_argument('--lazy-load', action='store_true',
|
||||
help='load the dataset lazily')
|
||||
parser.add_argument('--raw-text', default=False, action='store_true',
|
||||
help='load raw text dataset')
|
||||
parser.add_argument('--output-dictionary-size', default=-1, type=int,
|
||||
@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask):
|
||||
if self.args.raw_text and IndexedRawTextDataset.exists(path):
|
||||
ds = IndexedRawTextDataset(path, self.dictionary)
|
||||
elif not self.args.raw_text and IndexedDataset.exists(path):
|
||||
ds = IndexedDataset(path, fix_lua_indexing=True)
|
||||
if self.args.lazy_load:
|
||||
ds = IndexedDataset(path, fix_lua_indexing=True)
|
||||
else:
|
||||
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
|
||||
else:
|
||||
if k > 0:
|
||||
break
|
||||
@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask):
|
||||
|
||||
loaded_datasets.append(
|
||||
TokenBlockDataset(
|
||||
ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
|
||||
ds, ds.sizes, self.args.tokens_per_sample,
|
||||
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
|
||||
break_mode=self.args.sample_break_mode, include_targets=True,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
|
||||
|
||||
|
@ -12,8 +12,12 @@ import torch
|
||||
|
||||
from fairseq import options
|
||||
from fairseq.data import (
|
||||
Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
|
||||
IndexedRawTextDataset, RoundRobinZipDatasets,
|
||||
Dictionary,
|
||||
IndexedCachedDataset,
|
||||
IndexedDataset,
|
||||
IndexedRawTextDataset,
|
||||
LanguagePairDataset,
|
||||
RoundRobinZipDatasets,
|
||||
)
|
||||
from fairseq.models import FairseqMultiModel
|
||||
|
||||
@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask):
|
||||
help='source language (only needed for inference)')
|
||||
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
|
||||
help='target language (only needed for inference)')
|
||||
parser.add_argument('--lazy-load', action='store_true',
|
||||
help='load the dataset lazily')
|
||||
parser.add_argument('--raw-text', action='store_true',
|
||||
help='load raw text dataset')
|
||||
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
|
||||
@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask):
|
||||
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
|
||||
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
|
||||
return True
|
||||
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
|
||||
elif not self.args.raw_text and IndexedDataset.exists(filename):
|
||||
return True
|
||||
return False
|
||||
|
||||
def indexed_dataset(path, dictionary):
|
||||
if self.args.raw_text:
|
||||
return IndexedRawTextDataset(path, dictionary)
|
||||
elif IndexedInMemoryDataset.exists(path):
|
||||
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
|
||||
elif IndexedDataset.exists(path):
|
||||
if self.args.lazy_load:
|
||||
return IndexedDataset(path, fix_lua_indexing=True)
|
||||
else:
|
||||
return IndexedCachedDataset(path, fix_lua_indexing=True)
|
||||
return None
|
||||
|
||||
def sort_lang_pair(lang_pair):
|
||||
|
@ -6,13 +6,17 @@
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq.data import (
|
||||
data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
|
||||
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
|
||||
ConcatDataset,
|
||||
data_utils,
|
||||
Dictionary,
|
||||
IndexedCachedDataset,
|
||||
IndexedDataset,
|
||||
IndexedRawTextDataset,
|
||||
LanguagePairDataset,
|
||||
)
|
||||
|
||||
from . import FairseqTask, register_task
|
||||
@ -49,6 +53,8 @@ class TranslationTask(FairseqTask):
|
||||
help='source language')
|
||||
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
|
||||
help='target language')
|
||||
parser.add_argument('--lazy-load', action='store_true',
|
||||
help='load the dataset lazily')
|
||||
parser.add_argument('--raw-text', action='store_true',
|
||||
help='load raw text dataset')
|
||||
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
|
||||
@ -132,7 +138,10 @@ class TranslationTask(FairseqTask):
|
||||
if self.args.raw_text:
|
||||
return IndexedRawTextDataset(path, dictionary)
|
||||
elif IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path, fix_lua_indexing=True)
|
||||
if self.args.lazy_load:
|
||||
return IndexedDataset(path, fix_lua_indexing=True)
|
||||
else:
|
||||
return IndexedCachedDataset(path, fix_lua_indexing=True)
|
||||
return None
|
||||
|
||||
src_datasets = []
|
||||
|
@ -6,10 +6,11 @@
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from collections import Counter
|
||||
import os, re
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
SPACE_NORMALIZER = re.compile(r"\s+")
|
||||
|
||||
@ -27,7 +28,8 @@ def safe_readline(f):
|
||||
return f.readline()
|
||||
except UnicodeDecodeError:
|
||||
pos -= 1
|
||||
f.seek(pos) # search where this character begins
|
||||
f.seek(pos) # search where this character begins
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
|
||||
@ -41,7 +43,7 @@ class Tokenizer:
|
||||
end = offset + chunk_size
|
||||
f.seek(offset)
|
||||
if offset > 0:
|
||||
safe_readline(f) # drop first incomplete line
|
||||
safe_readline(f) # drop first incomplete line
|
||||
line = f.readline()
|
||||
while line:
|
||||
for word in tokenize(line):
|
||||
@ -73,14 +75,17 @@ class Tokenizer:
|
||||
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
|
||||
|
||||
@staticmethod
|
||||
def binarize(filename, dict, consumer, tokenize=tokenize_line,
|
||||
append_eos=True, reverse_order=False,
|
||||
offset=0, end=-1):
|
||||
def binarize(
|
||||
filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
|
||||
reverse_order=False, offset=0, end=-1,
|
||||
):
|
||||
nseq, ntok = 0, 0
|
||||
replaced = Counter()
|
||||
|
||||
def replaced_consumer(word, idx):
|
||||
if idx == dict.unk_index and word != dict.unk_word:
|
||||
replaced.update([word])
|
||||
|
||||
with open(filename, 'r') as f:
|
||||
f.seek(offset)
|
||||
# next(f) breaks f.tell(), hence readline() must be used
|
||||
|
@ -30,22 +30,23 @@ class Trainer(object):
|
||||
"""
|
||||
|
||||
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise NotImplementedError('Training on CPU is not supported')
|
||||
|
||||
self.args = args
|
||||
self.task = task
|
||||
|
||||
# copy model and criterion to current device
|
||||
self.criterion = criterion.cuda()
|
||||
self.criterion = criterion
|
||||
self._model = model
|
||||
self.cuda = torch.cuda.is_available() and not args.cpu
|
||||
if args.fp16:
|
||||
self._model = model.half().cuda()
|
||||
else:
|
||||
self._model = model.cuda()
|
||||
self._model = self._model.half()
|
||||
if self.cuda:
|
||||
self.criterion = self.criterion.cuda()
|
||||
self._model = self._model.cuda()
|
||||
|
||||
self._dummy_batch = dummy_batch
|
||||
self._oom_batch = oom_batch
|
||||
|
||||
self._lr_scheduler = None
|
||||
self._num_updates = 0
|
||||
self._optim_history = None
|
||||
self._optimizer = None
|
||||
@ -71,7 +72,6 @@ class Trainer(object):
|
||||
self.meters['wall'] = TimeMeter() # wall time in seconds
|
||||
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
|
||||
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
if self._wrapped_model is None:
|
||||
@ -89,19 +89,26 @@ class Trainer(object):
|
||||
self._build_optimizer()
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def lr_scheduler(self):
|
||||
if self._lr_scheduler is None:
|
||||
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
|
||||
return self._lr_scheduler
|
||||
|
||||
def _build_optimizer(self):
|
||||
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
|
||||
if self.args.fp16:
|
||||
if torch.cuda.get_device_capability(0)[0] < 7:
|
||||
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
|
||||
print('| WARNING: your device does NOT support faster training with --fp16, '
|
||||
'please switch to FP32 which is likely to be faster')
|
||||
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
|
||||
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
|
||||
if self.args.memory_efficient_fp16:
|
||||
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
|
||||
else:
|
||||
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
|
||||
else:
|
||||
if torch.cuda.get_device_capability(0)[0] >= 7:
|
||||
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
|
||||
print('| NOTICE: your device may support faster training with --fp16')
|
||||
self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
|
||||
|
||||
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
|
||||
self._optimizer = optim.build_optimizer(self.args, params)
|
||||
|
||||
def save_checkpoint(self, filename, extra_state):
|
||||
"""Save all training state in a checkpoint file."""
|
||||
@ -151,7 +158,8 @@ class Trainer(object):
|
||||
# reproducible results when resuming from checkpoints
|
||||
seed = self.args.seed + self.get_num_updates()
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if self.cuda:
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
self.model.train()
|
||||
self.zero_grad()
|
||||
@ -296,7 +304,8 @@ class Trainer(object):
|
||||
for p in self.model.parameters():
|
||||
if p.grad is not None:
|
||||
del p.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
if self.cuda:
|
||||
torch.cuda.empty_cache()
|
||||
return self.valid_step(sample, raise_oom=True)
|
||||
else:
|
||||
raise e
|
||||
@ -377,4 +386,6 @@ class Trainer(object):
|
||||
def _prepare_sample(self, sample):
|
||||
if sample is None or len(sample) == 0:
|
||||
return None
|
||||
return utils.move_to_cuda(sample)
|
||||
if self.cuda:
|
||||
sample = utils.move_to_cuda(sample)
|
||||
return sample
|
||||
|
@ -378,6 +378,14 @@ def item(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
def clip_grad_norm_(tensor, max_norm):
|
||||
grad_norm = item(torch.norm(tensor))
|
||||
if grad_norm > max_norm > 0:
|
||||
clip_coef = max_norm / (grad_norm + 1e-6)
|
||||
tensor.mul_(clip_coef)
|
||||
return grad_norm
|
||||
|
||||
|
||||
def fill_with_neg_inf(t):
|
||||
"""FP16-compatible function that fills a tensor with -inf."""
|
||||
return t.float().fill_(float('-inf')).type_as(t)
|
||||
|
18
fb_train.py
Normal file
18
fb_train.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import torch.fb.rendezvous.zeus # noqa: F401
|
||||
|
||||
from fairseq import options
|
||||
|
||||
from train import main
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
@ -11,7 +11,7 @@ Translate pre-processed data with a trained model.
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
|
||||
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
|
||||
from fairseq.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.sequence_generator import SequenceGenerator
|
||||
from fairseq.sequence_scorer import SequenceScorer
|
||||
@ -41,7 +41,9 @@ def main(args):
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(args.path))
|
||||
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
|
||||
models, _model_args = utils.load_ensemble_for_inference(
|
||||
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
|
||||
)
|
||||
|
||||
# Optimize ensemble for generation
|
||||
for model in models:
|
||||
@ -69,6 +71,7 @@ def main(args):
|
||||
required_batch_size_multiple=8,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
|
||||
# Initialize generator
|
||||
|
@ -75,8 +75,9 @@ def main(args):
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(args.path))
|
||||
model_paths = args.path.split(':')
|
||||
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
|
||||
models, _model_args = utils.load_ensemble_for_inference(
|
||||
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
|
||||
)
|
||||
|
||||
# Set dictionaries
|
||||
tgt_dict = task.target_dictionary
|
||||
|
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python3 -u
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import os
|
||||
import signal
|
||||
import torch
|
||||
|
||||
from fairseq import distributed_utils, options
|
||||
|
||||
from train import main as single_process_main
|
||||
|
||||
|
||||
def main(args):
|
||||
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
|
||||
print('| WARNING: when using --update-freq on a single machine, you '
|
||||
'will get better performance with --ddp-backend=no_c10d')
|
||||
|
||||
mp = torch.multiprocessing.get_context('spawn')
|
||||
|
||||
# Create a thread to listen for errors in the child processes.
|
||||
error_queue = mp.SimpleQueue()
|
||||
error_handler = ErrorHandler(error_queue)
|
||||
|
||||
# Train with multiprocessing.
|
||||
procs = []
|
||||
base_rank = args.distributed_rank
|
||||
for i in range(torch.cuda.device_count()):
|
||||
args.distributed_rank = base_rank + i
|
||||
args.device_id = i
|
||||
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
|
||||
procs[i].start()
|
||||
error_handler.add_child(procs[i].pid)
|
||||
for p in procs:
|
||||
p.join()
|
||||
|
||||
|
||||
def run(args, error_queue):
|
||||
try:
|
||||
args.distributed_rank = distributed_utils.distributed_init(args)
|
||||
single_process_main(args)
|
||||
except KeyboardInterrupt:
|
||||
pass # killed by parent, do nothing
|
||||
except Exception:
|
||||
# propagate exception to parent process, keeping original traceback
|
||||
import traceback
|
||||
error_queue.put((args.distributed_rank, traceback.format_exc()))
|
||||
|
||||
|
||||
class ErrorHandler(object):
|
||||
"""A class that listens for exceptions in children processes and propagates
|
||||
the tracebacks to the parent process."""
|
||||
|
||||
def __init__(self, error_queue):
|
||||
import signal
|
||||
import threading
|
||||
self.error_queue = error_queue
|
||||
self.children_pids = []
|
||||
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
|
||||
self.error_thread.start()
|
||||
signal.signal(signal.SIGUSR1, self.signal_handler)
|
||||
|
||||
def add_child(self, pid):
|
||||
self.children_pids.append(pid)
|
||||
|
||||
def error_listener(self):
|
||||
(rank, original_trace) = self.error_queue.get()
|
||||
self.error_queue.put((rank, original_trace))
|
||||
os.kill(os.getpid(), signal.SIGUSR1)
|
||||
|
||||
def signal_handler(self, signalnum, stackframe):
|
||||
for pid in self.children_pids:
|
||||
os.kill(pid, signal.SIGINT) # kill children processes
|
||||
(rank, original_trace) = self.error_queue.get()
|
||||
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
|
||||
msg += original_trace
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
@ -18,7 +18,7 @@ import shutil
|
||||
|
||||
from fairseq.data import indexed_dataset, dictionary
|
||||
from fairseq.tokenizer import Tokenizer, tokenize_line
|
||||
from multiprocessing import Pool, Manager, Process
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase):
|
||||
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
|
||||
generate_main(data_dir)
|
||||
|
||||
def test_memory_efficient_fp16(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_translation_data(data_dir)
|
||||
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16'])
|
||||
generate_main(data_dir)
|
||||
|
||||
def test_update_freq(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
|
||||
@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase):
|
||||
data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
|
||||
)
|
||||
self.assertTrue(
|
||||
'skip this example with --skip-invalid-size-inputs-valid-test' \
|
||||
in str(context.exception)
|
||||
'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception)
|
||||
)
|
||||
train_translation_model(
|
||||
data_dir, 'fconv_iwslt_de_en',
|
||||
|
@ -12,10 +12,6 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import options
|
||||
|
||||
from . import test_binaries
|
||||
|
||||
|
||||
@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase):
|
||||
'--fp16-init-scale', '4096',
|
||||
])
|
||||
|
||||
def test_reproducibility_memory_efficient_fp16(self):
|
||||
self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
|
||||
'--memory-efficient-fp16',
|
||||
'--fp16-init-scale', '4096',
|
||||
])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -39,8 +39,10 @@ def mock_dict():
|
||||
|
||||
|
||||
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
|
||||
tokens = torch.LongTensor(list(range(epoch_size)))
|
||||
tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False)
|
||||
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
|
||||
tokens_ds = data.TokenBlockDataset(
|
||||
tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
|
||||
)
|
||||
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
|
||||
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
|
||||
epoch_itr = data.EpochBatchIterator(
|
||||
@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
|
||||
[p.start() for p in self.applied_patches]
|
||||
|
||||
|
||||
def test_load_partial_checkpoint(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
|
||||
|
62
train.py
62
train.py
@ -28,9 +28,8 @@ def main(args):
|
||||
args.max_tokens = 6000
|
||||
print(args)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise NotImplementedError('Training on CPU is not supported')
|
||||
torch.cuda.set_device(args.device_id)
|
||||
if torch.cuda.is_available() and not args.cpu:
|
||||
torch.cuda.set_device(args.device_id)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Setup task, e.g., translation, language modeling, etc.
|
||||
@ -74,6 +73,7 @@ def main(args):
|
||||
seed=args.seed,
|
||||
num_shards=args.distributed_world_size,
|
||||
shard_id=args.distributed_rank,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
# Load the latest checkpoint if one is available
|
||||
@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
|
||||
seed=args.seed,
|
||||
num_shards=args.distributed_world_size,
|
||||
shard_id=args.distributed_rank,
|
||||
num_workers=args.num_workers,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
progress = progress_bar.build_progress_bar(
|
||||
args, itr, epoch_itr.epoch,
|
||||
@ -306,7 +307,15 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
# remove old checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
|
||||
for old_chk in checkpoints[args.keep_interval_updates:]:
|
||||
os.remove(old_chk)
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
if args.keep_last_epochs > 0:
|
||||
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt')
|
||||
for old_chk in checkpoints[args.keep_last_epochs:]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
|
||||
def load_checkpoint(args, trainer, epoch_itr):
|
||||
@ -346,23 +355,50 @@ def load_dataset_splits(task, splits):
|
||||
raise e
|
||||
|
||||
|
||||
def distributed_main(i, args):
|
||||
import socket
|
||||
args.device_id = i
|
||||
if args.distributed_rank is None: # torch.multiprocessing.spawn
|
||||
args.distributed_rank = i
|
||||
args.distributed_rank = distributed_utils.distributed_init(args)
|
||||
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
|
||||
if args.distributed_port > 0 or args.distributed_init_method is not None:
|
||||
from distributed_train import main as distributed_main
|
||||
if args.distributed_init_method is None:
|
||||
distributed_utils.infer_init_method(args)
|
||||
|
||||
distributed_main(args)
|
||||
if args.distributed_init_method is not None:
|
||||
# distributed training
|
||||
distributed_main(args.device_id, args)
|
||||
args.distributed_rank = distributed_utils.distributed_init(args)
|
||||
main(args)
|
||||
elif args.distributed_world_size > 1:
|
||||
from multiprocessing_train import main as multiprocessing_main
|
||||
|
||||
# Set distributed training parameters for a single node.
|
||||
args.distributed_world_size = torch.cuda.device_count()
|
||||
# fallback for single node with multiple GPUs
|
||||
port = random.randint(10000, 20000)
|
||||
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
|
||||
args.distributed_port = port + 1
|
||||
args.distributed_rank = None # set based on device id
|
||||
print(
|
||||
'''| NOTE: you may get better performance with:
|
||||
|
||||
multiprocessing_main(args)
|
||||
python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...)
|
||||
'''.format(
|
||||
ngpu=args.distributed_world_size,
|
||||
no_c10d=(
|
||||
'--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d'
|
||||
else ''
|
||||
),
|
||||
)
|
||||
)
|
||||
torch.multiprocessing.spawn(
|
||||
fn=distributed_main,
|
||||
args=(args, ),
|
||||
nprocs=args.distributed_world_size,
|
||||
)
|
||||
else:
|
||||
# single GPU training
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue
Block a user