Fix generation when vocabulary is small relative to beam size (fixes #7)

This commit is contained in:
Myle Ott 2017-09-26 11:05:51 -07:00 committed by GitHub
parent 2d3161daa8
commit 03c4a71698
5 changed files with 40 additions and 26 deletions

View File

@ -28,7 +28,7 @@ class FConvModel(nn.Module):
decoder_out = self.decoder(input_tokens, input_positions, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def make_generation_fast_(self, beam_size, use_beamable_mm=False):
def make_generation_fast_(self, use_beamable_mm=False):
"""Optimize model for faster generation.
Optimizations include:
@ -54,7 +54,7 @@ class FConvModel(nn.Module):
# use BeamableMM in attention layers
if use_beamable_mm:
self.decoder._use_beamable_mm(beam_size)
self.decoder._use_beamable_mm()
def train(mode):
if mode:
@ -243,14 +243,14 @@ class Decoder(nn.Module):
context += conv.kernel_size[0] - 1
return context
def incremental_inference(self):
def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call
model.decoder.clear_incremental_state.
model.decoder.start_fresh_sequence.
Usage:
```
@ -263,18 +263,19 @@ class Decoder(nn.Module):
"""
class IncrementalInference(object):
def __init__(self, decoder):
def __init__(self, decoder, beam_size):
self.decoder = decoder
self.beam_size = beam_size
def __enter__(self):
self.decoder._start_incremental_inference()
self.decoder._start_incremental_inference(self.beam_size)
def __exit__(self, *args):
self.decoder._stop_incremental_inference()
return IncrementalInference(self)
return IncrementalInference(self, beam_size)
def _start_incremental_inference(self):
def _start_incremental_inference(self, beam_size):
assert not self._is_inference_incremental, \
'already performing incremental inference'
self._is_inference_incremental = True
@ -287,7 +288,7 @@ class Decoder(nn.Module):
self.forward = self._incremental_forward
# start a fresh sequence
self.clear_incremental_state()
self.start_fresh_sequence(beam_size)
def _stop_incremental_inference(self):
# restore original forward and convolution layers
@ -348,17 +349,21 @@ class Decoder(nn.Module):
return x, avg_attn_scores
def clear_incremental_state(self):
def start_fresh_sequence(self, beam_size=None):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if self._is_inference_incremental:
self.prev_state = None
for conv in self.convolutions:
conv.clear_buffer()
for attn in self.attention:
if isinstance(attn.bmm, BeamableMM):
attn.bmm.set_beam_size(beam_size)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
@ -373,9 +378,9 @@ class Decoder(nn.Module):
for conv in self.convolutions:
conv.reorder_buffer(new_order)
def _use_beamable_mm(self, beam_size):
def _use_beamable_mm(self):
"""Replace torch.bmm with BeamableMM in attention layers."""
beamable_mm = BeamableMM(beam_size)
beamable_mm = BeamableMM()
for attn in self.attention:
attn.bmm = beamable_mm

View File

@ -6,9 +6,9 @@
# can be found in the PATENTS file in the same directory.
#
from .beamable_mm import *
from .linearized_convolution import *
from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC
from .linearized_convolution import LinearizedConvolution
__all__ = [
'BeamableMM', 'LinearizedConvolution', 'ConvTBC',

View File

@ -18,16 +18,16 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
"""
def __init__(self, beam_size):
def __init__(self):
super(BeamableMM, self).__init__()
self.beam_size = beam_size
self.beam_size = None
def forward(self, input1, input2):
if (
not self.training and # test mode
self.beam_size > 0 and # beam size is set
input1.dim() == 3 and # only support batched input
input1.size(1) == 1 # single time step update
not self.training and # test mode
self.beam_size is not None and # beam size is set
input1.dim() == 3 and # only support batched input
input1.size(1) == 1 # single time step update
):
bsz, beam = input1.size(0), self.beam_size
@ -45,3 +45,6 @@ class BeamableMM(nn.Module):
return output.view(bsz, 1, -1)
else:
return input1.bmm(input2)
def set_beam_size(self, beam_size):
self.beam_size = beam_size

View File

@ -87,13 +87,16 @@ class SequenceGenerator(object):
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
bsz = src_tokens.size(0)
beam_size = beam_size if beam_size is not None else self.beam_size
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, len(self.dict) - 1)
encoder_outs = []
for model in self.models:
model.eval()
model.decoder.clear_incremental_state() # start a fresh sequence
model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence
# compute the encoder output and expand to beam size
encoder_out = model.encoder(src_tokens, src_positions)
@ -172,7 +175,7 @@ class SequenceGenerator(object):
sents_seen.add(sent)
def get_hypo():
hypo = tokens[idx, 1:step+2].clone()
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
hypo[step] = self.eos
alignment = align[idx, 1:step+2].clone()
return {
@ -219,6 +222,7 @@ class SequenceGenerator(object):
else:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
# record alignment to source tokens, based on attention
_ignore_scores = buffer('_ignore_scores', type_of=scores)
@ -229,7 +233,9 @@ class SequenceGenerator(object):
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
cand_beams = buffer('cand_beams')
probs.view(bsz, -1).topk(cand_size, out=(cand_scores, cand_indices))
probs.view(bsz, -1).topk(
min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices))
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
@ -256,7 +262,7 @@ class SequenceGenerator(object):
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask')
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets,
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets[:eos_mask.size(1)],
out=active_mask)
# get the top beam_size active hypotheses, which are just the hypos

View File

@ -47,7 +47,7 @@ def main():
# Optimize model for generation
for model in models:
model.make_generation_fast_(args.beam, not args.no_beamable_mm)
model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,