Factor out search logic in SequenceGenerator

This commit is contained in:
Myle Ott 2018-08-09 13:19:14 -04:00
parent 75e12a27fb
commit ef43da72d3
4 changed files with 210 additions and 70 deletions

165
fairseq/search.py Normal file
View File

@ -0,0 +1,165 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class Search(object):
def __init__(self, tgt_dict):
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.scores_buf = None
self.indices_buf = None
self.beams_buf = None
def _init_buffers(self, t):
if self.scores_buf is None:
self.scores_buf = t.new()
self.indices_buf = torch.LongTensor().to(device=t.device)
self.beams_buf = torch.LongTensor().to(device=t.device)
def step(self, step, lprobs, scores, beam_size):
"""Take a single search step.
Args:
step: the current search step, starting at 0
lprobs: (bsz x input_beam_size x vocab_size)
the model's log-probabilities over the vocabulary at the current step
scores: (bsz x input_beam_size x step)
the historical model scores of each hypothesis up to this point
Return: A tuple of (scores, indices, beams) where:
scores: (bsz x output_beam_size)
the scores of the chosen elements; output_beam_size can be
larger than input_beam_size, e.g., we may return
2*input_beam_size to account for EOS
indices: (bsz x output_beam_size)
the indices of the chosen elements
beams: (bsz x output_beam_size)
the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
"""
raise NotImplementedError
class BeamSearch(Search):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
torch.topk(
lprobs.view(bsz, -1),
k=min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
out=(self.scores_buf, self.indices_buf),
)
torch.div(self.indices_buf, vocab_size, out=self.beams_buf)
self.indices_buf.fmod_(vocab_size)
return self.scores_buf, self.indices_buf, self.beams_buf
class Sampling(Search):
def __init__(self, tgt_dict, sampling_topk=-1, sampling_temperature=1.):
super().__init__(tgt_dict)
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
# we exclude the first two vocab items, one of which is pad
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
lprobs_nopad = lprobs[:, :, 2:]
# only sample from top-k candidates
if self.sampling_topk > 0:
lprobs_nopad, topk_indices = lprobs_nopad.topk(self.sampling_topk)
# sampling temperature
if self.sampling_temperature != 1.:
lprobs_nopad = lprobs_nopad.div_(self.sampling_temperature)
# sample
probs_nopad = lprobs_nopad.exp_()
if step == 0:
self.indices_buf = torch.multinomial(
probs_nopad.view(bsz, -1),
beam_size,
replacement=True,
out=self.indices_buf,
).view(bsz, beam_size)
else:
self.indices_buf = torch.multinomial(
probs_nopad.view(bsz * beam_size, -1),
1,
replacement=True,
out=self.indices_buf,
).view(bsz, beam_size)
if step == 0:
# expand to beam size
probs_nopad = probs_nopad.expand(bsz, beam_size, -1)
# gather scores
torch.gather(
probs_nopad,
dim=2,
index=self.indices_buf.unsqueeze(-1),
out=self.scores_buf,
)
self.scores_buf = self.scores_buf.log_().view(bsz, -1)
# remap indices if using top-k sampling
if self.sampling_topk > 0:
self.indices_buf = torch.gather(
topk_indices.expand(bsz, beam_size, -1),
dim=2,
index=self.indices_buf.unsqueeze(-1),
).squeeze(2)
# remap indices since we excluded the first two vocab items
self.indices_buf.add_(2)
if step == 0:
self.beams_buf = self.indices_buf.new_zeros(bsz, beam_size)
else:
self.beams_buf = torch.arange(0, beam_size, out=self.beams_buf).repeat(bsz, 1)
# make scores cumulative
self.scores_buf.add_(
torch.gather(
scores[:, :, step - 1],
dim=1,
index=self.beams_buf,
)
)
return self.scores_buf, self.indices_buf, self.beams_buf

View File

@ -9,7 +9,7 @@ import math
import torch
from fairseq import utils
from fairseq import search, utils
from fairseq.models import FairseqIncrementalDecoder
@ -43,9 +43,13 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
else:
self.search = search.BeamSearch(tgt_dict)
def cuda(self):
for model in self.models:
@ -273,19 +277,10 @@ class SequenceGenerator(object):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
scores = scores.type_as(probs)
scores_buf = scores_buf.type_as(probs)
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step - 1].view(-1, 1))
lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
if avg_attn_scores is not None:
@ -295,74 +290,33 @@ class SequenceGenerator(object):
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
cand_beams = buffer('cand_beams')
scores = scores.type_as(lprobs)
scores_buf = scores_buf.type_as(lprobs)
eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen:
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
if self.sampling_topk > 0:
values, indices = probs[:, 2:].topk(self.sampling_topk)
exp_probs = values.div_(self.sampling_temperature).exp()
if step == 0:
torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
cand_indices.add_(2)
else:
exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
if step == 0:
cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
else:
cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step - 1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
cand_beams = torch.zeros_like(cand_indices)
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
else:
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, step - 1].unsqueeze(-1))
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
probs[:, self.eos],
lprobs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
@ -406,7 +360,7 @@ class SequenceGenerator(object):
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask = cand_indices.new_ones(bsz)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)

View File

@ -75,6 +75,7 @@ def main(args):
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
sampling_temperature=args.sampling_temperature,
)
if use_cuda:

View File

@ -58,6 +58,26 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
generate_main(data_dir)
def test_generation(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_sampling') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir, [
'--sampling',
'--sampling-temperature', '2',
'--beam', '2',
'--nbest', '2',
])
generate_main(data_dir, [
'--sampling',
'--sampling-topk', '3',
'--beam', '2',
'--nbest', '2',
])
generate_main(data_dir, ['--prefix-size', '2'])
def test_lstm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm') as data_dir: