Conv lm implementation

This implements convolutional language model from https://arxiv.org/pdf/1612.08083.pdf

There are 3 modes for constructing batches:

- token block: fill each sample with a specified number of tokens without regard for sentence delimiters - this is what was used for training in the paper
- complete: fill each sample with a specified number of tokens but make sure it contains only complete sentences (i.e. if next sentence goes over token block limit, move it to the next sample) - this was used for evaluation in the paper
- eos: one sentence per sample (skip blank lines)

some results:

GCNN-13 - GBW - 37.46
GCNN-14B - GBW - 33.88
GCNN-8 - Wiki103 - 43.76
GCNN-14 - Wiki103 - 35.66

train:

python train.py /private/home/abaevski/data/wiki103 --save-dir /tmp --fp16 --max-epoch 35 --save-interval 1 --save-interval-updates 1000 --keep-interval-updates 25 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 --decoder-embed-dim 280 --decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion cross_entropy --max-tokens 1024 --max-target-positions 1024 --seed 1 --log-format json --log-interval 500

eval:

python eval_lm.py ~abaevski/data/wiki103 --path '/checkpoint02/abaevski/2018-04-27/lm_wiki.fp16.mxup300000.fconv.adam.lrs=reduce_lr_on_plateau.emb280.layers(850,6)*3+(850,1)+(850,5)*4+(850,1)+(850,4)*3+(1024,1)+(2048,4).lr0.0005.clp0.1.drp0.3.wd0.0.crt=cross_entropy.mxtk2048.smptk256.seed1.ngpu8/checkpoint_last.pt'
This commit is contained in:
alexeib 2018-05-25 14:43:37 +01:00 committed by Myle Ott
parent 4e1ec2d883
commit 4c2ef2de74
34 changed files with 1557 additions and 675 deletions

84
eval_lm.py Normal file
View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from fairseq import options, utils, progress_bar
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for evaluation!'
print(args)
if args.max_target_positions is None:
args.max_target_positions = 1024
use_cuda = torch.cuda.is_available() and not args.cpu
dataset = data_loaders.load_dataset(args, [args.gen_subset], False)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| Dictionary: {} types'.format(len(dataset.src_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
model.src_dict = dataset.src_dict
model.dst_dict = dataset.dst_dict
itr = dataset.eval_dataloader(
args.gen_subset,
max_sentences=args.max_sentences or 4,
max_positions=args.max_target_positions or 1024,
descending=True,
)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models)
if use_cuda:
scorer.cuda()
score_sum = 0.
count = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
dataset.src_dict.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
if __name__ == '__main__':
parser = options.get_eval_lm_parser()
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,73 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('adaptive_loss')
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample['net_input'])
target = model.get_targets(sample, net_output).view(-1)
bsz = target.size(0)
logits, target = adaptive_softmax(net_output[0], target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1))
loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'sample_size': sample_size,
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output

View File

@ -1,458 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import itertools
import glob
import math
import numbers
import numpy as np
import os
import torch
from torch.autograd import Variable
import torch.utils.data
from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
def has_binary_files(data_dir, splits):
for split in splits:
if len(glob.glob(os.path.join(data_dir, '{}.*-*.*.bin'.format(split)))) < 2:
return False
return True
def infer_language_pair(path, splits):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
for split in splits:
if parts[0] == split and parts[-1] == 'idx':
src, dst = parts[1].split('-')
break
return src, dst
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def load_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files
def all_splits_exist(src, dst, lang):
for split in load_splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
for split in load_splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path('{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from raw text files
for split in load_splits:
src_path = os.path.join(path, '{}.{}'.format(split, src))
dst_path = os.path.join(path, '{}.{}'.format(split, dst))
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id):
assert shard_id >= 0 and shard_id < num_shards
self.itr = itr
self.num_shards = num_shards
self.shard_id = shard_id
def __len__(self):
return len(self.itr)
def __iter__(self):
for i, v in enumerate(self.itr):
if i % self.num_shards == self.shard_id:
yield v
class LanguagePairDataset(torch.utils.data.Dataset):
# padding constants
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def collate(samples, pad_idx, eos_idx, has_target=True):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
@staticmethod
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
if left_pad:
copy_tensor(v, res[i][size-len(v):])
else:
copy_tensor(v, res[i][:len(v)])
return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 1 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions):
return False
return True
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False,
required_batch_size_multiple=1):
batch = []
mult = required_batch_size_multiple
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in map(int, indices):
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions))
sample_lens.append(max(src_size, dst_size))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
mod8_len = max(mult * (len(batch) // mult), len(batch) % mult)
yield batch[:mod8_len]
batch = batch[mod8_len:]
sample_lens = sample_lens[mod8_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=allow_different_src_lens,
required_batch_size_multiple=required_batch_size_multiple,
))
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024),
required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True,
required_batch_size_multiple=required_batch_size_multiple,
))
return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len))
bsz = math.ceil(bsz / 8) * 8
assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict)
dummy_batch = {}
dummy_batch['id'] = Variable(torch.arange(bsz).long().cuda())
dummy_batch['ntokens'] = tgt_len * bsz
dummy_batch['target'] = Variable(torch.Tensor(bsz, tgt_len).uniform_(pad_idx + 1, dst_vocab - 1).long().cuda())
input = {}
input['prev_output_tokens'] = Variable(dummy_batch['target'].data.clone())
input['src_lengths'] = Variable(torch.LongTensor(bsz).fill_(src_len).cuda())
input['src_tokens'] = Variable(torch.Tensor(bsz, src_len).uniform_(pad_idx + 1, src_vocab - 1).long().cuda())
dummy_batch['net_input'] = input
return dummy_batch

13
fairseq/data/__init__.py Normal file
View File

@ -0,0 +1,13 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary
from .token_block_dataset import TokenBlockDataset
from .language_dataset import LanguageDatasets
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .offset_dataset import OffsetDataset

10
fairseq/data/consts.py Normal file
View File

@ -0,0 +1,10 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# padding constants
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False

View File

@ -0,0 +1,24 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import LanguagePairDataset, MonolingualDataset
from fairseq.data.data_utils import infer_language_pair
def load_dataset(args, splits, is_raw):
""" Detect if we have a multi language dataset, or a single language dataset """
if args.source_lang is None and args.target_lang is None:
# find language pair automatically
args.source_lang, args.target_lang = infer_language_pair(args.data, splits)
if args.source_lang is None and args.target_lang is None and all(
os.path.exists(os.path.join(args.data, '{}.bin'.format(split))) for split in splits):
cls = MonolingualDataset
else:
cls = LanguagePairDataset
return cls.create_dataset(args, splits, is_raw)

243
fairseq/data/data_utils.py Normal file
View File

@ -0,0 +1,243 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import glob
import math
import numbers
import numpy as np
import os
import torch
from torch.autograd import Variable
import torch.utils.data
from fairseq.data.dictionary import Dictionary
from fairseq.data.indexed_dataset import SizedDataset
def has_binary_files(data_dir, splits):
for split in splits:
if len(glob.glob(os.path.join(data_dir, '{}*.bin'.format(split)))) == 0:
return False
return True
def infer_language_pair(path, splits):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
for split in splits:
if len(parts) >= 3 and parts[0] == split and parts[-1] == 'idx':
src, dst = parts[1].split('-')
break
return src, dst
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def fmt_path(path, fmt, *args):
return os.path.join(path, fmt.format(*args))
class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id):
assert shard_id >= 0 and shard_id < num_shards
self.itr = itr
self.num_shards = num_shards
self.shard_id = shard_id
def __len__(self):
return len(self.itr)
def __iter__(self):
for i, v in enumerate(self.itr):
if i % self.num_shards == self.shard_id:
yield v
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
if left_pad:
copy_tensor(v, res[i][size - len(v):])
else:
copy_tensor(v, res[i][:len(v)])
return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 1 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions):
return False
return True
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False,
required_batch_size_multiple=1):
batch = []
mult = required_batch_size_multiple
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in map(int, indices):
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions))
sample_lens.append(max(src_size, dst_size))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
mod8_len = max(mult * (len(batch) // mult), len(batch) % mult)
yield batch[:mod8_len]
batch = batch[mod8_len:]
sample_lens = sample_lens[mod8_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, SizedDataset) and (dst is None or isinstance(dst, SizedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=allow_different_src_lens,
required_batch_size_multiple=required_batch_size_multiple,
))
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024),
required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert isinstance(src, SizedDataset) and isinstance(dst, SizedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True,
required_batch_size_multiple=required_batch_size_multiple,
))
return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len))
bsz = math.ceil(bsz / 8) * 8
assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict)
dummy_batch = {}
dummy_batch['id'] = Variable(torch.arange(bsz).long().cuda())
dummy_batch['ntokens'] = tgt_len * bsz
dummy_batch['target'] = Variable(torch.Tensor(bsz, tgt_len).uniform_(pad_idx + 1, dst_vocab - 1).long().cuda())
input = {}
input['prev_output_tokens'] = Variable(dummy_batch['target'].data.clone())
input['src_lengths'] = Variable(torch.LongTensor(bsz).fill_(src_len).cuda())
input['src_tokens'] = Variable(torch.Tensor(bsz, src_len).uniform_(pad_idx + 1, src_vocab - 1).long().cuda())
dummy_batch['net_input'] = input
return dummy_batch

View File

@ -9,6 +9,7 @@ import numpy as np
import os
import struct
import torch
import torch.utils.data
from fairseq.tokenizer import Tokenizer
@ -48,10 +49,20 @@ def data_file_path(prefix_path):
return prefix_path + '.bin'
class IndexedDataset(object):
class SizedDataset(torch.utils.data.Dataset):
def __init__(self):
self._sizes = None
@property
def sizes(self):
return self._sizes
class IndexedDataset(SizedDataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path):
super().__init__()
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00'
@ -62,7 +73,7 @@ class IndexedDataset(object):
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s)
self._sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path):
@ -121,7 +132,7 @@ class IndexedRawTextDataset(IndexedDataset):
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self._sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
@ -136,8 +147,8 @@ class IndexedRawTextDataset(IndexedDataset):
append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
self._sizes.append(len(tokens))
self._sizes = np.array(self._sizes)
def __getitem__(self, i):
self.check_index(i)
@ -155,10 +166,9 @@ class IndexedRawTextDataset(IndexedDataset):
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,

View File

@ -0,0 +1,80 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import torch
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, mask_batches, batches_by_size
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)

View File

@ -0,0 +1,154 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import torch
import torch.utils
from fairseq.data import LanguageDatasets
from fairseq.data.consts import LEFT_PAD_TARGET, LEFT_PAD_SOURCE
from fairseq.data.data_utils import fmt_path, load_dictionaries, collate_tokens
from fairseq.data.indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
def collate(samples, pad_idx, eos_idx, has_target):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
class LanguagePairDataset(torch.utils.data.Dataset):
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def create_dataset(args, splits, is_raw):
src, dst = args.source_lang, args.target_lang
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(args.data, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
def create_raw_dataset():
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
# Load dataset from raw text files
for split in splits:
src_path = os.path.join(args.data, '{}.{}'.format(split, src))
dst_path = os.path.join(args.data, '{}.{}'.format(split, dst))
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
def create_binary_dataset():
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
# Load dataset from binary files
def all_splits_exist(src, dst, lang):
for split in splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(args.data, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + args.data)
for split in splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
return create_raw_dataset() if is_raw else create_binary_dataset()

View File

@ -0,0 +1,121 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import torch
from torch.utils.data import Dataset
from fairseq.data import TokenBlockDataset, Dictionary, LanguageDatasets
from fairseq.data.indexed_dataset import IndexedInMemoryDataset
from fairseq.data.data_utils import fmt_path, collate_tokens
def collate(samples, pad_idx, eos_idx, has_target):
if len(samples) == 0:
return {}
def merge(key):
return collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad=False, move_eos_to_beginning=False,
)
id = torch.LongTensor([s['id'] for s in samples])
# language models only have a decoder which is not padding-aware, so don't left pad for them
src_tokens = merge('source')
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
_, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
target = None
ntokens = None
if has_target:
target = merge('target')
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
},
'target': target,
}
class MonolingualDataset(Dataset):
def __init__(self, tokens, sizes, token_block_size, break_mode, pad_idx, eos_idx, next_token_is_target):
if next_token_is_target:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=1, break_mode=break_mode)
self.dst = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
else:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
self.dst = None
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def create_dataset(args, splits, is_raw):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if is_raw:
raise Exception('raw text single language data sets are currently not supported')
assert args.sample_break_mode == 'eos' or args.max_target_positions is not None
path = args.data
dict = Dictionary.load(os.path.join(path, 'dict.txt'))
dataset = LanguageDatasets(None, None, dict, dict)
assert all(os.path.exists(os.path.join(path, '{}.bin'.format(split))) for split in splits)
for split in splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
split_path = fmt_path(path, '{}', prefix)
if not IndexedInMemoryDataset.exists(split_path):
break
ds = IndexedInMemoryDataset(split_path)
tokens = torch.from_numpy(ds.buffer)
dataset.splits[prefix] = MonolingualDataset(
tokens,
ds.sizes,
args.max_target_positions,
args.sample_break_mode,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
next_token_is_target=True,
)
return dataset

View File

@ -0,0 +1,32 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data import Dataset
class OffsetDataset(Dataset):
""" Wraps an existing dataset, but starts iterating from a particular offset """
def __init__(self, dataset, offset):
"""
Args:
dataset: Dataset to wrap
offset: An integer. offset from which to start iterating
"""
super().__init__()
assert len(dataset) >= offset
self.dataset = dataset
self.offset = offset
def __getitem__(self, i):
return self.dataset[i + self.offset]
def __len__(self):
return len(self.dataset) - self.offset

View File

@ -0,0 +1,85 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import numpy as np
import torch
from fairseq.data.indexed_dataset import SizedDataset
class TokenBlockDataset(SizedDataset):
"""Given a 1d tensor of tokens, this dataset will break tokens into blocks based on parameters. The blocks are
fetched from the original tensor so no additional memory is allocated"""
def __init__(self, tokens, block_size, sizes, offset=0, break_mode=None):
"""
Args:
tokens: torch tensor of tokens to break into blocks
block_size: An integer. the maximum size of each block (note this has no effect in 'eos' break mode)
sizes: A list of integers. sizes of sentences in the block. the sum of the sizes should add up to the
length of tokens
offset: An integer. rotates the tokens by this much before computing blocks. useful for language model targets
break_mode: A boolean if None/'none' then breaks tokens into equally sized blocks of size block_size
if 'complete' then breaks tokens into block sizes of up to block_size such that each block
contains complete sentences. block_size may be exceeded if some sentences exceed block_size
if 'eos' then each block contains a single sentence. does not respect block_size"""
super().__init__()
self.tokens = tokens
self.offset = offset
self.slice_indices = []
if break_mode is None or break_mode == 'none':
length = math.ceil(tokens.numel() / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
return (start, end)
self.slice_indices = [block_at(i) for i in np.arange(length)]
elif break_mode == 'complete':
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices.append((curr, curr + sz))
curr += sz
else:
raise Exception('invalid break_mode. Supported values: none, complete, eos')
self._sizes = np.array([e - s for s, e in self.slice_indices])
def _slice(self, s, e):
# this will copy only the first block if offset > 0, instead of all blocks if we just rotated
# the tensor with torch.cat()
if s < self.offset:
return torch.cat([self.tokens[s - self.offset:], self.tokens[s:e - self.offset]])
return self.tokens[s - self.offset:e - self.offset]
def __getitem__(self, i):
return self._slice(*self.slice_indices[i])
def __len__(self):
return len(self.slice_indices)

View File

@ -11,8 +11,7 @@ import os
from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import FairseqModel # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
@ -29,8 +28,8 @@ def register_model(name):
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, FairseqModel):
raise ValueError('Model ({}: {}) must extend FairseqModel'.format(name, cls.__name__))
if not issubclass(cls, BaseFairseqModel):
raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls
return cls

View File

@ -19,7 +19,7 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs):
def get_normalized_probs(self, net_output, log_probs, _):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()
if log_probs:

View File

@ -5,12 +5,78 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from . import FairseqDecoder, FairseqEncoder
class FairseqModel(nn.Module):
class BaseFairseqModel(nn.Module):
"""Base class for fairseq models."""
def __init__(self):
super().__init__()
self._is_generation_fast = False
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
pass
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample['target']
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and
its descendants.
Overrides the method in nn.Module; compared with that method this
additionally "upgrades" state_dicts from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict):
assert state_dict is not None
def do_upgrade(m):
if m != self and hasattr(m, 'upgrade_state_dict'):
m.upgrade_state_dict(state_dict)
self.apply(do_upgrade)
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
class FairseqModel(BaseFairseqModel):
"""Base class for encoder-decoder models."""
def __init__(self, encoder, decoder):
@ -27,13 +93,6 @@ class FairseqModel(nn.Module):
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self._is_generation_fast = False
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
pass
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
@ -44,13 +103,9 @@ class FairseqModel(nn.Module):
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs):
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample['target']
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
@ -60,44 +115,25 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and
its descendants.
Overrides the method in nn.Module; compared with that method this
additionally "upgrades" state_dicts from old checkpoints.
"""
state_dict = self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict)
class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models."""
def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
def forward(self, src_tokens, **unused):
return self.decoder(src_tokens)
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
def max_encoder_positions(self):
return self.max_decoder_positions()

View File

@ -11,10 +11,11 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, AdaptiveSoftmax
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, FairseqLanguageModel, register_model, \
register_model_architecture
@register_model('fconv')
@ -44,6 +45,8 @@ class FConvModel(FairseqModel):
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, default=0.5, metavar='D',
help='multiplies the result of the residual block by sqrt(value)')
parser.add_argument('--share-input-output-embed', action='store_true',
help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim'
@ -75,6 +78,7 @@ class FConvModel(FairseqModel):
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
normalization_constant=args.normalization_constant,
)
decoder = FConvDecoder(
dst_dict,
@ -85,17 +89,72 @@ class FConvModel(FairseqModel):
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed
share_embed=args.share_input_output_embed,
normalization_constant=args.normalization_constant,
)
return FConvModel(encoder, decoder)
@register_model('fconv_lm')
class FConvLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, default=0.5, metavar='D',
help='multiplies the result of the residual block by sqrt(value)')
@classmethod
def build_model(cls, args, dict, *_):
"""Build a new model instance."""
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
# make sure all arguments are present in older models
base_lm_architecture(args)
decoder = FConvDecoder(
dict,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=False,
positional_embeddings=False,
adaptive_softmax_cutoff=list(
map(int, args.adaptive_softmax_cutoff.split(','))) if args.adaptive_softmax_cutoff else None,
normalization_constant=args.normalization_constant,
)
return FConvLanguageModel(decoder)
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, embed_dict=None,
max_positions=1024, convolutions=((512, 3),) * 20, dropout=0.1):
max_positions=1024, convolutions=((512, 3),) * 20, dropout=0.1,
normalization_constant=0.5):
super().__init__(dictionary)
self.dropout = dropout
self.normalization_constant = normalization_constant
self.num_attention_layers = None
num_embeddings = len(dictionary)
@ -108,16 +167,24 @@ class FConvEncoder(FairseqEncoder):
max_positions,
embed_dim,
self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
left_pad=LEFT_PAD_SOURCE,
)
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
for (out_channels, kernel_size) in convolutions:
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.residuals = []
layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0:
residual_dim = out_channels
else:
residual_dim = layer_in_channels[-residual]
self.projections.append(Linear(residual_dim, out_channels)
if residual_dim != out_channels else None)
if kernel_size % 2 == 1:
padding = kernel_size // 2
else:
@ -126,7 +193,9 @@ class FConvEncoder(FairseqEncoder):
ConvTBC(in_channels, out_channels * 2, kernel_size,
dropout=dropout, padding=padding)
)
self.residuals.append(residual)
in_channels = out_channels
layer_in_channels.append(out_channels)
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens, src_lengths):
@ -146,9 +215,14 @@ class FConvEncoder(FairseqEncoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
residuals = [x]
# temporal convolutions
for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x)
for proj, conv, res_layer in zip(self.projections, self.convolutions, self.residuals):
if res_layer > 0:
residual = residuals[-res_layer]
residual = residual if proj is None else proj(residual)
else:
residual = None
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
@ -163,7 +237,10 @@ class FConvEncoder(FairseqEncoder):
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x)
x = F.glu(x, dim=2)
x = (x + residual) * math.sqrt(0.5)
if residual is not None:
x = (x + residual) * math.sqrt(self.normalization_constant)
residuals.append(x)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
@ -179,7 +256,7 @@ class FConvEncoder(FairseqEncoder):
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5)
y = (x + input_embedding) * math.sqrt(self.normalization_constant)
return {
'encoder_out': (x, y),
@ -192,8 +269,9 @@ class FConvEncoder(FairseqEncoder):
class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, bmm=None):
def __init__(self, conv_channels, embed_dim, normalization_constant=0.5, bmm=None):
super().__init__()
self.normalization_constant = normalization_constant
# projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size
@ -205,7 +283,7 @@ class AttentionLayer(nn.Module):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = (self.in_projection(x) + target_embedding) * math.sqrt(self.normalization_constant)
x = self.bmm(x, encoder_out[0])
# don't attend over padding
@ -233,7 +311,7 @@ class AttentionLayer(nn.Module):
x = x * (s * s.rsqrt())
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
x = (self.out_projection(x) + residual) * math.sqrt(self.normalization_constant)
return x, attn_scores
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
@ -245,14 +323,17 @@ class AttentionLayer(nn.Module):
class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512,
embed_dict=None, out_embed_dim=256,
def __init__(self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, share_embed=False):
attention=True, dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.dropout = dropout
self.normalization_constant = normalization_constant
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
if isinstance(attention, bool):
# expand True into [True, True, ...] and do the same with False
@ -271,45 +352,69 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions,
embed_dim,
padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
)
left_pad=LEFT_PAD_TARGET,
) if positional_embeddings else None
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.residuals = []
layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0:
residual_dim = out_channels
else:
residual_dim = layer_in_channels[-residual]
self.projections.append(Linear(residual_dim, out_channels)
if residual_dim != out_channels else None)
self.convolutions.append(
LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
padding=(kernel_size - 1), dropout=dropout)
)
self.attention.append(AttentionLayer(out_channels, embed_dim)
self.attention.append(AttentionLayer(out_channels, embed_dim, self.normalization_constant)
if attention[i] else None)
self.residuals.append(residual)
in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed:
assert out_embed_dim == embed_dim, \
"Shared embed weights implies same dimensions " \
" out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
self.fc3.weight = self.embed_tokens.weight
layer_in_channels.append(out_channels)
self.adaptive_softmax = None
self.fc2 = self.fc3 = None
if adaptive_softmax_cutoff is not None:
assert not share_embed
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff,
dropout=dropout)
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed:
assert out_embed_dim == embed_dim, \
"Shared embed weights implies same dimensions " \
" out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
self.fc3.weight = self.embed_tokens.weight
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None):
if encoder_out_dict is not None:
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
if self.embed_positions is not None:
pos_embed = self.embed_positions(prev_output_tokens, incremental_state)
else:
pos_embed = 0
# embed tokens and combine with positional embeddings
pos_embed = self.embed_positions(prev_output_tokens, incremental_state)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
x = self._embed_tokens(prev_output_tokens, incremental_state)
# embed tokens and combine with positional embeddings
x += pos_embed
x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x
@ -323,8 +428,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
residuals = [x]
for proj, conv, attention, res_layer in zip(self.projections, self.convolutions, self.attention,
self.residuals):
if res_layer > 0:
residual = residuals[-res_layer]
residual = residual if proj is None else proj(residual)
else:
residual = None
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x, incremental_state)
@ -344,18 +455,31 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self._transpose_if_training(x, incremental_state)
# residual
x = (x + residual) * math.sqrt(0.5)
if residual is not None:
x = (x + residual) * math.sqrt(self.normalization_constant)
residuals.append(x)
# T x B x C -> B x T x C
x = self._transpose_if_training(x, incremental_state)
# project back to size of vocabulary
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
# project back to size of vocabulary if not using adaptive softmax
if self.fc2 is not None and self.fc3 is not None:
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x, avg_attn_scores
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
else:
return super().get_normalized_probs(net_output, log_probs, sample)
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
@ -371,7 +495,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf')
def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
@ -413,6 +537,23 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x
def extend_conv_spec(convolutions):
"""
Extends convolutional spec that is a list of tuples of 2 or 3 parameters
(kernel size, dim size and optionally how many layers behind to look for residual)
to default the residual propagation param if it is not specified
"""
extended = []
for spec in convolutions:
if len(spec) == 3:
extended.append(spec)
elif len(spec) == 2:
extended.append(spec + (1,))
else:
raise Exception('invalid number of parameters in convolution spec ' + str(spec) + '. expected 2 or 3')
return tuple(extended)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal(m.weight, 0, 0.1)
@ -430,8 +571,8 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
def Linear(in_features, out_features, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
nn.init.normal(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features))
nn.init.constant(m.bias, 0)
return nn.utils.weight_norm(m)
@ -439,8 +580,8 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
nn.init.normal(m.weight, mean=0, std=std)
nn.init.constant(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
@ -449,11 +590,19 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
from fairseq.modules import ConvTBC
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
nn.init.normal(m.weight, mean=0, std=std)
nn.init.constant(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
@register_model_architecture('fconv_lm', 'fconv_lm')
def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128)
args.decoder_layers = getattr(args, 'decoder_layers', '[(1268, 4)] * 13')
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
@register_model_architecture('fconv', 'fconv')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
@ -485,7 +634,7 @@ def fconv_wmt_en_ro(args):
@register_model_architecture('fconv', 'fconv_wmt_en_de')
def fconv_wmt_en_de(args):
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
@ -499,8 +648,8 @@ def fconv_wmt_en_de(args):
@register_model_architecture('fconv', 'fconv_wmt_en_fr')
def fconv_wmt_en_fr(args):
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.data import consts
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
@ -117,7 +117,7 @@ class LSTMEncoder(FairseqEncoder):
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=LanguagePairDataset.LEFT_PAD_SOURCE,
left_pad_source=consts.LEFT_PAD_SOURCE,
pretrained_embed=None,
padding_value=0.,
):

View File

@ -11,7 +11,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import LanguagePairDataset
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
@ -108,7 +108,7 @@ class TransformerEncoder(FairseqEncoder):
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
left_pad=LEFT_PAD_SOURCE,
learned=args.encoder_learned_pos,
)
@ -169,7 +169,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
left_pad=LEFT_PAD_TARGET,
learned=args.decoder_learned_pos,
)
@ -181,7 +181,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal(self.embed_out, mean=0, std=embed_dim**-0.5)
nn.init.normal(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions
@ -363,7 +363,7 @@ class TransformerDecoderLayer(nn.Module):
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.normal(m.weight, mean=0, std=embedding_dim ** -0.5)
return m
@ -382,7 +382,7 @@ def Linear(in_features, out_features, bias=True):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.normal(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)

View File

@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply
@ -14,6 +15,7 @@ from .multihead_attention import MultiheadAttention
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [
'AdaptiveSoftmax',
'BeamableMM',
'ConvTBC',
'GradMultiply',

View File

@ -0,0 +1,121 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn.functional as F
from torch import nn
class AdaptiveSoftmax(nn.Module):
"""This is an implementation of the efficient softmax approximation for graphical processing units (GPU),
described in the paper "Efficient softmax approximation for GPUs" (http://arxiv.org/abs/1609.04309)."""
def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
output_dim = cutoff[0] + len(cutoff) - 1
self.vocab_size = vocab_size
self.cutoff = cutoff
self.dropout = dropout
self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList()
for i in range(len(cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** i, bias=False),
nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** i, cutoff[i + 1] - cutoff[i], bias=False)
)
)
def init_weights(m):
if hasattr(m, 'weight'):
nn.init.xavier_uniform(m.weight)
self.apply(init_weights)
def adapt_target(self, target):
"""In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the
vocabulary for all the examples.It is thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass."""
target = target.view(-1)
new_target = [target.clone()]
target_idxs = []
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - 1
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
new_target.append(target[mask].add(-self.cutoff[i]))
else:
target_idxs.append(None)
new_target.append(None)
return new_target, target_idxs
def forward(self, input, target):
""" accepts input (b x t x d) and target (b x t) and returns
2 lists: output for each cutoff section and new targets by cut off """
input = input.contiguous().view(-1, input.size(-1))
input = F.dropout(input, p=self.dropout, training=self.training)
new_target, target_idxs = self.adapt_target(target)
output = [self.head(input)]
for i in range(len(target_idxs)):
if target_idxs[i] is not None:
output.append(self.tail[i](input.index_select(0, target_idxs[i])))
else:
output.append(None)
return output, new_target
def get_log_prob(self, input, target):
"""computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors"""
bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim)
if target is not None:
_, target_idxs = self.adapt_target(target)
else:
target_idxs = None
head_y = self.head(input)
log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
end = self.cutoff[i + 1]
if target_idxs is None:
tail_out = log_probs[:, start:end]
tail_out.copy_(self.tail[i](input))
log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None])
elif target_idxs[i] is not None:
idxs = target_idxs[i]
tail_out = log_probs[idxs, start:end]
tail_out.copy_(self.tail[i](input[idxs]))
log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None])
log_probs = log_probs.view(bsz, length, -1)
return log_probs

View File

@ -43,6 +43,13 @@ def _eval_float_list(x):
return [float(x)]
def get_eval_lm_parser():
parser = get_parser('Evaluate Language Model')
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
@ -102,7 +109,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='target language')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
@ -110,6 +117,12 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='Used only for LM datasets. If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end of sentence, but may include '
'multiple sentences per sample. If set to "eos", includes only one sentence per sample')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
@ -216,10 +229,24 @@ def add_checkpoint_args(parser):
return group
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE', action='append',
help='path(s) to model file(s)')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
add_common_eval_args(group)
group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
@ -230,15 +257,12 @@ def add_generation_args(parser):
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
@ -247,8 +271,6 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',

View File

@ -9,7 +9,6 @@ import math
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.models import FairseqIncrementalDecoder

View File

@ -28,8 +28,6 @@ class SequenceScorer(object):
if timer is not None:
timer.start()
pos_scores, attn = self.score(s)
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data):
# remove padding from ref
src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
@ -37,8 +35,11 @@ class SequenceScorer(object):
tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
if attn is not None:
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{
'tokens': ref,
'score': score_i,
@ -46,6 +47,8 @@ class SequenceScorer(object):
'alignment': alignment,
'positional_scores': pos_scores_i,
}]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator
yield id, src, ref, hypos
@ -59,16 +62,10 @@ class SequenceScorer(object):
for model in self.models:
with utils.maybe_no_grad():
model.eval()
encoder_out = model.encoder(
net_input['src_tokens'],
net_input['src_lengths'],
)
decoder_out = model.decoder(
net_input['prev_output_tokens'],
encoder_out,
)
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data
if avg_probs is None:
avg_probs = probs
else:

View File

@ -68,7 +68,7 @@ def load_model_state(filename, model):
return None, [], None
state = torch.load(filename)
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
model.upgrade_state_dict(state['model'])
# load model parameters
try:
@ -134,7 +134,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
{'arg_name': arg} -- to override model args that were used during model
training
"""
from fairseq import data, models
from fairseq import models
from fairseq.data import data_utils
# load model architectures and weights
states = []
@ -150,7 +151,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
src_dict, dst_dict = data_utils.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []

View File

@ -8,7 +8,8 @@
import torch
from fairseq import bleu, data, options, progress_bar, tokenizer, utils
from fairseq import bleu, options, progress_bar, tokenizer, utils
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
@ -27,23 +28,7 @@ def main(args):
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
if args.replace_unk is None:
dataset = data.load_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
else:
dataset = data.load_raw_text_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
dataset = data_loaders.load_dataset(args, [args.gen_subset], args.replace_unk is not None)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
@ -75,7 +60,7 @@ def main(args):
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
# Initialize generator
gen_timer = StopwatchMeter()

View File

@ -13,7 +13,8 @@ from collections import namedtuple
from torch.autograd import Variable
from fairseq import options, tokenizer, utils
from fairseq.data import LanguagePairDataset
from fairseq.data.data_utils import collate_tokens
from fairseq.data.consts import LEFT_PAD_SOURCE
from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
@ -41,9 +42,8 @@ def make_batches(lines, batch_size, src_dict):
batches = np.array_split(indices, num_batches)
for batch_idxs in batches:
batch_toks = [tokens[i] for i in batch_idxs]
batch_toks = LanguagePairDataset.collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(),
LanguagePairDataset.LEFT_PAD_SOURCE,
move_eos_to_beginning=False)
batch_toks = collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(), LEFT_PAD_SOURCE,
move_eos_to_beginning=False)
yield Batch(
srcs=[lines[i] for i in batch_idxs],
tokens=batch_toks,

View File

@ -12,7 +12,7 @@ from itertools import zip_longest
import os
import shutil
from fairseq import dictionary, indexed_dataset
from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
@ -54,33 +54,53 @@ def main(args):
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
return d
def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += f'.{lang}'
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path('dict', lang) + '.txt'
def dataset_dest_path(output_prefix, lang, extension):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}.{extension}'
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = build_dictionary([
'{}.{}'.format(args.trainpref, lang)
src_dict = build_dictionary(set([
train_path(lang)
for lang in [args.source_lang, args.target_lang]
])
]))
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.source_lang)])
src_dict = build_dictionary([train_path(args.source_lang)])
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)])
tgt_dict = build_dictionary([train_path(args.target_lang)])
src_dict.finalize(
threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
src_dict.save(dict_path(args.source_lang))
if target:
if not args.joined_dictionary:
tgt_dict.finalize(
@ -88,36 +108,31 @@ def main(args):
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)))
tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
ds = indexed_dataset.IndexedDatasetBuilder(
'{}/{}.{}-{}.{}.bin'.format(args.destdir, output_prefix, args.source_lang,
args.target_lang, lang)
)
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))
def consumer(tensor):
ds.add_item(tensor)
input_file = '{}.{}'.format(input_prefix, lang)
input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
res = Tokenizer.binarize(input_file, dict, consumer)
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'],
100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize('{}/{}.{}-{}.{}.idx'.format(
args.destdir, output_prefix,
args.source_lang, args.target_lang, lang))
ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
if output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw':
# Copy original text file to destination folder
output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang))
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
output_text_file = dest_path(output_prefix, lang)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(args, make_dataset, lang):
if args.trainpref:
@ -139,10 +154,10 @@ def main(args):
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)))
src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {}
with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file:

View File

@ -11,7 +11,8 @@ import argparse
import os
import sys
from fairseq import bleu, dictionary, tokenizer
from fairseq import bleu, tokenizer
from fairseq.data import dictionary
def main():

View File

@ -84,6 +84,8 @@ class TestBinaries(unittest.TestCase):
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--source-lang', 'in',
'--target-lang', 'out',
],
)
train.main(train_args)

71
tests/test_train.py Normal file
View File

@ -0,0 +1,71 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
from unittest.mock import MagicMock, patch
import train
def mock_trainer(epoch, num_updates):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch}
trainer.get_num_updates.return_value = num_updates
return trainer
def mock_loader(length):
ds = MagicMock()
ds.__len__.return_value = length
loader = MagicMock()
loader.__next__.return_value = ds
return loader
class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.patches = {
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),
'os.path.isfile': MagicMock(return_value=True),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
trainer = mock_trainer(2, 200)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 50)
self.assertNotIsInstance(ds, MagicMock)
def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0)
loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
def tearDown(self):
patch.stopall()
if __name__ == '__main__':
unittest.main()

View File

@ -8,7 +8,9 @@
import torch
from torch.autograd import Variable
from fairseq import data, dictionary, utils
from fairseq.data.language_pair_dataset import collate
from fairseq import utils
from fairseq.data import dictionary
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
@ -45,10 +47,11 @@ def dummy_dataloader(
dataset,
batch_size=batch_size,
collate_fn=(
lambda samples: data.LanguagePairDataset.collate(
lambda samples: collate(
samples,
padding_idx,
eos_idx,
has_target=True,
)
),
)
@ -134,7 +137,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
return Variable(probs), Variable(attn)
def get_normalized_probs(self, net_output, log_probs):
def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly
probs = net_output[0]
if log_probs:

View File

@ -11,7 +11,8 @@ import os
import math
import torch
from fairseq import criterions, data, models, options, progress_bar
from fairseq import criterions, models, options, progress_bar
from fairseq.data import data_utils, data_loaders, OffsetDataset
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
@ -72,10 +73,10 @@ def main(args):
)
# Load the latest checkpoint if one is available
epoch = load_checkpoint(args, trainer, train_dataloader)
epoch, next_ds = load_checkpoint(args, trainer, train_dataloader)
# Send a dummy batch to warm the caching allocator
dummy_batch = data.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
dummy_batch = data_utils.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
@ -87,7 +88,7 @@ def main(args):
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next(train_dataloader), epoch, dataset)
train(args, trainer, next_ds, epoch, dataset)
if epoch % args.validate_interval == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch)
@ -100,19 +101,14 @@ def main(args):
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1
next_ds = next(train_dataloader)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def load_dataset(args, splits):
if data.has_binary_files(args.data, splits):
dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
is_raw = not data_utils.has_binary_files(args.data, splits)
dataset = data_loaders.load_dataset(args, splits, is_raw)
return dataset
@ -311,17 +307,29 @@ def load_checkpoint(args, trainer, train_dataloader):
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
epoch = 1
ds = None
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer_updates = trainer.get_num_updates()
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(checkpoint_path, epoch, trainer_updates))
trainer.lr_step(epoch)
updates = 0
for i in range(epoch):
_ = next(train_dataloader)
epoch += 1
ds = next(train_dataloader)
updates += len(ds)
if ds is not None and updates > trainer_updates:
ds = OffsetDataset(ds, updates - trainer_updates)
else:
ds = next(train_dataloader)
epoch += 1
trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0))
return epoch
return epoch, ds or next(train_dataloader)
if __name__ == '__main__':