fbshipit-source-id: 6a835d32f9dc5e0de118f1b46d365d0e0cc85e11

This commit is contained in:
myleott 2018-09-30 12:25:59 -07:00
parent 864b89d044
commit f8377a704c
5 changed files with 236 additions and 15 deletions

126
fairseq/data/noising.py Normal file
View File

@ -0,0 +1,126 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import numpy as np
class WordNoising(object):
"""Generate a noisy version of a sentence, without changing words themselves."""
def __init__(self, dictionary, bpe_cont_marker="@@"):
self.dictionary = dictionary
self.bpe_end = np.array([
not self.dictionary[i].endswith(bpe_cont_marker)
for i in range(len(self.dictionary))
])
def noising(self, x, lengths, noising_prob=0.0):
raise NotImplementedError()
def _get_bpe_word_idx(self, x):
# x: (T x B)
bpe_end = self.bpe_end[x]
# do a reduce front sum to generate word ids
word_idx = bpe_end[::-1].cumsum(0)[::-1]
word_idx = word_idx.max(0)[None, :] - word_idx
return word_idx
class WordDropout(WordNoising):
"""Randomly drop input words. If not passing blank_idx (default is None),
then dropped words will be removed. Otherwise, it will be replaced by the
blank_idx."""
def __init__(self, dictionary):
super().__init__(dictionary)
def noising(self, x, lengths, dropout_prob=0.1, blank_idx=None):
# x: (T x B), lengths: B
if dropout_prob == 0:
return x, lengths
assert 0 < dropout_prob < 1
# be sure to drop entire words
word_idx = self._get_bpe_word_idx(x)
sentences = []
modified_lengths = []
for i in range(lengths.size(0)):
# Since dropout probabilities need to apply over non-pad tokens,
# it is not trivial to generate the keep mask without consider
# input lengths; otherwise, this could be done outside the loop
keep = np.random.rand(lengths[i] - 1) >= dropout_prob
# ith example: [x0, x1, ..., eos, pad, ..., pad]
assert x[lengths[i] - 1, i] == self.dictionary.eos()
words = x[:lengths[i], i].tolist()
# TODO: speed up the following loop
# drop words from the input according to keep
new_s = [
w if keep[word_idx[j, i]] else blank_idx
for j, w in enumerate(words)
]
new_s = [w for w in new_s if w is not None]
# we need to have at least one word in the sentence (more than the
# start / end sentence symbols)
if len(new_s) == 1:
new_s.append(words[np.random.randint(0, len(words))])
assert (
len(new_s) >= 2
and new_s[-1] == self.dictionary.eos()
), "New sentence is invalid."
sentences.append(new_s)
modified_lengths.append(len(new_s))
# re-construct input
modified_lengths = torch.LongTensor(modified_lengths)
modified_x = torch.LongTensor(
modified_lengths.max(),
modified_lengths.size(0)
).fill_(self.dictionary.pad())
for i in range(modified_lengths.size(0)):
modified_x[:modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
return modified_x, modified_lengths
class WordShuffle(WordNoising):
"""Shuffle words by no more than k positions."""
def __init__(self, dictionary):
super().__init__(dictionary)
def noising(self, x, lengths, max_shuffle_distance=3):
# x: (T x B), lengths: B
if max_shuffle_distance == 0:
return x, lengths
# max_shuffle_distance < 1 will return the same sequence
assert max_shuffle_distance > 1
# define noise word scores
noise = np.random.uniform(
0,
max_shuffle_distance,
size=(x.size(0) - 1, x.size(1)),
)
noise[0] = -1 # do not move start sentence symbol
# be sure to shuffle entire words
word_idx = self._get_bpe_word_idx(x)
x2 = x.clone()
for i in range(lengths.size(0)):
# generate a random permutation
scores = word_idx[:lengths[i] - 1, i] + noise[word_idx[:lengths[i] - 1, i], i]
# ensure no reordering inside a word
scores += 1e-6 * np.arange(lengths[i] - 1)
permutation = scores.argsort()
# shuffle words
x2[:lengths[i] - 1, i].copy_(
x2[:lengths[i] - 1, i][torch.from_numpy(permutation)]
)
return x2, lengths

View File

@ -133,7 +133,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'

View File

@ -480,21 +480,17 @@ class SequenceGenerator(object):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
avg_probs = None
log_probs = []
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
if avg_probs is None:
avg_probs = probs
else:
avg_probs.add_(probs)
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=True)
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs.div_(len(self.models))
avg_probs.log_()
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models))
if avg_attn is not None:
avg_attn.div_(len(self.models))
return avg_probs, avg_attn

View File

@ -45,7 +45,15 @@ class Trainer(object):
else:
self._model = model.cuda()
# initialize meters
self._dummy_batch = dummy_batch
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._wrapped_model = None
self.init_meters(args)
def init_meters(self, args):
self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter()
@ -63,11 +71,6 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
self._dummy_batch = dummy_batch
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._wrapped_model = None
@property
def model(self):

94
tests/test_noising.py Normal file
View File

@ -0,0 +1,94 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import unittest
from fairseq.data import data_utils, Dictionary, noising
class TestDataNoising(unittest.TestCase):
def _get_test_data(self):
vocab = Dictionary()
vocab.add_symbol("he@@")
vocab.add_symbol("llo")
vocab.add_symbol("how")
vocab.add_symbol("are")
vocab.add_symbol("y@@")
vocab.add_symbol("ou")
vocab.add_symbol("n@@")
vocab.add_symbol("ew")
vocab.add_symbol("or@@")
vocab.add_symbol("k")
src_tokens = [
["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
["how", "are", "y@@", "ou"],
]
src_len = [len(x) for x in src_tokens]
x = torch.LongTensor(len(src_tokens), max(src_len) + 1).fill_(vocab.pad())
for i in range(len(src_tokens)):
for j in range(len(src_tokens[i])):
x[i][j] = vocab.index(src_tokens[i][j])
x[i][j + 1] = vocab.eos()
x = x.transpose(1, 0)
return vocab, x, torch.LongTensor([i + 1 for i in src_len])
def test_word_dropout(self):
vocab, x, x_len = self._get_test_data()
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
# Expect only the first word (2 bpe tokens) of the first example
# was dropped out
self.assertEqual(x_len[0] - 2, l_noised[0])
for i in range(l_noised[0]):
self.assertEqual(x_noised[i][0], x[i+2][0])
def test_word_blank(self):
vocab, x, x_len = self._get_test_data()
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
# Expect only the first word (2 bpe tokens) of the first example
# was blanked out
self.assertEqual(x_len[0], l_noised[0])
for i in range(l_noised[0]):
if i < 2:
self.assertEqual(x_noised[i][0], vocab.unk())
else:
self.assertEqual(x_noised[i][0], x[i][0])
def test_word_shuffle(self):
vocab, x, x_len = self._get_test_data()
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
for i in range(len(x_len)):
for j in range(x_len[i]):
self.assertEqual(x[j][i], x_noised[j][i])
self.assertEqual(x_len[0], l_noised[0])
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
# Expect the second example has the last three tokens shuffled
# 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word
for i in range(x_len[0]):
self.assertEqual(x[i][0], x_noised[i][0])
shuffle_map = {0: 0, 1: 3, 2: 1, 3: 2}
for k, v in shuffle_map.items():
self.assertEqual(x[k][1], x_noised[v][1])
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
if __name__ == '__main__':
unittest.main()