Pass encoder_input to generator, rather than src_tokens/src_lengths.

This commit is contained in:
Stephen Roller 2018-09-08 17:53:10 -04:00 committed by Myle Ott
parent 8bd8ec8fa8
commit bfeb773214
3 changed files with 42 additions and 21 deletions

View File

@ -78,13 +78,18 @@ class SequenceGenerator(object):
if 'net_input' not in s:
continue
input = s['net_input']
srclen = input['src_tokens'].size(1)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in input.items()
if k != 'prev_output_tokens'
}
srclen = encoder_input['src_tokens'].size(1)
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
input['src_tokens'],
input['src_lengths'],
encoder_input,
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
@ -97,12 +102,23 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
Args:
encoder_input: dictionary containing the inputs to
model.encoder.forward
beam_size: int overriding the beam size. defaults to
self.beam_size
max_len: maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens
"""
with torch.no_grad():
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""See generate"""
src_tokens = encoder_input['src_tokens']
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
@ -121,10 +137,10 @@ class SequenceGenerator(object):
incremental_states[model] = None
# compute the encoder output for each beam
encoder_out = model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen),
src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
)
encoder_out = model.encoder(**encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
encoder_outs.append(encoder_out)
# initialize buffers

View File

@ -145,9 +145,9 @@ def main(args):
tokens = tokens.cuda()
lengths = lengths.cuda()
encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
translations = translator.generate(
tokens,
lengths,
encoder_input,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)

View File

@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase):
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])
self.encoder_input = {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths,
}
args = argparse.Namespace()
unk = 0.
@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase):
[self.model], self.tgt_dict,
beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
)
hypos = generator.generate(self.src_tokens, self.src_lengths)
encoder_input = {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}
hypos = generator.generate(encoder_input)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])