mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-11 14:55:11 +03:00
Fix generation when vocabulary is small relative to beam size (fixes #7)
This commit is contained in:
parent
2d3161daa8
commit
03c4a71698
@ -28,7 +28,7 @@ class FConvModel(nn.Module):
|
||||
decoder_out = self.decoder(input_tokens, input_positions, encoder_out)
|
||||
return decoder_out.view(-1, decoder_out.size(-1))
|
||||
|
||||
def make_generation_fast_(self, beam_size, use_beamable_mm=False):
|
||||
def make_generation_fast_(self, use_beamable_mm=False):
|
||||
"""Optimize model for faster generation.
|
||||
|
||||
Optimizations include:
|
||||
@ -54,7 +54,7 @@ class FConvModel(nn.Module):
|
||||
|
||||
# use BeamableMM in attention layers
|
||||
if use_beamable_mm:
|
||||
self.decoder._use_beamable_mm(beam_size)
|
||||
self.decoder._use_beamable_mm()
|
||||
|
||||
def train(mode):
|
||||
if mode:
|
||||
@ -243,14 +243,14 @@ class Decoder(nn.Module):
|
||||
context += conv.kernel_size[0] - 1
|
||||
return context
|
||||
|
||||
def incremental_inference(self):
|
||||
def incremental_inference(self, beam_size=None):
|
||||
"""Context manager for incremental inference.
|
||||
|
||||
This provides an optimized forward pass for incremental inference
|
||||
(i.e., it predicts one time step at a time). If the input order changes
|
||||
between time steps, call model.decoder.reorder_incremental_state to
|
||||
update the relevant buffers. To generate a fresh sequence, first call
|
||||
model.decoder.clear_incremental_state.
|
||||
model.decoder.start_fresh_sequence.
|
||||
|
||||
Usage:
|
||||
```
|
||||
@ -263,18 +263,19 @@ class Decoder(nn.Module):
|
||||
"""
|
||||
class IncrementalInference(object):
|
||||
|
||||
def __init__(self, decoder):
|
||||
def __init__(self, decoder, beam_size):
|
||||
self.decoder = decoder
|
||||
self.beam_size = beam_size
|
||||
|
||||
def __enter__(self):
|
||||
self.decoder._start_incremental_inference()
|
||||
self.decoder._start_incremental_inference(self.beam_size)
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.decoder._stop_incremental_inference()
|
||||
|
||||
return IncrementalInference(self)
|
||||
return IncrementalInference(self, beam_size)
|
||||
|
||||
def _start_incremental_inference(self):
|
||||
def _start_incremental_inference(self, beam_size):
|
||||
assert not self._is_inference_incremental, \
|
||||
'already performing incremental inference'
|
||||
self._is_inference_incremental = True
|
||||
@ -287,7 +288,7 @@ class Decoder(nn.Module):
|
||||
self.forward = self._incremental_forward
|
||||
|
||||
# start a fresh sequence
|
||||
self.clear_incremental_state()
|
||||
self.start_fresh_sequence(beam_size)
|
||||
|
||||
def _stop_incremental_inference(self):
|
||||
# restore original forward and convolution layers
|
||||
@ -348,17 +349,21 @@ class Decoder(nn.Module):
|
||||
|
||||
return x, avg_attn_scores
|
||||
|
||||
def clear_incremental_state(self):
|
||||
def start_fresh_sequence(self, beam_size=None):
|
||||
"""Clear all state used for incremental generation.
|
||||
|
||||
**For incremental inference only**
|
||||
|
||||
This should be called before generating a fresh sequence.
|
||||
beam_size is required if using BeamableMM.
|
||||
"""
|
||||
if self._is_inference_incremental:
|
||||
self.prev_state = None
|
||||
for conv in self.convolutions:
|
||||
conv.clear_buffer()
|
||||
for attn in self.attention:
|
||||
if isinstance(attn.bmm, BeamableMM):
|
||||
attn.bmm.set_beam_size(beam_size)
|
||||
|
||||
def reorder_incremental_state(self, new_order):
|
||||
"""Reorder buffered internal state (for incremental generation).
|
||||
@ -373,9 +378,9 @@ class Decoder(nn.Module):
|
||||
for conv in self.convolutions:
|
||||
conv.reorder_buffer(new_order)
|
||||
|
||||
def _use_beamable_mm(self, beam_size):
|
||||
def _use_beamable_mm(self):
|
||||
"""Replace torch.bmm with BeamableMM in attention layers."""
|
||||
beamable_mm = BeamableMM(beam_size)
|
||||
beamable_mm = BeamableMM()
|
||||
for attn in self.attention:
|
||||
attn.bmm = beamable_mm
|
||||
|
||||
|
@ -6,9 +6,9 @@
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
#
|
||||
|
||||
from .beamable_mm import *
|
||||
from .linearized_convolution import *
|
||||
from .beamable_mm import BeamableMM
|
||||
from .conv_tbc import ConvTBC
|
||||
from .linearized_convolution import LinearizedConvolution
|
||||
|
||||
__all__ = [
|
||||
'BeamableMM', 'LinearizedConvolution', 'ConvTBC',
|
||||
|
@ -18,16 +18,16 @@ class BeamableMM(nn.Module):
|
||||
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
|
||||
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
|
||||
"""
|
||||
def __init__(self, beam_size):
|
||||
def __init__(self):
|
||||
super(BeamableMM, self).__init__()
|
||||
self.beam_size = beam_size
|
||||
self.beam_size = None
|
||||
|
||||
def forward(self, input1, input2):
|
||||
if (
|
||||
not self.training and # test mode
|
||||
self.beam_size > 0 and # beam size is set
|
||||
input1.dim() == 3 and # only support batched input
|
||||
input1.size(1) == 1 # single time step update
|
||||
not self.training and # test mode
|
||||
self.beam_size is not None and # beam size is set
|
||||
input1.dim() == 3 and # only support batched input
|
||||
input1.size(1) == 1 # single time step update
|
||||
):
|
||||
bsz, beam = input1.size(0), self.beam_size
|
||||
|
||||
@ -45,3 +45,6 @@ class BeamableMM(nn.Module):
|
||||
return output.view(bsz, 1, -1)
|
||||
else:
|
||||
return input1.bmm(input2)
|
||||
|
||||
def set_beam_size(self, beam_size):
|
||||
self.beam_size = beam_size
|
||||
|
@ -87,13 +87,16 @@ class SequenceGenerator(object):
|
||||
|
||||
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
|
||||
bsz = src_tokens.size(0)
|
||||
beam_size = beam_size if beam_size is not None else self.beam_size
|
||||
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
|
||||
|
||||
# the max beam size is the dictionary size - 1, since we never select pad
|
||||
beam_size = beam_size if beam_size is not None else self.beam_size
|
||||
beam_size = min(beam_size, len(self.dict) - 1)
|
||||
|
||||
encoder_outs = []
|
||||
for model in self.models:
|
||||
model.eval()
|
||||
model.decoder.clear_incremental_state() # start a fresh sequence
|
||||
model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence
|
||||
|
||||
# compute the encoder output and expand to beam size
|
||||
encoder_out = model.encoder(src_tokens, src_positions)
|
||||
@ -172,7 +175,7 @@ class SequenceGenerator(object):
|
||||
sents_seen.add(sent)
|
||||
|
||||
def get_hypo():
|
||||
hypo = tokens[idx, 1:step+2].clone()
|
||||
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
|
||||
hypo[step] = self.eos
|
||||
alignment = align[idx, 1:step+2].clone()
|
||||
return {
|
||||
@ -219,6 +222,7 @@ class SequenceGenerator(object):
|
||||
else:
|
||||
# make probs contain cumulative scores for each hypothesis
|
||||
probs.add_(scores.view(-1, 1))
|
||||
probs[:, self.pad] = -math.inf # never select pad
|
||||
|
||||
# record alignment to source tokens, based on attention
|
||||
_ignore_scores = buffer('_ignore_scores', type_of=scores)
|
||||
@ -229,7 +233,9 @@ class SequenceGenerator(object):
|
||||
cand_scores = buffer('cand_scores', type_of=scores)
|
||||
cand_indices = buffer('cand_indices')
|
||||
cand_beams = buffer('cand_beams')
|
||||
probs.view(bsz, -1).topk(cand_size, out=(cand_scores, cand_indices))
|
||||
probs.view(bsz, -1).topk(
|
||||
min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
|
||||
out=(cand_scores, cand_indices))
|
||||
torch.div(cand_indices, self.vocab_size, out=cand_beams)
|
||||
cand_indices.fmod_(self.vocab_size)
|
||||
|
||||
@ -256,7 +262,7 @@ class SequenceGenerator(object):
|
||||
# and values < cand_size indicate candidate active hypos.
|
||||
# After, the min values per row are the top candidate active hypos
|
||||
active_mask = buffer('active_mask')
|
||||
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets,
|
||||
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets[:eos_mask.size(1)],
|
||||
out=active_mask)
|
||||
|
||||
# get the top beam_size active hypotheses, which are just the hypos
|
||||
|
@ -47,7 +47,7 @@ def main():
|
||||
|
||||
# Optimize model for generation
|
||||
for model in models:
|
||||
model.make_generation_fast_(args.beam, not args.no_beamable_mm)
|
||||
model.make_generation_fast_(not args.no_beamable_mm)
|
||||
|
||||
# Initialize generator
|
||||
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,
|
||||
|
Loading…
Reference in New Issue
Block a user