mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Clean up FairseqTask so that it's easier to extend/add new tasks
This commit is contained in:
parent
6296de825a
commit
2e507d3cb4
@ -59,11 +59,13 @@ def main(parsed_args):
|
||||
|
||||
assert len(models) > 0
|
||||
|
||||
itr = data.EpochBatchIterator(
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens or 36000,
|
||||
max_sentences=args.max_sentences,
|
||||
max_positions=models[0].max_positions(),
|
||||
max_positions=utils.resolve_max_positions(*[
|
||||
model.max_positions() for model in models
|
||||
]),
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
ignore_invalid_inputs=True,
|
||||
|
@ -12,8 +12,6 @@ import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
def infer_language_pair(path):
|
||||
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
||||
@ -99,42 +97,35 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
|
||||
|
||||
|
||||
class EpochBatchIterator(object):
|
||||
"""Iterate over a FairseqDataset and yield batches bucketed by size.
|
||||
"""A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
|
||||
|
||||
Batches may contain sequences of different lengths. This iterator can be
|
||||
reused across multiple epochs with the next_epoch_itr() method.
|
||||
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
|
||||
|
||||
- can be reused across multiple epochs with the :func:`next_epoch_itr`
|
||||
method (optionally shuffled between epochs)
|
||||
- can be serialized/deserialized with the :func:`state_dict` and
|
||||
:func:`load_state_dict` methods
|
||||
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
|
||||
|
||||
Args:
|
||||
dataset: a FairseqDataset
|
||||
max_tokens: max number of tokens in each batch
|
||||
max_sentences: max number of sentences in each batch
|
||||
max_positions: max sentence length supported by the model
|
||||
ignore_invalid_inputs: don't raise Exception for sentences that are too long
|
||||
required_batch_size_multiple: require batch size to be a multiple of N
|
||||
seed: seed for random number generator for reproducibility
|
||||
num_shards: shard the data iterator into N shards
|
||||
shard_id: which shard of the data iterator to return
|
||||
dataset (Dataset): dataset from which to load the data
|
||||
batch_sampler (Sampler): an iterator over batches of indices
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility. Default: ``1``
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards. Default: ``1``
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return. Default: ``0``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
|
||||
ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1,
|
||||
num_shards=1, shard_id=0,
|
||||
):
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
def __init__(self, dataset, batch_sampler, seed=1, num_shards=1, shard_id=0):
|
||||
assert isinstance(dataset, torch.utils.data.Dataset)
|
||||
self.dataset = dataset
|
||||
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
||||
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
||||
self.max_positions = max_positions
|
||||
self.ignore_invalid_inputs = ignore_invalid_inputs
|
||||
self.bsz_mult = required_batch_size_multiple
|
||||
self.frozen_batches = tuple(batch_sampler)
|
||||
self.seed = seed
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
with numpy_seed(self.seed):
|
||||
self.frozen_batches = tuple(self._batch_generator())
|
||||
|
||||
self.epoch = 0
|
||||
self._cur_epoch_itr = None
|
||||
self._next_epoch_itr = None
|
||||
@ -143,7 +134,13 @@ class EpochBatchIterator(object):
|
||||
return len(self.frozen_batches)
|
||||
|
||||
def next_epoch_itr(self, shuffle=True):
|
||||
"""Shuffle batches and return a new iterator over the dataset."""
|
||||
"""
|
||||
Return a new iterator over the dataset.
|
||||
|
||||
Args:
|
||||
shuffle (bool, optional): shuffle batches before returning the
|
||||
iterator. Default: ``True``
|
||||
"""
|
||||
if self._next_epoch_itr is not None:
|
||||
self._cur_epoch_itr = self._next_epoch_itr
|
||||
self._next_epoch_itr = None
|
||||
@ -153,10 +150,12 @@ class EpochBatchIterator(object):
|
||||
return self._cur_epoch_itr
|
||||
|
||||
def end_of_epoch(self):
|
||||
"""Returns whether the most recent epoch iterator has been exhausted"""
|
||||
return not self._cur_epoch_itr.has_next()
|
||||
|
||||
@property
|
||||
def iterations_in_epoch(self):
|
||||
"""The number of consumed batches in the current epoch."""
|
||||
if self._cur_epoch_itr is not None:
|
||||
return self._cur_epoch_itr.count
|
||||
elif self._next_epoch_itr is not None:
|
||||
@ -193,55 +192,6 @@ class EpochBatchIterator(object):
|
||||
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
|
||||
))
|
||||
|
||||
def _batch_generator(self):
|
||||
batch = []
|
||||
|
||||
def is_batch_full(num_tokens):
|
||||
if len(batch) == 0:
|
||||
return False
|
||||
if len(batch) == self.max_sentences:
|
||||
return True
|
||||
if num_tokens > self.max_tokens:
|
||||
return True
|
||||
return False
|
||||
|
||||
sample_len = 0
|
||||
sample_lens = []
|
||||
ignored = []
|
||||
for idx in self.dataset.ordered_indices():
|
||||
if not self.dataset.valid_size(idx, self.max_positions):
|
||||
if self.ignore_invalid_inputs:
|
||||
ignored.append(idx)
|
||||
continue
|
||||
raise Exception((
|
||||
'Size of sample #{} is invalid, max_positions={}, skip this '
|
||||
'example with --skip-invalid-size-inputs-valid-test'
|
||||
).format(idx, self.max_positions))
|
||||
|
||||
sample_lens.append(self.dataset.num_tokens(idx))
|
||||
sample_len = max(sample_len, sample_lens[-1])
|
||||
num_tokens = (len(batch) + 1) * sample_len
|
||||
if is_batch_full(num_tokens):
|
||||
mod_len = max(
|
||||
self.bsz_mult * (len(batch) // self.bsz_mult),
|
||||
len(batch) % self.bsz_mult,
|
||||
)
|
||||
yield batch[:mod_len]
|
||||
batch = batch[mod_len:]
|
||||
sample_lens = sample_lens[mod_len:]
|
||||
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
|
||||
|
||||
batch.append(idx)
|
||||
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
if len(ignored) > 0:
|
||||
print((
|
||||
'| WARNING: {} samples have invalid sizes and will be skipped, '
|
||||
'max_positions={}, first few sample ids={}'
|
||||
).format(len(ignored), self.max_positions, ignored[:10]))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def numpy_seed(seed):
|
||||
@ -256,3 +206,112 @@ def numpy_seed(seed):
|
||||
yield
|
||||
finally:
|
||||
np.random.set_state(state)
|
||||
|
||||
|
||||
def collect_filtered(function, iterable, filtered):
|
||||
"""
|
||||
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
||||
|
||||
Args:
|
||||
function (callable): function that returns ``False`` for elements that
|
||||
should be filtered
|
||||
iterable (iterable): iterable to filter
|
||||
filtered (list): list to store filtered elements
|
||||
"""
|
||||
for el in iterable:
|
||||
if function(el):
|
||||
yield el
|
||||
else:
|
||||
filtered.append(el)
|
||||
|
||||
|
||||
def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
|
||||
"""
|
||||
Filter indices based on their size.
|
||||
|
||||
Args:
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
size_fn (callable): function that returns the size of a given index
|
||||
max_positions (tuple): filter elements larger than this size.
|
||||
Comparisons are done component-wise.
|
||||
raise_exception (bool, optional): if ``True``, raise an exception
|
||||
if any elements are filtered. Default: ``False``
|
||||
"""
|
||||
def check_size(idx):
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
return size_fn(idx) < max_positions
|
||||
else:
|
||||
return all(a <= b for a, b in zip(size_fn(idx), max_positions))
|
||||
|
||||
ignored = []
|
||||
itr = collect_filtered(check_size, indices, ignored)
|
||||
for idx in itr:
|
||||
if len(ignored) > 0 and raise_exception:
|
||||
raise Exception((
|
||||
'Size of sample #{} is invalid (={}) since max_positions={}, '
|
||||
'skip this example with --skip-invalid-size-inputs-valid-test'
|
||||
).format(idx, self.size(idx), max_positions))
|
||||
yield idx
|
||||
|
||||
if len(ignored) > 0:
|
||||
print((
|
||||
'| WARNING: {} samples have invalid sizes and will be skipped, '
|
||||
'max_positions={}, first few sample ids={}'
|
||||
).format(len(ignored), max_positions, ignored[:10]))
|
||||
|
||||
|
||||
def batch_by_size(
|
||||
indices, num_tokens_fn, max_tokens=None, max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
):
|
||||
"""
|
||||
Yield mini-batches of indices bucketed by size. Batches may contain
|
||||
sequences of different lengths.
|
||||
|
||||
Args:
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
num_tokens_fn (callable): function that returns the number of tokens at
|
||||
a given index
|
||||
max_tokens (int, optional): max number of tokens in each batch.
|
||||
Default: ``None``
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch. Default: ``None``
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be a multiple of N. Default: ``1``
|
||||
"""
|
||||
max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
||||
max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
||||
bsz_mult = required_batch_size_multiple
|
||||
|
||||
batch = []
|
||||
|
||||
def is_batch_full(num_tokens):
|
||||
if len(batch) == 0:
|
||||
return False
|
||||
if len(batch) == max_sentences:
|
||||
return True
|
||||
if num_tokens > max_tokens:
|
||||
return True
|
||||
return False
|
||||
|
||||
sample_len = 0
|
||||
sample_lens = []
|
||||
ignored = []
|
||||
for idx in indices:
|
||||
sample_lens.append(num_tokens_fn(idx))
|
||||
sample_len = max(sample_len, sample_lens[-1])
|
||||
num_tokens = (len(batch) + 1) * sample_len
|
||||
if is_batch_full(num_tokens):
|
||||
mod_len = max(
|
||||
bsz_mult * (len(batch) // bsz_mult),
|
||||
len(batch) % bsz_mult,
|
||||
)
|
||||
yield batch[:mod_len]
|
||||
batch = batch[mod_len:]
|
||||
sample_lens = sample_lens[mod_len:]
|
||||
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
|
||||
|
||||
batch.append(idx)
|
||||
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
@ -7,6 +7,8 @@
|
||||
|
||||
import torch.utils.data
|
||||
|
||||
from fairseq.data import data_utils
|
||||
|
||||
|
||||
class FairseqDataset(torch.utils.data.Dataset):
|
||||
"""A dataset that provides helpers for batching."""
|
||||
@ -18,7 +20,14 @@ class FairseqDataset(torch.utils.data.Dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch."""
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[int]): sample indices to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_batch(self, num_tokens, max_positions):
|
||||
@ -26,13 +35,16 @@ class FairseqDataset(torch.utils.data.Dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return an example's length (number of tokens), used for batching."""
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
raise NotImplementedError
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Ordered indices for batching."""
|
||||
raise NotImplementedError
|
||||
|
||||
def valid_size(self, index, max_positions):
|
||||
"""Check if an example's size is valid according to max_positions."""
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
raise NotImplementedError
|
||||
|
@ -8,6 +8,8 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq import utils
|
||||
|
||||
from . import data_utils, FairseqDataset
|
||||
|
||||
|
||||
@ -59,7 +61,27 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
|
||||
|
||||
|
||||
class LanguagePairDataset(FairseqDataset):
|
||||
"""A pair of torch.utils.data.Datasets."""
|
||||
"""
|
||||
A pair of torch.utils.data.Datasets.
|
||||
|
||||
Args:
|
||||
src (torch.utils.data.Dataset): source dataset to wrap
|
||||
src_sizes (List[int]): source sentence lengths
|
||||
src_dict (fairseq.data.Dictionary): source vocabulary
|
||||
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
|
||||
tgt_sizes (List[int], optional): target sentence lengths
|
||||
tgt_dict (fairseq.data.Dictionary, optional): target vocabulary
|
||||
left_pad_source (bool, optional): pad source tensors on the left side.
|
||||
Default: ``True``
|
||||
left_pad_target (bool, optional): pad target tensors on the left side.
|
||||
Default: ``False``
|
||||
max_source_positions (int, optional): max number of tokens in the source
|
||||
sentence. Default: ``1024``
|
||||
max_target_positions (int, optional): max number of tokens in the target
|
||||
sentence. Default: ``1024``
|
||||
shuffle (bool, optional): shuffle dataset elements before batching.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, src, src_sizes, src_dict,
|
||||
@ -95,15 +117,43 @@ class LanguagePairDataset(FairseqDataset):
|
||||
return len(self.src)
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch."""
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Returned mini-batches contain the following keys:
|
||||
- `id` (torch.LongTensor): example IDs in the original input order
|
||||
- `ntokens` (int): total number of tokens in the batch
|
||||
- `net_input` (dict): the input to the Model, containing keys:
|
||||
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
|
||||
the source sentence of shape `(bsz, src_len)`. Padding will appear
|
||||
on the left if ``left_pad_source`` is True.
|
||||
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
|
||||
of each source sentence of shape `(bsz)`
|
||||
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
|
||||
tokens in the target sentence, shifted right by one position for
|
||||
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding
|
||||
will appear on the left if ``left_pad_target`` is True.
|
||||
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
|
||||
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
|
||||
left if ``left_pad_target`` is True.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
return collate(
|
||||
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
|
||||
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
|
||||
)
|
||||
|
||||
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
|
||||
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
|
||||
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
|
||||
"""Return a dummy batch with a given number of tokens."""
|
||||
src_len, tgt_len = utils.resolve_max_positions(
|
||||
(src_len, tgt_len),
|
||||
max_positions,
|
||||
(self.max_source_positions, self.max_target_positions),
|
||||
)
|
||||
bsz = num_tokens // max(src_len, tgt_len)
|
||||
return self.collater([
|
||||
{
|
||||
@ -115,11 +165,18 @@ class LanguagePairDataset(FairseqDataset):
|
||||
])
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return an example's length (number of tokens), used for batching."""
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Ordered indices for batching."""
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self))
|
||||
else:
|
||||
@ -127,18 +184,3 @@ class LanguagePairDataset(FairseqDataset):
|
||||
if self.tgt_sizes is not None:
|
||||
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
|
||||
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
|
||||
|
||||
def valid_size(self, index, max_positions):
|
||||
"""Check if an example's size is valid according to max_positions."""
|
||||
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
|
||||
return (
|
||||
self.src_sizes[index] <= max_source_positions
|
||||
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
|
||||
)
|
||||
|
||||
def _get_max_positions(self, max_positions):
|
||||
if max_positions is None:
|
||||
return self.max_source_positions, self.max_target_positions
|
||||
assert len(max_positions) == 2
|
||||
max_src_pos, max_tgt_pos = max_positions
|
||||
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
|
||||
|
@ -31,7 +31,16 @@ def collate(samples, pad_idx, eos_idx):
|
||||
|
||||
|
||||
class MonolingualDataset(FairseqDataset):
|
||||
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
|
||||
"""
|
||||
A wrapper around torch.utils.data.Dataset for monolingual data.
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): dataset to wrap
|
||||
sizes (List[int]): sentence lengths
|
||||
vocab (fairseq.data.Dictionary): vocabulary
|
||||
shuffle (bool, optional): shuffle the elements before batching.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, sizes, vocab, shuffle):
|
||||
self.dataset = dataset
|
||||
@ -47,12 +56,31 @@ class MonolingualDataset(FairseqDataset):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch."""
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Returned mini-batches contain the following keys:
|
||||
- `id` (torch.LongTensor): example IDs in the original input order
|
||||
- `ntokens` (int): total number of tokens in the batch
|
||||
- `net_input` (dict): the input to the Model, containing keys:
|
||||
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
|
||||
the source sentence of shape `(bsz, src_len)`. Padding will appear
|
||||
on the right.
|
||||
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
|
||||
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
|
||||
right.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
return collate(samples, self.vocab.pad(), self.vocab.eos())
|
||||
|
||||
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
|
||||
assert isinstance(max_positions, float) or isinstance(max_positions, int)
|
||||
tgt_len = min(tgt_len, max_positions)
|
||||
"""Return a dummy batch with a given number of tokens."""
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
tgt_len = min(tgt_len, max_positions)
|
||||
bsz = num_tokens // tgt_len
|
||||
target = self.vocab.dummy_sentence(tgt_len + 1)
|
||||
source, target = target[:-1], target[1:]
|
||||
@ -62,19 +90,21 @@ class MonolingualDataset(FairseqDataset):
|
||||
])
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return an example's length (number of tokens), used for batching."""
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return self.sizes[index]
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Ordered indices for batching."""
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
order.append(np.flip(self.sizes, 0))
|
||||
return np.lexsort(order)
|
||||
|
||||
def valid_size(self, index, max_positions):
|
||||
"""Check if an example's size is valid according to max_positions."""
|
||||
assert isinstance(max_positions, float) or isinstance(max_positions, int)
|
||||
return self.sizes[index] <= max_positions
|
||||
|
@ -5,11 +5,13 @@
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from fairseq.data import data_utils, FairseqDataset
|
||||
|
||||
|
||||
class FairseqTask(object):
|
||||
"""
|
||||
A Task defines the data format, stores shared state (e.g., dictionaries) and
|
||||
provides helpers for building the model/criterion and calculating the loss.
|
||||
Tasks store dictionaries and provide helpers for loading/iterating over
|
||||
Datasets, initializing the Model/Criterion and calculating the loss.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -37,6 +39,62 @@ class FairseqTask(object):
|
||||
raise TypeError('Datasets are expected to be of type FairseqDataset')
|
||||
return self.datasets[split]
|
||||
|
||||
def get_batch_iterator(
|
||||
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
|
||||
ignore_invalid_inputs=False, required_batch_size_multiple=1,
|
||||
seed=1, num_shards=1, shard_id=0,
|
||||
):
|
||||
"""
|
||||
Generate batches of indices.
|
||||
|
||||
Args:
|
||||
dataset (FairseqDataset): dataset to batch
|
||||
max_tokens (int, optional): max number of tokens in each batch.
|
||||
Default: ``None``
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch. Default: ``None``
|
||||
max_positions (optional): max sentence length supported by the
|
||||
model. Default: ``None``
|
||||
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
||||
sentences that are too long. Default: ``False``
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be a multiple of N. Default: ``1``
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility. Default: ``1``
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards. Default: ``1``
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return. Default: ``0``
|
||||
|
||||
Returns:
|
||||
EpochBatchIterator: a batched iterator over the given dataset split
|
||||
"""
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
|
||||
# get indices ordered by example size
|
||||
with data_utils.numpy_seed(seed):
|
||||
indices = dataset.ordered_indices()
|
||||
|
||||
# filter examples that are too large
|
||||
indices = data_utils.filter_by_size(
|
||||
indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
|
||||
)
|
||||
|
||||
# create mini-batches with given size constraints
|
||||
batch_sampler = data_utils.batch_by_size(
|
||||
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
|
||||
# return a reusable, sharded iterator
|
||||
return data_utils.EpochBatchIterator(
|
||||
dataset=dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
def build_model(self, args):
|
||||
from fairseq import models
|
||||
return models.build_model(args, self)
|
||||
@ -48,6 +106,9 @@ class FairseqTask(object):
|
||||
def get_loss(self, model, criterion, sample):
|
||||
return criterion(model, sample)
|
||||
|
||||
def max_positions(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
raise NotImplementedError
|
||||
|
@ -139,6 +139,9 @@ class TranslationTask(FairseqTask):
|
||||
max_target_positions=self.args.max_target_positions,
|
||||
)
|
||||
|
||||
def max_positions(self):
|
||||
return (self.args.max_source_positions, self.args.max_target_positions)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.src_dict
|
||||
|
@ -150,7 +150,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
||||
ensemble = []
|
||||
for state in states:
|
||||
args = state['args']
|
||||
|
||||
|
||||
if model_arg_overrides is not None:
|
||||
args = _override_model_args(args, model_arg_overrides)
|
||||
|
||||
@ -399,3 +399,17 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
|
||||
idx = int(m.group(1)) if len(m.groups()) > 0 else i
|
||||
entries.append((idx, m.group(0)))
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def resolve_max_positions(*args):
|
||||
"""Resolve max position constraints from multiple sources."""
|
||||
max_positions = None
|
||||
for arg in args:
|
||||
if max_positions is None:
|
||||
max_positions = arg
|
||||
elif arg is not None:
|
||||
if isinstance(arg, float) or isinstance(arg, int):
|
||||
max_positions = min(max_positions, arg)
|
||||
else:
|
||||
max_positions = tuple(map(min, zip(max_positions, arg)))
|
||||
return max_positions
|
||||
|
@ -54,11 +54,14 @@ def main(args):
|
||||
align_dict = utils.load_align_dict(args.replace_unk)
|
||||
|
||||
# Load dataset (possibly sharded)
|
||||
itr = data.EpochBatchIterator(
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.max_sentences,
|
||||
max_positions=models[0].max_positions(),
|
||||
max_positions=utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
*[model.max_positions() for model in models]
|
||||
),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=8,
|
||||
num_shards=args.num_shards,
|
||||
|
@ -32,14 +32,14 @@ def buffered_read(buffer_size):
|
||||
yield buffer
|
||||
|
||||
|
||||
def make_batches(lines, args, src_dict, max_positions):
|
||||
def make_batches(lines, args, task, max_positions):
|
||||
tokens = [
|
||||
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
|
||||
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
|
||||
for src_str in lines
|
||||
]
|
||||
lengths = np.array([t.numel() for t in tokens])
|
||||
itr = data.EpochBatchIterator(
|
||||
dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.max_sentences,
|
||||
max_positions=max_positions,
|
||||
@ -76,7 +76,6 @@ def main(args):
|
||||
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
|
||||
|
||||
# Set dictionaries
|
||||
src_dict = task.source_dictionary
|
||||
tgt_dict = task.target_dictionary
|
||||
|
||||
# Optimize ensemble for generation
|
||||
@ -151,13 +150,18 @@ def main(args):
|
||||
|
||||
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
|
||||
|
||||
max_positions = utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
*[model.max_positions() for model in models]
|
||||
)
|
||||
|
||||
if args.buffer_size > 1:
|
||||
print('| Sentence buffer size:', args.buffer_size)
|
||||
print('| Type the input sentence and press return:')
|
||||
for inputs in buffered_read(args.buffer_size):
|
||||
indices = []
|
||||
results = []
|
||||
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
|
||||
for batch, batch_indices in make_batches(inputs, args, task, max_positions):
|
||||
indices.extend(batch_indices)
|
||||
results += process_batch(batch)
|
||||
|
||||
|
@ -44,7 +44,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
|
||||
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
|
||||
epoch_itr = data.EpochBatchIterator(
|
||||
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
|
||||
max_tokens=1,
|
||||
batch_sampler=[[i] for i in range(epoch_size)],
|
||||
)
|
||||
return trainer, epoch_itr
|
||||
|
||||
|
18
train.py
18
train.py
@ -12,7 +12,7 @@ import os
|
||||
import math
|
||||
import torch
|
||||
|
||||
from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
|
||||
from fairseq import distributed_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.fp16_trainer import FP16Trainer
|
||||
from fairseq.trainer import Trainer
|
||||
from fairseq.meters import AverageMeter, StopwatchMeter
|
||||
@ -57,11 +57,14 @@ def main(args):
|
||||
))
|
||||
|
||||
# Initialize dataloader
|
||||
max_positions = trainer.get_model().max_positions()
|
||||
epoch_itr = data.EpochBatchIterator(
|
||||
max_positions = utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
trainer.get_model().max_positions(),
|
||||
)
|
||||
epoch_itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(args.train_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.max_sentences_valid,
|
||||
max_sentences=args.max_sentences,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=True,
|
||||
required_batch_size_multiple=8,
|
||||
@ -193,11 +196,14 @@ def validate(args, trainer, task, epoch_itr, subsets):
|
||||
valid_losses = []
|
||||
for subset in subsets:
|
||||
# Initialize data iterator
|
||||
itr = data.EpochBatchIterator(
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.max_sentences_valid,
|
||||
max_positions=trainer.get_model().max_positions(),
|
||||
max_positions=utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
trainer.get_model().max_positions(),
|
||||
),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=8,
|
||||
seed=args.seed,
|
||||
|
Loading…
Reference in New Issue
Block a user