Deprecate dummy_batch (#699)

Summary:
It was tedious defining these, let's try just taking the first batch lazily instead.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/699

Differential Revision: D15188266

Pulled By: myleott

fbshipit-source-id: a4c9f7ee3111278faaffa8a22ba91ed5f50e143d
This commit is contained in:
Myle Ott 2019-05-04 16:31:00 -07:00 committed by Facebook Github Bot
parent 7a5996fdc7
commit fc1a19a38d
12 changed files with 12 additions and 128 deletions

View File

@ -10,7 +10,6 @@ import torch
from fairseq import utils
from . import FairseqDataset
from .language_pair_dataset import collate as language_pair_collate, generate_dummy_batch
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
@ -141,19 +140,6 @@ class BacktranslationDataset(FairseqDataset):
)
return self.output_collater(samples)
def get_dummy_batch(self, num_tokens, max_positions):
"""Just use the tgt dataset get_dummy_batch"""
def collate_fn(samples):
return language_pair_collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
input_feeding=True,
)
dummy_batch = generate_dummy_batch(
num_tokens, collate_fn,
self.src_dict, tgt_dict=self.tgt_dict)
dummy_batch['is_dummy'] = True
return dummy_batch
def num_tokens(self, index):
"""Just use the tgt dataset num_tokens"""
return self.tgt_dataset.num_tokens(index)

View File

@ -28,10 +28,6 @@ class FairseqDataset(torch.utils.data.Dataset):
"""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
"""Return a dummy batch with a given number of tokens."""
raise NotImplementedError
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""

View File

@ -68,19 +68,6 @@ def collate(
return batch
def generate_dummy_batch(num_tokens, collate_fn, src_dict, src_len=128, tgt_dict=None, tgt_len=128):
"""Return a dummy batch with a given number of tokens."""
bsz = num_tokens // max(src_len, tgt_len)
return collate_fn([
{
'id': i,
'source': src_dict.dummy_sentence(src_len),
'target': tgt_dict.dummy_sentence(tgt_len) if tgt_dict is not None else None,
}
for i in range(bsz)
])
class LanguagePairDataset(FairseqDataset):
"""
A pair of torch.utils.data.Datasets.
@ -198,15 +185,6 @@ class LanguagePairDataset(FairseqDataset):
input_feeding=self.input_feeding,
)
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
"""Return a dummy batch with a given number of tokens."""
src_len, tgt_len = utils.resolve_max_positions(
(src_len, tgt_len),
max_positions,
(self.max_source_positions, self.max_target_positions),
)
return generate_dummy_batch(num_tokens, self.collater, self.src_dict, src_len, self.tgt_dict, tgt_len)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""

View File

@ -62,9 +62,6 @@ class LMContextWindowDataset(FairseqDataset):
return sample
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)

View File

@ -294,33 +294,6 @@ class MaskedLMDataset(FairseqDataset):
"""
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(
self,
num_tokens: int,
max_positions: Union[float, int],
tgt_len: int = 12
):
"""
Return a dummy batch with a given number of tokens.
"""
if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
source = self.vocab.dummy_sentence(tgt_len)
sentence_target = 0
bsz = num_tokens // tgt_len
return self.collater(
[
{
"id": i,
"block_one": source,
"block_two": source if self.has_pairs else None,
"sentence_target": sentence_target if self.has_pairs else None,
}
for i in range(bsz)
]
)
def num_tokens(
self,
index: int
@ -358,4 +331,4 @@ class MaskedLMDataset(FairseqDataset):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
self.dataset.prefetch(indices)

View File

@ -174,21 +174,6 @@ class MonolingualDataset(FairseqDataset):
"""
return collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
"""Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
bsz = max(num_tokens // tgt_len, 1)
target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
source, target = self._maybe_add_bos(source, target)
return self.collater([
{'id': i, 'source': source, 'target': target}
for i in range(bsz)
])
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""

View File

@ -116,15 +116,6 @@ class MultiCorpusSampledDataset(FairseqDataset):
selected_samples = [sample[selected_key] for sample in samples]
return self.datasets[selected_key].collater(selected_samples)
def get_dummy_batch(self, num_tokens: int, max_positions: int):
"""
Return a dummy batch with a given number of tokens. Assumes that the
max_positions specified is the same for all underlying datasets.
"""
return self.datasets[self.default_key].get_dummy_batch(
num_tokens, max_positions
)
def num_tokens(self, index: int):
"""
Return an example's length (number of tokens), used for batching. Here

View File

@ -72,17 +72,6 @@ class RoundRobinZipDatasets(FairseqDataset):
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key].collater(samples)
def get_dummy_batch(self, max_tokens, max_positions):
if self.eval_key is None:
# TODO should max_tokens be used independently for each batch like this?
return OrderedDict([
(key, dataset.get_dummy_batch(max_tokens, max_positions[key]))
for key, dataset in self.datasets.items()
])
else:
# at evaluation time it's useful to return a single batch directly
return self.datasets[self.eval_key].get_dummy_batch(max_tokens, max_positions[self.eval_key])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
# TODO make it configurable whether to use max() or sum() here

View File

@ -96,9 +96,6 @@ class TransformEosDataset(FairseqDataset):
samples = list(map(transform, samples))
return self.dataset.collater(samples)
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)

View File

@ -60,9 +60,6 @@ class TransformEosLangPairDataset(FairseqDataset):
return samples
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)

View File

@ -32,7 +32,7 @@ class Trainer(object):
communication of the gradients across workers.
"""
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=None):
self.args = args
self.task = task
@ -47,7 +47,7 @@ class Trainer(object):
self._model = self._model.cuda()
self._dummy_batch = dummy_batch
self._oom_batch = oom_batch
self._oom_batch = oom_batch or dummy_batch
self._lr_scheduler = None
self._num_updates = 0
@ -177,6 +177,9 @@ class Trainer(object):
def train_step(self, samples, dummy_batch=False, raise_oom=False):
"""Do forward, backward and parameter update."""
if self._dummy_batch is None:
self._dummy_batch = samples[0]
self._set_seed()
self.model.train()
self.criterion.train()

View File

@ -50,18 +50,8 @@ def main(args):
sum(p.numel() for p in model.parameters() if p.requires_grad),
))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions)
oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions)
# Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
@ -73,7 +63,10 @@ def main(args):
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
max_positions=utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
),
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
@ -83,8 +76,7 @@ def main(args):
)
# Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr):
trainer.dummy_train_step([dummy_batch])
load_checkpoint(args, trainer, epoch_itr)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf