Add FairseqTask

A Task defines the data format, stores shared state (e.g., dictionaries) and provides helpers for building the model/criterion and calculating the loss.

Changes:
- Add TranslationTask and LanguageModelingTask. New tasks can be registered with @register_task decorator.
- Add EpochBatchIterator to encapsulate batching and saving/restoring dataloader position
- Remove LEFT_PAD_* constants and make them configurable per task
This commit is contained in:
Myle Ott 2018-06-12 13:39:41 -04:00
parent 2de9353273
commit ff68a9ef50
47 changed files with 1195 additions and 939 deletions

View File

@ -9,45 +9,43 @@
import numpy as np
import torch
from fairseq import options, utils, progress_bar
from fairseq.data import data_utils, data_loaders
from fairseq import data, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for evaluation!'
if args.tokens_per_sample is None:
args.tokens_per_sample = 1024
print(args)
if args.max_target_positions is None:
args.max_target_positions = 1024
use_cuda = torch.cuda.is_available() and not args.cpu
dataset = data_loaders.load_dataset(args, [args.gen_subset], False)
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(','), dataset.src_dict, dataset.dst_dict)
print('| Dictionary: {} types'.format(len(dataset.src_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
models, _ = utils.load_ensemble_for_inference(args.path.split(','), task)
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
model.src_dict = dataset.src_dict
model.dst_dict = dataset.dst_dict
itr = dataset.eval_dataloader(
args.gen_subset,
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_sentences=args.max_sentences or 4,
max_positions=args.max_target_positions or 1024,
descending=True,
)
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id)
max_positions=model.max_positions(),
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models)
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()
@ -62,7 +60,7 @@ def main(args):
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
dataset.src_dict.string(hypo['tokens'][inf_scores.nonzero()]))
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()

View File

@ -22,11 +22,11 @@ $ python preprocess.py --only-source \
# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
$ python train.py data-bin/wikitext-103 --save-dir /checkpoints/wikitext-103 \
--max-epoch 35 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' \
--decoder-embed-dim 280 --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --max-target-positions 1024
$ python train.py --task language_modeling data-bin/wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
# Evaluate:
$ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'

View File

@ -28,6 +28,3 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-
# Generate:
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1
```

View File

@ -103,4 +103,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
```

View File

@ -15,8 +15,8 @@ CRITERION_REGISTRY = {}
CRITERION_CLASS_NAMES = set()
def build_criterion(args, src_dict, dst_dict):
return CRITERION_REGISTRY[args.criterion](args, src_dict, dst_dict)
def build_criterion(args, task):
return CRITERION_REGISTRY[args.criterion](args, task)
def register_criterion(name):

View File

@ -19,8 +19,8 @@ class AdaptiveLoss(FairseqCriterion):
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
def __init__(self, args, task):
super().__init__(args, task)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.

View File

@ -16,8 +16,8 @@ from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
def __init__(self, args, task):
super().__init__(args, task)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.

View File

@ -10,10 +10,10 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, args, src_dict, dst_dict):
def __init__(self, args, task):
super().__init__()
self.args = args
self.padding_idx = dst_dict.pad()
self.padding_idx = task.target_dictionary.pad()
@staticmethod
def add_args(parser):

View File

@ -15,8 +15,8 @@ from . import FairseqCriterion, register_criterion
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing
@staticmethod

View File

@ -6,7 +6,10 @@
# can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary
from .token_block_dataset import TokenBlockDataset
from .language_dataset import LanguageDatasets
from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
from .data_utils import EpochBatchIterator

View File

@ -1,10 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# padding constants
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False

View File

@ -1,24 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import LanguagePairDataset, MonolingualDataset
from fairseq.data.data_utils import infer_language_pair
def load_dataset(args, splits, is_raw):
""" Detect if we have a multi language dataset, or a single language dataset """
if args.source_lang is None and args.target_lang is None:
# find language pair automatically
args.source_lang, args.target_lang = infer_language_pair(args.data, splits)
if args.source_lang is None and args.target_lang is None and all(
os.path.exists(os.path.join(args.data, '{}.bin'.format(split))) for split in splits):
cls = MonolingualDataset
else:
cls = LanguagePairDataset
return cls.create_dataset(args, splits, is_raw)

View File

@ -6,51 +6,25 @@
# can be found in the PATENTS file in the same directory.
import contextlib
import glob
import itertools
import math
import numbers
import numpy as np
import os
import numpy as np
import torch
from torch.autograd import Variable
import torch.utils.data
from fairseq.data.dictionary import Dictionary
from fairseq.data.indexed_dataset import SizedDataset
from . import FairseqDataset
def has_binary_files(data_dir, splits):
for split in splits:
if len(glob.glob(os.path.join(data_dir, '{}*.bin'.format(split)))) == 0:
return False
return True
def infer_language_pair(path, splits):
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
for split in splits:
if len(parts) >= 3 and parts[0] == split and parts[-1] == 'idx':
src, dst = parts[1].split('-')
break
if len(parts) >= 3 and len(parts[1].split('-')) == 2:
return parts[1].split('-')
return src, dst
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def fmt_path(path, fmt, *args):
return os.path.join(path, fmt.format(*args))
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
@ -78,7 +52,35 @@ class ShardedIterator(object):
return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
@ -96,114 +98,149 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 1 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions):
return False
return True
class EpochBatchIterator(object):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False,
required_batch_size_multiple=1):
batch = []
mult = required_batch_size_multiple
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
def __init__(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1,
num_shards=1, shard_id=0,
):
assert isinstance(dataset, FairseqDataset)
self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.max_positions = max_positions
self.ignore_invalid_inputs = ignore_invalid_inputs
self.bsz_mult = required_batch_size_multiple
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
with numpy_seed(self.seed):
self.frozen_batches = tuple(self._batch_generator())
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Shuffle batches and return a new iterator over the dataset."""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
def _batch_generator(self):
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in map(int, indices):
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions))
sample_len = 0
sample_lens = []
ignored = []
for idx in self.dataset.ordered_indices():
if not self.dataset.valid_size(idx, self.max_positions):
if self.ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
'Size of sample #{} is invalid, max_positions={}, skip this '
'example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.max_positions))
sample_lens.append(max(src_size, dst_size))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
mod8_len = max(mult * (len(batch) // mult), len(batch) % mult)
yield batch[:mod8_len]
batch = batch[mod8_len:]
sample_lens = sample_lens[mod8_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
sample_lens.append(self.dataset.num_tokens(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
self.bsz_mult * (len(batch) // self.bsz_mult),
len(batch) % self.bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
batch.append(idx)
if len(batch) > 0:
yield batch
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, SizedDataset) and (dst is None or isinstance(dst, SizedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=allow_different_src_lens,
required_batch_size_multiple=required_batch_size_multiple,
))
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024),
required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert isinstance(src, SizedDataset) and isinstance(dst, SizedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True,
required_batch_size_multiple=required_batch_size_multiple,
))
return batches
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), self.max_positions, ignored[:10]))
@contextlib.contextmanager
@ -219,21 +256,3 @@ def numpy_seed(seed):
yield
finally:
np.random.set_state(state)
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len))
bsz = math.ceil(bsz / 8) * 8
assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict)
dummy_batch = {}
dummy_batch['id'] = Variable(torch.arange(bsz).long().cuda())
dummy_batch['ntokens'] = tgt_len * bsz
dummy_batch['target'] = Variable(torch.Tensor(bsz, tgt_len).uniform_(pad_idx + 1, dst_vocab - 1).long().cuda())
input = {}
input['prev_output_tokens'] = Variable(dummy_batch['target'].data.clone())
input['src_lengths'] = Variable(torch.LongTensor(bsz).fill_(src_len).cuda())
input['src_tokens'] = Variable(torch.Tensor(bsz, src_len).uniform_(pad_idx + 1, src_vocab - 1).long().cuda())
dummy_batch['net_input'] = input
return dummy_batch

View File

@ -188,3 +188,8 @@ class Dictionary(object):
return self.save(fd, threshold, nwords)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t

View File

@ -0,0 +1,38 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
"""Return a dummy batch with a given number of tokens."""
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
raise NotImplementedError
def ordered_indices(self):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
raise NotImplementedError

View File

@ -10,7 +10,6 @@ import struct
import numpy as np
import torch
import torch.utils.data
from fairseq.tokenizer import Tokenizer
@ -50,16 +49,7 @@ def data_file_path(prefix_path):
return prefix_path + '.bin'
class SizedDataset(torch.utils.data.Dataset):
def __init__(self):
self._sizes = None
@property
def sizes(self):
return self._sizes
class IndexedDataset(SizedDataset):
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path):
@ -74,7 +64,7 @@ class IndexedDataset(SizedDataset):
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self._sizes = read_longs(f, self.s)
self.sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path):
@ -93,7 +83,7 @@ class IndexedDataset(SizedDataset):
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return torch.from_numpy(a)
return torch.from_numpy(a).long() - 1 # subtract 1 for 0-based indexing
def __len__(self):
return self.size
@ -114,6 +104,7 @@ class IndexedInMemoryDataset(IndexedDataset):
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
self.buffer -= 1 # subtract 1 for 0-based indexing
def __del__(self):
pass
@ -123,7 +114,7 @@ class IndexedInMemoryDataset(IndexedDataset):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a)
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
@ -133,7 +124,7 @@ class IndexedRawTextDataset(IndexedDataset):
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self._sizes = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
@ -146,10 +137,10 @@ class IndexedRawTextDataset(IndexedDataset):
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility
).long()
self.tokens_list.append(tokens)
self._sizes.append(len(tokens))
self._sizes = np.array(self._sizes)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def __getitem__(self, i):
self.check_index(i)
@ -165,6 +156,10 @@ class IndexedRawTextDataset(IndexedDataset):
def __len__(self):
return self.size
@staticmethod
def exists(path):
return os.path.exists(path)
class IndexedDatasetBuilder(object):
element_sizes = {

View File

@ -1,80 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import torch
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, ShardedIterator, batches_by_size
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = ShardedIterator(b, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = ShardedIterator(batch_sampler, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)

View File

@ -5,29 +5,24 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import numpy as np
import torch
import torch.utils
from fairseq.data import LanguageDatasets
from fairseq.data.consts import LEFT_PAD_TARGET, LEFT_PAD_SOURCE
from fairseq.data.data_utils import fmt_path, load_dictionaries, collate_tokens
from fairseq.data.indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, has_target):
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return collate_tokens(
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LEFT_PAD_SOURCE)
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
@ -36,19 +31,20 @@ def collate(samples, pad_idx, eos_idx, has_target):
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LEFT_PAD_TARGET)
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LEFT_PAD_TARGET,
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return {
'id': id,
@ -62,93 +58,87 @@ def collate(samples, pad_idx, eos_idx, has_target):
}
class LanguagePairDataset(torch.utils.data.Dataset):
class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets."""
def __init__(self, src, dst, pad_idx, eos_idx):
def __init__(
self, src, src_sizes, src_dict,
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
self.src = src
self.dst = dst
self.pad_idx = pad_idx
self.eos_idx = eos_idx
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.shuffle = shuffle
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __getitem__(self, index):
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
}
def __len__(self):
return len(self.src)
def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
"""Merge a list of samples to form a mini-batch."""
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
)
@staticmethod
def create_dataset(args, splits, is_raw):
src, dst = args.source_lang, args.target_lang
assert src is not None and dst is not None, 'Source and target languages should be provided'
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
bsz = num_tokens // max(src_len, tgt_len)
return self.collater([
{
'id': i,
'source': self.src_dict.dummy_sentence(src_len),
'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
}
for i in range(bsz)
])
src_dict, dst_dict = load_dictionaries(args.data, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def create_raw_dataset():
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
def ordered_indices(self):
"""Ordered indices for batching."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
# Load dataset from raw text files
for split in splits:
src_path = os.path.join(args.data, '{}.{}'.format(split, src))
dst_path = os.path.join(args.data, '{}.{}'.format(split, dst))
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def create_binary_dataset():
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
# Load dataset from binary files
def all_splits_exist(src, dst, lang):
for split in splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(args.data, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + args.data)
for split in splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
return create_raw_dataset() if is_raw else create_binary_dataset()
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)

View File

@ -5,117 +5,77 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from fairseq.data import TokenBlockDataset, Dictionary, LanguageDatasets
from fairseq.data.indexed_dataset import IndexedInMemoryDataset
from fairseq.data.data_utils import fmt_path, collate_tokens
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, has_target):
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad=False, move_eos_to_beginning=False,
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
id = torch.LongTensor([s['id'] for s in samples])
# language models only have a decoder which is not padding-aware, so don't left pad for them
src_tokens = merge('source')
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
_, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
target = None
ntokens = None
if has_target:
target = merge('target')
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'id': torch.LongTensor([s['id'] for s in samples]),
'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': src_tokens,
'src_tokens': merge('source'),
},
'target': target,
'target': merge('target'),
}
class MonolingualDataset(Dataset):
class MonolingualDataset(FairseqDataset):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
def __init__(self, tokens, sizes, token_block_size, break_mode, pad_idx, eos_idx, next_token_is_target):
def __init__(self, dataset, sizes, vocab, shuffle):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.shuffle = shuffle
if next_token_is_target:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=1, break_mode=break_mode)
self.dst = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
else:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
self.dst = None
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __getitem__(self, index):
source, target = self.dataset[index]
return {'id': index, 'source': source, 'target': target}
def __len__(self):
return len(self.src)
return len(self.dataset)
def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
"""Merge a list of samples to form a mini-batch."""
return collate(samples, self.vocab.pad(), self.vocab.eos())
@staticmethod
def create_dataset(args, splits, is_raw):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
assert isinstance(max_positions, float) or isinstance(max_positions, int)
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1)
source, target = target[:-1], target[1:]
return self.collater([
{'id': i, 'source': source, 'target': target}
for i in range(bsz)
])
if is_raw:
raise Exception('raw text single language data sets are currently not supported')
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
source, target = self.dataset[index]
return len(source)
assert args.sample_break_mode == 'eos' or args.max_target_positions is not None
def ordered_indices(self):
"""Ordered indices for batching."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
path = args.data
dict = Dictionary.load(os.path.join(path, 'dict.txt'))
dataset = LanguageDatasets(None, None, dict, dict)
assert all(os.path.exists(os.path.join(path, '{}.bin'.format(split))) for split in splits)
for split in splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
split_path = fmt_path(path, '{}', prefix)
if not IndexedInMemoryDataset.exists(split_path):
break
ds = IndexedInMemoryDataset(split_path)
tokens = torch.from_numpy(ds.buffer)
dataset.splits[prefix] = MonolingualDataset(
tokens,
ds.sizes,
args.max_target_positions,
args.sample_break_mode,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
next_token_is_target=True,
)
return dataset
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
assert isinstance(max_positions, float) or isinstance(max_positions, int)
return self.sizes[index] <= max_positions

View File

@ -5,47 +5,49 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import numpy as np
import torch
from fairseq.data.indexed_dataset import SizedDataset
class TokenBlockDataset(torch.utils.data.Dataset):
"""Break a 1d tensor of tokens into blocks.
class TokenBlockDataset(SizedDataset):
"""Given a 1d tensor of tokens, this dataset will break tokens into blocks based on parameters. The blocks are
fetched from the original tensor so no additional memory is allocated"""
The blocks are fetched from the original tensor so no additional memory is allocated.
def __init__(self, tokens, block_size, sizes, offset=0, break_mode=None):
"""
Args:
tokens: torch tensor of tokens to break into blocks
block_size: An integer. the maximum size of each block (note this has no effect in 'eos' break mode)
sizes: A list of integers. sizes of sentences in the block. the sum of the sizes should add up to the
length of tokens
offset: An integer. rotates the tokens by this much before computing blocks. useful for language model targets
break_mode: A boolean if None/'none' then breaks tokens into equally sized blocks of size block_size
if 'complete' then breaks tokens into block sizes of up to block_size such that each block
contains complete sentences. block_size may be exceeded if some sentences exceed block_size
if 'eos' then each block contains a single sentence. does not respect block_size"""
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.offset = offset
self.total_size = len(tokens)
self.include_targets = include_targets
self.slice_indices = []
if break_mode is None or break_mode == 'none':
length = math.ceil(tokens.numel() / block_size)
length = math.ceil(len(tokens) / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
return (start, end)
self.slice_indices = [block_at(i) for i in np.arange(length)]
self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens)
tok_idx = 0
sz_idx = 0
curr_size = 0
@ -60,6 +62,7 @@ class TokenBlockDataset(SizedDataset):
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens)
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
@ -67,19 +70,20 @@ class TokenBlockDataset(SizedDataset):
self.slice_indices.append((curr, curr + sz))
curr += sz
else:
raise Exception('invalid break_mode. Supported values: none, complete, eos')
raise ValueError('Invalid break_mode: ' + break_mode)
self._sizes = np.array([e - s for s, e in self.slice_indices])
self.sizes = np.array([e - s for s, e in self.slice_indices])
def _slice(self, s, e):
# this will copy only the first block if offset > 0, instead of all blocks if we just rotated
# the tensor with torch.cat()
if s < self.offset:
return torch.cat([self.tokens[s - self.offset:], self.tokens[s:e - self.offset]])
return self.tokens[s - self.offset:e - self.offset]
def __getitem__(self, i):
return self._slice(*self.slice_indices[i])
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
if e == self.total_size:
return item[:-1], item[1:]
else:
return item, torch.LongTensor(self.tokens[s + 1:e + 1])
else:
return item
def __len__(self):
return len(self.slice_indices)

View File

@ -49,8 +49,8 @@ class FP16Trainer(Trainer):
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, model, criterion):
super().__init__(args, model, criterion)
def __init__(self, args, task, model, criterion):
super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()

View File

@ -15,13 +15,14 @@ from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel
from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args, src_dict, dst_dict):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, src_dict, dst_dict)
def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
def register_model(name):

View File

@ -23,10 +23,27 @@ class BaseFairseqModel(nn.Module):
"""Add model-specific arguments to the parser."""
pass
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample['target']
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_positions(self):
"""Maximum length supported by the model."""
raise NotImplementedError
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and
its descendants.
@ -87,33 +104,14 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
class FairseqLanguageModel(BaseFairseqModel):
@ -124,16 +122,9 @@ class FairseqLanguageModel(BaseFairseqModel):
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **unused):
def forward(self, src_tokens):
return self.decoder(src_tokens)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
def max_positions(self):
"""Maximum length supported by the model."""
return self.decoder.max_positions()
def max_encoder_positions(self):
return self.max_decoder_positions()

View File

@ -11,7 +11,6 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import (
AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
@ -58,26 +57,23 @@ class FConvModel(FairseqModel):
' to be equal)')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
encoder_embed_dict = None
if args.encoder_embed_path:
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, src_dict)
utils.print_embed_overlap(encoder_embed_dict, task.source_dictionary)
decoder_embed_dict = None
if args.decoder_embed_path:
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, dst_dict)
utils.print_embed_overlap(decoder_embed_dict, task.target_dictionary)
encoder = FConvEncoder(
src_dict,
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
embed_dict=encoder_embed_dict,
convolutions=eval(args.encoder_layers),
@ -86,7 +82,7 @@ class FConvModel(FairseqModel):
normalization_constant=args.normalization_constant,
)
decoder = FConvDecoder(
dst_dict,
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
embed_dict=decoder_embed_dict,
convolutions=eval(args.decoder_layers),
@ -125,27 +121,28 @@ class FConvLanguageModel(FairseqLanguageModel):
help='multiplies the result of the residual block by sqrt(value)')
@classmethod
def build_model(cls, args, dict, *_):
def build_model(cls, args, task):
"""Build a new model instance."""
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
# make sure all arguments are present in older models
base_lm_architecture(args)
if hasattr(args, 'max_target_positions'):
args.tokens_per_sample = args.max_target_positions
decoder = FConvDecoder(
dict,
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
max_positions=args.tokens_per_sample,
share_embed=False,
positional_embeddings=False,
adaptive_softmax_cutoff=options.eval_str_list(args.adaptive_softmax_cutoff,
type=int) if args.criterion == 'adaptive_loss' else None,
adaptive_softmax_cutoff=(
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
normalization_constant=args.normalization_constant,
)
return FConvLanguageModel(decoder)
@ -154,12 +151,15 @@ class FConvLanguageModel(FairseqLanguageModel):
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, embed_dict=None,
max_positions=1024, convolutions=((512, 3),) * 20, dropout=0.1,
normalization_constant=0.5):
def __init__(
self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, normalization_constant=0.5,
left_pad=True,
):
super().__init__(dictionary)
self.dropout = dropout
self.normalization_constant = normalization_constant
self.left_pad = left_pad
self.num_attention_layers = None
num_embeddings = len(dictionary)
@ -172,7 +172,7 @@ class FConvEncoder(FairseqEncoder):
max_positions,
embed_dim,
self.padding_idx,
left_pad=LEFT_PAD_SOURCE,
left_pad=self.left_pad,
)
convolutions = extend_conv_spec(convolutions)
@ -329,14 +329,18 @@ class AttentionLayer(nn.Module):
class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5):
def __init__(
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5,
left_pad=False,
):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.dropout = dropout
self.normalization_constant = normalization_constant
self.left_pad = left_pad
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
@ -357,7 +361,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions,
embed_dim,
padding_idx,
left_pad=LEFT_PAD_TARGET,
left_pad=self.left_pad,
) if positional_embeddings else None
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
@ -609,6 +613,22 @@ def base_lm_architecture(args):
args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
def fconv_lm_dauphin_wikitext103(args):
layers = '[(850, 6)] * 3'
layers += ' + [(850, 1)] * 1'
layers += ' + [(850, 5)] * 4'
layers += ' + [(850, 1)] * 1'
layers += ' + [(850, 4)] * 3'
layers += ' + [(1024, 4)] * 1'
layers += ' + [(2048, 4)] * 1'
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 280)
args.decoder_layers = getattr(args, 'decoder_layers', layers)
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000')
base_lm_architecture(args)
@register_model_architecture('fconv', 'fconv')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)

View File

@ -12,7 +12,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
@ -78,7 +77,7 @@ class FConvModelSelfAtt(FairseqModel):
help='use pretrained model when training [True, ...]')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
def build_model(cls, args, task):
trained_encoder, trained_decoder = None, None
pretrained = eval(args.pretrained)
if pretrained:
@ -86,8 +85,7 @@ class FConvModelSelfAtt(FairseqModel):
trained_model = utils.load_ensemble_for_inference(
# not actually for inference, but loads pretrained model parameters
filenames=[args.pretrained_checkpoint],
src_dict=src_dict,
dst_dict=dst_dict,
task=task,
)[0][0]
trained_decoder = list(trained_model.children())[1]
trained_encoder = list(trained_model.children())[0]
@ -100,7 +98,7 @@ class FConvModelSelfAtt(FairseqModel):
"""Build a new model instance."""
encoder = FConvEncoder(
src_dict,
task.source_dictionary,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
@ -110,7 +108,7 @@ class FConvModelSelfAtt(FairseqModel):
)
decoder = FConvDecoder(
dst_dict,
task.target_dictionary,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
@ -140,11 +138,12 @@ class FConvEncoder(FairseqEncoder):
def __init__(
self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
attention_nheads=1,
attention_nheads=1, left_pad=True,
):
super().__init__(dictionary)
self.dropout = dropout
self.num_attention_layers = None
self.left_pad = left_pad
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
@ -153,7 +152,7 @@ class FConvEncoder(FairseqEncoder):
max_positions,
embed_dim,
self.padding_idx,
left_pad=LEFT_PAD_SOURCE,
left_pad=self.left_pad,
)
def expand_bool_array(val):
@ -239,13 +238,14 @@ class FConvDecoder(FairseqDecoder):
convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
selfattention=False, attention_nheads=1, selfattention_nheads=1,
project_input=False, gated_attention=False, downsample=False,
pretrained=False, trained_decoder=None,
pretrained=False, trained_decoder=None, left_pad=False,
):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.pretrained = pretrained
self.pretrained_decoder = trained_decoder
self.dropout = dropout
self.left_pad = left_pad
in_channels = convolutions[0][0]
def expand_bool_array(val):
@ -269,7 +269,7 @@ class FConvDecoder(FairseqDecoder):
max_positions,
embed_dim,
padding_idx,
left_pad=LEFT_PAD_TARGET,
left_pad=self.left_pad,
)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)

View File

@ -11,7 +11,6 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.data import consts
from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model,
@ -63,7 +62,7 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder output')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
@ -79,14 +78,14 @@ class LSTMModel(FairseqModel):
pretrained_encoder_embed = None
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, src_dict, args.encoder_embed_dim)
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, dst_dict, args.decoder_embed_dim)
args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim)
encoder = LSTMEncoder(
dictionary=src_dict,
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
hidden_size=args.encoder_hidden_size,
num_layers=args.encoder_layers,
@ -96,7 +95,7 @@ class LSTMModel(FairseqModel):
pretrained_embed=pretrained_encoder_embed,
)
decoder = LSTMDecoder(
dictionary=dst_dict,
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim,
@ -114,11 +113,9 @@ class LSTMModel(FairseqModel):
class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=consts.LEFT_PAD_SOURCE,
pretrained_embed=None,
padding_value=0.,
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad=True, pretrained_embed=None, padding_value=0.,
):
super().__init__(dictionary)
self.num_layers = num_layers
@ -141,7 +138,7 @@ class LSTMEncoder(FairseqEncoder):
dropout=self.dropout_out,
bidirectional=bidirectional,
)
self.left_pad_source = left_pad_source
self.left_pad = left_pad
self.padding_value = padding_value
self.output_units = hidden_size
@ -149,7 +146,7 @@ class LSTMEncoder(FairseqEncoder):
self.output_units *= 2
def forward(self, src_tokens, src_lengths):
if self.left_pad_source:
if self.left_pad:
# convert left-padding to right-padding
src_tokens = utils.convert_padding_direction(
src_tokens,
@ -248,10 +245,9 @@ class AttentionLayer(nn.Module):
class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512,
pretrained_embed=None,
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
):
super().__init__(dictionary)
self.dropout_in = dropout_in

View File

@ -11,7 +11,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
@ -68,8 +67,9 @@ class TransformerModel(FairseqModel):
' (requires shared dictionary and embed dim)')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
def build_model(cls, args, task):
"""Build a new model instance."""
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
@ -77,7 +77,7 @@ class TransformerModel(FairseqModel):
return Embedding(num_embeddings, embed_dim, padding_idx)
if args.share_all_embeddings:
if src_dict != dst_dict:
if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
@ -87,17 +87,17 @@ class TransformerModel(FairseqModel):
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim)
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, dst_dict, decoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder)
class TransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(self, args, dictionary, embed_tokens):
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary)
self.dropout = args.dropout
@ -108,7 +108,7 @@ class TransformerEncoder(FairseqEncoder):
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=LEFT_PAD_SOURCE,
left_pad=left_pad,
learned=args.encoder_learned_pos,
)
@ -157,7 +157,7 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens):
def __init__(self, args, dictionary, embed_tokens, left_pad=False):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
@ -169,7 +169,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=LEFT_PAD_TARGET,
left_pad=left_pad,
learned=args.decoder_learned_pos,
)

View File

@ -48,7 +48,7 @@ class FixedSchedule(FairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates <= self.args.warmup_updates:
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()

View File

@ -13,10 +13,11 @@ from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
def get_training_parser():
parser = get_parser('Trainer')
def get_training_parser(default_task='translation'):
parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True)
add_distributed_training_args(parser)
add_model_args(parser)
@ -25,8 +26,8 @@ def get_training_parser():
return parser
def get_generation_parser(interactive=False):
parser = get_parser('Generation')
def get_generation_parser(interactive=False, default_task='translation'):
parser = get_parser('Generation', default_task)
add_dataset_args(parser, gen=True)
add_generation_args(parser)
if interactive:
@ -34,8 +35,8 @@ def get_generation_parser(interactive=False):
return parser
def get_eval_lm_parser():
parser = get_parser('Evaluate Language Model')
def get_eval_lm_parser(default_task='language_modeling'):
parser = get_parser('Evaluate Language Model', default_task)
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
@ -85,6 +86,8 @@ def parse_args_and_arch(parser, input_args=None):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'task'):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time.
args = parser.parse_args(input_args)
@ -104,7 +107,7 @@ def parse_args_and_arch(parser, input_args=None):
return args
def get_parser(desc):
def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
@ -114,34 +117,24 @@ def get_parser(desc):
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
# Task definitions can be found under fairseq/tasks/
parser.add_argument(
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(),
help='task: {} (default: {})'.format(', '.join(TASK_REGISTRY.keys()), default_task)
)
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
group.add_argument('data', metavar='DIR',
help='path to data directory')
group.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample'
' tokens. If set to "complete", splits samples only at the end'
' of sentence, but may include multiple sentences per sample.'
' If set to "eos", includes only one sentence per sample.')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
@ -152,10 +145,6 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')

View File

@ -14,7 +14,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(
self, models, beam_size=1, minlen=1, maxlen=None, stop_early=True,
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1,
):
@ -28,13 +28,10 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output.
"""
self.models = models
self.pad = models[0].dst_dict.pad()
self.unk = models[0].dst_dict.unk()
self.eos = models[0].dst_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models)
@ -70,6 +67,8 @@ class SequenceGenerator(object):
for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda)
if 'net_input' not in s:
continue
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:

View File

@ -11,10 +11,9 @@ from fairseq import utils
class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(self, models):
def __init__(self, models, tgt_dict):
self.models = models
self.pad = models[0].dst_dict.pad()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
self.pad = tgt_dict.pad()
def cuda(self):
for model in self.models:

43
fairseq/tasks/__init__.py Normal file
View File

@ -0,0 +1,43 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_task import FairseqTask
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args):
return TASK_REGISTRY[args.task].setup_task(args)
def register_task(name):
"""Decorator to register a new task."""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError('Cannot register duplicate task ({})'.format(name))
if not issubclass(cls, FairseqTask):
raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
return cls
return register_task_cls
# automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.tasks.' + module)

View File

@ -0,0 +1,57 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object):
"""
A Task defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for building the model/criterion and calculating the loss.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
pass
def __init__(self, args):
self.args = args
self.datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split):
raise NotImplementedError
def dataset(self, split):
"""Return a dataset split."""
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split]
def build_model(self, args):
return models.build_model(args, self)
def build_criterion(self, args):
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
return criterion(model, sample)
@property
def source_dictionary(self):
raise NotImplementedError
@property
def target_dictionary(self):
raise NotImplementedError

View File

@ -0,0 +1,66 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
)
from . import FairseqTask, register_task
@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split):
"""Load a dataset split."""
path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path)
tokens = ds.buffer
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
return self.dictionary

View File

@ -0,0 +1,112 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset,
)
from . import FairseqTask, register_task
@register_task('translation')
class TranslationTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split):
"""Load a dataset split."""
def split_exists(src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path)
return None
src_dataset = indexed_dataset(prefix + src, self.src_dict)
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict

View File

@ -27,7 +27,7 @@ class Trainer(object):
torch.distributed.all_reduce.
"""
def __init__(self, args, model, criterion):
def __init__(self, args, task, model, criterion):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
@ -35,6 +35,7 @@ class Trainer(object):
self.args = args
# copy model and criterion to current device
self.task = task
self.model = model.cuda()
self.criterion = criterion.cuda()
@ -67,6 +68,7 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
@ -90,6 +92,10 @@ class Trainer(object):
self._num_updates = last_optim['num_updates']
if 'train_meters' in extra_state:
self.meters = extra_state['train_meters']
del extra_state['train_meters']
return extra_state
def train_step(self, sample, update_params=True):
@ -99,9 +105,14 @@ class Trainer(object):
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self._build_optimizer()
sample = self._prepare_sample(sample, volatile=False)
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# forward and backward pass
sample = self._prepare_sample(sample, volatile=False)
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss)
@ -182,7 +193,7 @@ class Trainer(object):
try:
with utils.maybe_no_grad(eval):
# calculate loss and sample size
loss, sample_size, logging_output_ = self.criterion(self.model, sample)
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
@ -311,6 +322,10 @@ class Trainer(object):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()

View File

@ -72,7 +72,7 @@ def load_model_state(filename, model):
# load model parameters
try:
model.load_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
@ -120,23 +120,26 @@ def _upgrade_state_dict(state):
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': 0,
}
return state
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
data_dir=None, model_arg_overrides=None):
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
from fairseq import models
from fairseq.data import data_utils
# load model architectures and weights
states = []
for filename in filenames:
@ -149,14 +152,10 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data_utils.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []
for state in states:
model = models.build_model(args, src_dict, dst_dict)
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
@ -308,15 +307,15 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
from fairseq import tokenizer
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment

View File

@ -8,8 +8,7 @@
import torch
from fairseq import bleu, options, progress_bar, tokenizer, utils
from fairseq.data import data_utils, data_loaders
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
@ -17,65 +16,67 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for generation!'
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
dataset = data_loaders.load_dataset(args, [args.gen_subset], args.replace_unk is not None)
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(',')
models, _ = utils.load_ensemble_for_inference(model_paths, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
models, _ = utils.load_ensemble_for_inference([args.path], task)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
)
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader(
args.gen_subset,
# Load dataset (possibly sharded)
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
)
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id)
max_positions=models[0].max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
# Initialize generator
gen_timer = StopwatchMeter()
if args.score_reference:
translator = SequenceScorer(models)
translator = SequenceScorer(models, task.target_dictionary)
else:
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len)
models, task.target_dictionary, beam_size=args.beam,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
)
if use_cuda:
translator.cuda()
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
@ -84,7 +85,9 @@ def main(args):
else:
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size)
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
@ -93,12 +96,12 @@ def main(args):
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
src_str = src_dict.string(src_tokens, args.remove_bpe)
if has_target:
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
@ -112,7 +115,7 @@ def main(args):
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
@ -135,7 +138,7 @@ def main(args):
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, dataset.dst_dict, add_if_not_exist=True)
target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))

View File

@ -6,17 +6,17 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import numpy as np
import sys
import torch
from collections import namedtuple
from torch.autograd import Variable
from fairseq import options, tokenizer, utils
from fairseq.data.data_utils import collate_tokens
from fairseq.data.consts import LEFT_PAD_SOURCE
from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')
@ -33,44 +33,52 @@ def buffered_read(buffer_size):
yield buffer
def make_batches(lines, batch_size, src_dict):
tokens = [tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() for src_str in lines]
lengths = [t.numel() for t in tokens]
indices = np.argsort(lengths)
num_batches = np.ceil(len(indices) / batch_size)
batches = np.array_split(indices, num_batches)
for batch_idxs in batches:
batch_toks = [tokens[i] for i in batch_idxs]
batch_toks = collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(), LEFT_PAD_SOURCE,
move_eos_to_beginning=False)
def make_batches(lines, args, src_dict, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
for src_str in lines
]
lengths = np.array([t.numel() for t in tokens])
itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
srcs=[lines[i] for i in batch_idxs],
tokens=batch_toks,
lengths=tokens[0].new([lengths[i] for i in batch_idxs]),
), batch_idxs
srcs=[lines[i] for i in batch['id']],
tokens=batch['net_input']['src_tokens'],
lengths=batch['net_input']['src_lengths'],
), batch['id']
def main(args):
print(args)
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
args.max_sentences = 1
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size'
if args.buffer_size < 1:
args.buffer_size = 1
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Setup task, e.g., translation
task = tasks.setup_task(args)
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(',')
models, model_args = utils.load_ensemble_for_inference(model_paths, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
models, model_args = utils.load_ensemble_for_inference(model_paths, task)
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
@ -80,10 +88,11 @@ def main(args):
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len)
minlen=args.min_len,
)
if use_cuda:
translator.cuda()
@ -106,7 +115,7 @@ def main(args):
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dst_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
@ -135,7 +144,7 @@ def main(args):
for inputs in buffered_read(args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, max(1, args.max_sentences or 1), src_dict):
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
indices.extend(batch_indices)
results += process_batch(batch)

View File

@ -148,6 +148,7 @@ def train_translation_model(data_dir, arch, extra_flags=None):
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'translation',
data_dir,
'--save-dir', data_dir,
'--arch', arch,
@ -166,15 +167,18 @@ def train_translation_model(data_dir, arch, extra_flags=None):
def generate_main(data_dir):
generate_parser = options.get_generation_parser()
generate_args = generate_parser.parse_args([
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
])
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
],
)
# evaluate model in batch mode
generate.main(generate_args)
@ -205,6 +209,7 @@ def train_language_model(data_dir, arch):
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
@ -214,7 +219,7 @@ def train_language_model(data_dir, arch):
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--max-target-positions', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
@ -226,11 +231,14 @@ def train_language_model(data_dir, arch):
def eval_lm_main(data_dir):
eval_lm_parser = options.get_eval_lm_parser()
eval_lm_args = eval_lm_parser.parse_args([
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--no-progress-bar',
])
eval_lm_args = options.parse_args_and_arch(
eval_lm_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--no-progress-bar',
],
)
eval_lm.main(eval_lm_args)

29
tests/test_data_utils.py Normal file
View File

@ -0,0 +1,29 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
from fairseq.data import data_utils
class TestDataUtils(unittest.TestCase):
def test_counting_iterator(self):
x = list(range(10))
itr = data_utils.CountingIterator(x)
self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1)
itr.skip(3)
self.assertEqual(next(itr), 5)
itr.skip(3)
self.assertEqual(next(itr), 9)
self.assertFalse(itr.has_next())
if __name__ == '__main__':
unittest.main()

View File

@ -46,12 +46,13 @@ class TestLabelSmoothing(unittest.TestCase):
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d)
self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
self.model = self.task.build_model(self.args)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
@ -59,7 +60,7 @@ class TestLabelSmoothing(unittest.TestCase):
def test_padding(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx):
@ -68,7 +69,7 @@ class TestLabelSmoothing(unittest.TestCase):
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d)
model1 = self.task.build_model(args1)
loss1, _, _ = crit(model1, sample1)
return loss1
@ -78,15 +79,15 @@ class TestLabelSmoothing(unittest.TestCase):
def test_reduction(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss)

View File

@ -80,10 +80,12 @@ class TestSequenceGenerator(unittest.TestCase):
]),
]
self.model = test_utils.TestModel.build_model(args, d, d)
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_with_normalization(self):
generator = SequenceGenerator([self.model])
generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
@ -102,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], normalize_scores=False)
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
@ -120,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], len_penalty=lenpen)
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
@ -138,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], len_penalty=lenpen)
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
@ -155,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
generator = SequenceGenerator([self.model], maxlen=2)
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
@ -172,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], stop_early=False)
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1

View File

@ -83,8 +83,9 @@ class TestSequenceScorer(unittest.TestCase):
[0.3, 0.7], # sentence 3
]
model = test_utils.TestModel.build_model(args, d, d)
scorer = SequenceScorer([model])
task = test_utils.TestTranslationTask.setup_task(args, d, d)
model = task.build_model(args)
scorer = SequenceScorer([model], task.target_dictionary)
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id])

View File

@ -8,23 +8,45 @@
import contextlib
from io import StringIO
import unittest
from unittest.mock import MagicMock, patch
import torch
from fairseq import data
import train
def mock_trainer(epoch, num_updates, end_of_epoch):
def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch}
trainer.load_checkpoint.return_value = {
'train_iterator': {
'epoch': epoch,
'iterations_in_epoch': iterations_in_epoch,
'shuffle': False,
},
}
trainer.get_num_updates.return_value = num_updates
return trainer
def mock_loader(length):
loader = MagicMock()
loader.__next__.return_value = list(range(length))
return loader
def mock_dict():
d = MagicMock()
d.pad.return_value = 1
d.eos.return_value = 2
d.unk.return_value = 3
return d
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
max_tokens=1,
)
return trainer, epoch_itr
class TestLoadCheckpoint(unittest.TestCase):
@ -40,29 +62,41 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 200, False)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(next(ds), 50)
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 300, True)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 3)
self.assertEqual(next(iter(ds)), 0)
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(0, 0, False)
loader = mock_loader(150)
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1)
self.assertEqual(next(iter(ds)), 0)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def tearDown(self):
patch.stopall()

View File

@ -8,18 +8,19 @@
import torch
from torch.autograd import Variable
from fairseq.data.language_pair_dataset import collate
from fairseq import utils
from fairseq.data import dictionary
from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
)
from fairseq.tasks import FairseqTask
def dummy_dictionary(vocab_size, prefix='token_'):
d = dictionary.Dictionary()
d = Dictionary()
for i in range(vocab_size):
token = prefix + str(i)
d.add_symbol(token)
@ -46,14 +47,7 @@ def dummy_dataloader(
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=(
lambda samples: collate(
samples,
padding_idx,
eos_idx,
has_target=True,
)
),
collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
)
return iter(dataloader)
@ -71,14 +65,38 @@ class TestDataset(torch.utils.data.Dataset):
return len(self.data)
class TestTranslationTask(FairseqTask):
def __init__(self, args, src_dict, tgt_dict, model):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.model = model
@classmethod
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
return cls(args, src_dict, tgt_dict, model)
def build_model(self, args):
return TestModel.build_model(args, self)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
class TestModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def build_model(cls, args, src_dict, dst_dict):
encoder = TestEncoder(args, src_dict)
decoder = TestIncrementalDecoder(args, dst_dict)
def build_model(cls, args, task):
encoder = TestEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder)

196
train.py
View File

@ -7,14 +7,12 @@
# can be found in the PATENTS file in the same directory.
import collections
import itertools
import os
import math
import torch
from itertools import islice
from fairseq import criterions, models, options, progress_bar, utils
from fairseq.data import data_utils, data_loaders
from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
@ -23,7 +21,6 @@ from fairseq.meters import AverageMeter, StopwatchMeter
def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
print(args)
if not torch.cuda.is_available():
@ -31,27 +28,25 @@ def main(args):
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Load dataset
splits = ['train', 'valid']
dataset = load_dataset(args, splits)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
# Load dataset splits
load_dataset_splits(args, task, ['train', 'valid'])
criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Build trainer
if args.fp16:
trainer = FP16Trainer(args, model, criterion)
trainer = FP16Trainer(args, task, model, criterion)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, model, criterion)
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
@ -59,25 +54,24 @@ def main(args):
))
# Initialize dataloader
train_dataloader = dataset.train_dataloader_generator(
args.train_subset,
max_positions = trainer.get_model().max_positions()
epoch_itr = data.EpochBatchIterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=(
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
),
max_sentences=args.max_sentences_valid,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
seed=args.seed,
sample_without_replacement=args.sample_without_replacement,
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
)
# Load the latest checkpoint if one is available
epoch, next_ds = load_checkpoint(args, trainer, train_dataloader)
load_checkpoint(args, trainer, epoch_itr)
# Send a dummy batch to warm the caching allocator
dummy_batch = data_utils.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
@ -88,58 +82,41 @@ def main(args):
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next_ds, epoch, dataset)
train(args, trainer, task, epoch_itr)
if epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, dataset, valid_subsets, epoch)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, valid_losses[0])
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=valid_losses[0])
epoch += 1
next_ds = next(train_dataloader)
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def load_dataset(args, splits):
is_raw = not data_utils.has_binary_files(args.data, splits)
dataset = data_loaders.load_dataset(args, splits, is_raw)
return dataset
def train(args, trainer, itr, epoch, dataset):
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
# reproducible results when resuming from checkpoints
seed = args.seed + epoch
torch.manual_seed(seed)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
if epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch - 1]
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(itr)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
for i, sample in enumerate(progress):
num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
@ -165,8 +142,8 @@ def train(args, trainer, itr, epoch, dataset):
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
valid_losses = validate(args, trainer, dataset, [first_valid], epoch)
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=valid_losses[0])
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
@ -177,6 +154,12 @@ def train(args, trainer, itr, epoch, dataset):
stats[k] = meter.avg
progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer):
stats = collections.OrderedDict()
@ -202,27 +185,24 @@ def get_training_stats(trainer):
return stats
def validate(args, trainer, dataset, subsets, epoch):
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
for subset in subsets:
# Initialize dataloader
max_positions_valid = (
trainer.get_model().max_encoder_positions(),
trainer.get_model().max_decoder_positions(),
)
itr = dataset.eval_dataloader(
subset,
# Initialize data iterator
itr = data.EpochBatchIterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions_valid,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
shard_id=args.distributed_rank,
max_positions=trainer.get_model().max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
)
shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch,
args, itr, epoch_itr.epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
@ -232,8 +212,8 @@ def validate(args, trainer, dataset, subsets, epoch):
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
@ -274,9 +254,11 @@ def get_perplexity(loss):
return float('inf')
def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if args.no_save or args.distributed_rank > 0:
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
@ -298,11 +280,9 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'best': prev_best,
'end_of_epoch': end_of_epoch,
'epoch': epoch,
'best': save_checkpoint.best,
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
}
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
@ -325,46 +305,36 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
os.remove(old_chk)
def load_checkpoint(args, trainer, train_dataloader):
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
epoch = 1
ds = None
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
end_of_epoch = extra_state.get('end_of_epoch', True)
trainer_updates = trainer.get_num_updates()
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator'])
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
updates = 0
for i in range(epoch):
ds = next(train_dataloader)
updates += len(ds)
if not end_of_epoch and ds is not None and updates > trainer_updates:
completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
print('| resuming from batch {}'.format(completed_batches + 1))
# consume completed batches
next(islice(ds, completed_batches, completed_batches), None)
else:
if not end_of_epoch:
print('| WARNING: checkpoint is not at end of epoch')
ds = next(train_dataloader)
epoch += 1
trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0))
return epoch, ds or next(train_dataloader)
def load_dataset_splits(args, task, splits):
for split in splits:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k)
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
except FileNotFoundError as e:
if k > 0:
break
raise e
if __name__ == '__main__':