This commit is contained in:
Sergey Edunov 2018-03-07 10:01:56 -08:00 committed by Myle Ott
parent 2ee3e8c137
commit 7d19e36dc4
4 changed files with 36 additions and 3 deletions

View File

@ -228,6 +228,8 @@ def add_generation_args(parser):
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
return group

View File

@ -15,7 +15,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0, retain_dropout=False):
unk_penalty=0, retain_dropout=False, sampling=False):
"""Generates translations of a given source sentence.
Args:
@ -44,6 +44,7 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
def cuda(self):
for model in self.models:
@ -255,9 +256,10 @@ class SequenceGenerator(object):
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
scores = scores.type_as(probs)
scores_buf = scores_buf.type_as(probs)
else:
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
@ -278,6 +280,31 @@ class SequenceGenerator(object):
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
exp_probs = probs.exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
cand_indices.add_(2)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
if step == 0:
cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
else:
cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.

View File

@ -16,6 +16,8 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args):
print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu
@ -77,7 +79,7 @@ def main(args):
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
unk_penalty=args.unkpen, sampling=args.sampling)
if use_cuda:
translator.cuda()

View File

@ -16,6 +16,8 @@ from fairseq.sequence_generator import SequenceGenerator
def main(args):
print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu