Clean up FairseqTask so that it's easier to extend/add new tasks

This commit is contained in:
Myle Ott 2018-08-30 12:17:33 -04:00
parent 6296de825a
commit 2e507d3cb4
12 changed files with 373 additions and 137 deletions

View File

@ -59,11 +59,13 @@ def main(parsed_args):
assert len(models) > 0
itr = data.EpochBatchIterator(
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=models[0].max_positions(),
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,

View File

@ -12,8 +12,6 @@ import os
import numpy as np
import torch
from . import FairseqDataset
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
@ -99,42 +97,35 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
class EpochBatchIterator(object):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
"""A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
dataset (Dataset): dataset from which to load the data
batch_sampler (Sampler): an iterator over batches of indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def __init__(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1,
num_shards=1, shard_id=0,
):
assert isinstance(dataset, FairseqDataset)
def __init__(self, dataset, batch_sampler, seed=1, num_shards=1, shard_id=0):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.max_positions = max_positions
self.ignore_invalid_inputs = ignore_invalid_inputs
self.bsz_mult = required_batch_size_multiple
self.frozen_batches = tuple(batch_sampler)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
with numpy_seed(self.seed):
self.frozen_batches = tuple(self._batch_generator())
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
@ -143,7 +134,13 @@ class EpochBatchIterator(object):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Shuffle batches and return a new iterator over the dataset."""
"""
Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
@ -153,10 +150,12 @@ class EpochBatchIterator(object):
return self._cur_epoch_itr
def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
@ -193,55 +192,6 @@ class EpochBatchIterator(object):
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
def _batch_generator(self):
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in self.dataset.ordered_indices():
if not self.dataset.valid_size(idx, self.max_positions):
if self.ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
'Size of sample #{} is invalid, max_positions={}, skip this '
'example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.max_positions))
sample_lens.append(self.dataset.num_tokens(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
self.bsz_mult * (len(batch) // self.bsz_mult),
len(batch) % self.bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), self.max_positions, ignored[:10]))
@contextlib.contextmanager
def numpy_seed(seed):
@ -256,3 +206,112 @@ def numpy_seed(seed):
yield
finally:
np.random.set_state(state)
def collect_filtered(function, iterable, filtered):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for el in iterable:
if function(el):
yield el
else:
filtered.append(el)
def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
"""
Filter indices based on their size.
Args:
indices (List[int]): ordered list of dataset indices
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception
if any elements are filtered. Default: ``False``
"""
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) < max_positions
else:
return all(a <= b for a, b in zip(size_fn(idx), max_positions))
ignored = []
itr = collect_filtered(check_size, indices, ignored)
for idx in itr:
if len(ignored) > 0 and raise_exception:
raise Exception((
'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.size(idx), max_positions))
yield idx
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), max_positions, ignored[:10]))
def batch_by_size(
indices, num_tokens_fn, max_tokens=None, max_sentences=None,
required_batch_size_multiple=1,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
bsz_mult = required_batch_size_multiple
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in indices:
sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch

View File

@ -7,6 +7,8 @@
import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
@ -18,7 +20,14 @@ class FairseqDataset(torch.utils.data.Dataset):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
@ -26,13 +35,16 @@ class FairseqDataset(torch.utils.data.Dataset):
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
raise NotImplementedError
def ordered_indices(self):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
raise NotImplementedError

View File

@ -8,6 +8,8 @@
import numpy as np
import torch
from fairseq import utils
from . import data_utils, FairseqDataset
@ -59,7 +61,27 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets."""
"""
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side.
Default: ``False``
max_source_positions (int, optional): max number of tokens in the source
sentence. Default: ``1024``
max_target_positions (int, optional): max number of tokens in the target
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
"""
def __init__(
self, src, src_sizes, src_dict,
@ -95,15 +117,43 @@ class LanguagePairDataset(FairseqDataset):
return len(self.src)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
"""Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if ``left_pad_source`` is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding
will appear on the left if ``left_pad_target`` is True.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if ``left_pad_target`` is True.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
)
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
"""Return a dummy batch with a given number of tokens."""
src_len, tgt_len = utils.resolve_max_positions(
(src_len, tgt_len),
max_positions,
(self.max_source_positions, self.max_target_positions),
)
bsz = num_tokens // max(src_len, tgt_len)
return self.collater([
{
@ -115,11 +165,18 @@ class LanguagePairDataset(FairseqDataset):
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def ordered_indices(self):
"""Ordered indices for batching."""
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
@ -127,18 +184,3 @@ class LanguagePairDataset(FairseqDataset):
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)

View File

@ -31,7 +31,16 @@ def collate(samples, pad_idx, eos_idx):
class MonolingualDataset(FairseqDataset):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
"""
A wrapper around torch.utils.data.Dataset for monolingual data.
Args:
dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
"""
def __init__(self, dataset, sizes, vocab, shuffle):
self.dataset = dataset
@ -47,12 +56,31 @@ class MonolingualDataset(FairseqDataset):
return len(self.dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
"""Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the right.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
right.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
assert isinstance(max_positions, float) or isinstance(max_positions, int)
tgt_len = min(tgt_len, max_positions)
"""Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1)
source, target = target[:-1], target[1:]
@ -62,19 +90,21 @@ class MonolingualDataset(FairseqDataset):
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def ordered_indices(self):
"""Ordered indices for batching."""
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(np.flip(self.sizes, 0))
return np.lexsort(order)
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
assert isinstance(max_positions, float) or isinstance(max_positions, int)
return self.sizes[index] <= max_positions

View File

@ -5,11 +5,13 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset
class FairseqTask(object):
"""
A Task defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for building the model/criterion and calculating the loss.
Tasks store dictionaries and provide helpers for loading/iterating over
Datasets, initializing the Model/Criterion and calculating the loss.
"""
@staticmethod
@ -37,6 +39,62 @@ class FairseqTask(object):
raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split]
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0,
):
"""
Generate batches of indices.
Args:
dataset (FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
max_positions (optional): max sentence length supported by the
model. Default: ``None``
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False``
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
Returns:
EpochBatchIterator: a batched iterator over the given dataset split
"""
assert isinstance(dataset, FairseqDataset)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
indices = data_utils.filter_by_size(
indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
)
# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
# return a reusable, sharded iterator
return data_utils.EpochBatchIterator(
dataset=dataset,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
)
def build_model(self, args):
from fairseq import models
return models.build_model(args, self)
@ -48,6 +106,9 @@ class FairseqTask(object):
def get_loss(self, model, criterion, sample):
return criterion(model, sample)
def max_positions(self):
return None
@property
def source_dictionary(self):
raise NotImplementedError

View File

@ -139,6 +139,9 @@ class TranslationTask(FairseqTask):
max_target_positions=self.args.max_target_positions,
)
def max_positions(self):
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def source_dictionary(self):
return self.src_dict

View File

@ -150,7 +150,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
@ -399,3 +399,17 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
max_positions = None
for arg in args:
if max_positions is None:
max_positions = arg
elif arg is not None:
if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg)
else:
max_positions = tuple(map(min, zip(max_positions, arg)))
return max_positions

View File

@ -54,11 +54,14 @@ def main(args):
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
itr = data.EpochBatchIterator(
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=models[0].max_positions(),
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.num_shards,

View File

@ -32,14 +32,14 @@ def buffered_read(buffer_size):
yield buffer
def make_batches(lines, args, src_dict, max_positions):
def make_batches(lines, args, task, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
for src_str in lines
]
lengths = np.array([t.numel() for t in tokens])
itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
itr = task.get_batch_iterator(
dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
@ -76,7 +76,6 @@ def main(args):
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
@ -151,13 +150,18 @@ def main(args):
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
max_positions = utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
)
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
for batch, batch_indices in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices)
results += process_batch(batch)

View File

@ -44,7 +44,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
max_tokens=1,
batch_sampler=[[i] for i in range(epoch_size)],
)
return trainer, epoch_itr

View File

@ -12,7 +12,7 @@ import os
import math
import torch
from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
@ -57,11 +57,14 @@ def main(args):
))
# Initialize dataloader
max_positions = trainer.get_model().max_positions()
epoch_itr = data.EpochBatchIterator(
max_positions = utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
@ -193,11 +196,14 @@ def validate(args, trainer, task, epoch_itr, subsets):
valid_losses = []
for subset in subsets:
# Initialize data iterator
itr = data.EpochBatchIterator(
itr = task.get_batch_iterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=trainer.get_model().max_positions(),
max_positions=utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
seed=args.seed,