Add a diverse beam search variant to sequence_generator.py (#953)

Summary:
This PR implements a new generation strategy that we experimented with in project Pinocchio (https://github.com/fairinternal/Pinocchio), see the paper submission in: https://fburl.com/hduj2me7.

Specifically in this PR:
- added a Diverse Beam Search variant as described in https://arxiv.org/abs/1611.08562
- moved the Search object generation out of `sequence_generation.py`, which allows for limiting the number of kwargs passes around
- made sure the above changes are backward compatible based on grep - P124083926
- added test cases covering these scenarios
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/953

Test Plan:
- `python -m unittest tests.test_binaries -v`- including added test cases, see issues below for some details
- `python -m unittest tests.test_sequence_generator -v` - including added test cases
- tested locally in conjunction with the Pinocchio repo
- grepped for all instantiations of `SequenceGeneration`, made sure they're backward compatible

# Issues
- when I try to run all tests with `python -m unittest tests.test_binaries -v` command, the execution gets stuck on `test_binaries.TestTranslation.test_generation` - the test otherwise passes without problems when ran individually. Is this a known problem?
- discovered T59235948 - assigned to fairseq oncall

Reviewed By: myleott, fabiopetroni

Differential Revision: D19142394

Pulled By: ola13

fbshipit-source-id: d24543424c14a9537e7b6485951d9f841da62b07
This commit is contained in:
Aleksandra Piktus 2020-01-06 08:21:03 -08:00 committed by Facebook Github Bot
parent fb2d29d2aa
commit fab2e86e51
7 changed files with 193 additions and 58 deletions

View File

@ -501,6 +501,8 @@ def add_generation_args(parser):
help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N',
help='strength of diversity penalty for Diverse Siblings Search')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--print-step', action='store_true')

View File

@ -286,3 +286,68 @@ class Sampling(Search):
)
return self.scores_buf, self.indices_buf, self.beams_buf
class DiverseSiblingsSearch(Search):
"""
Beam search with diverse siblings.
See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details.
https://arxiv.org/abs/1611.08562
1/ Calculate hypotheses for each beam
2/ Intra-sibling ordering
3/ Rewrite scores
4/ Choose top K hypotheses
if diversity_rate == 0 is equivalent to BeamSearch
"""
def __init__(self, tgt_dict, diversity_rate):
super().__init__(tgt_dict)
self.diversity_rate = diversity_rate
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
k = min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
)
s_list = [lprobs.new() for i in range(beam_size)]
i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)]
sibling_score = lprobs.new(range(1, k + 1)) * self.diversity_rate
if step == 0:
return self.beam.step(step, lprobs, scores)
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
# 1/ Calculate hypotheses for each beam
for i in range(beam_size):
torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i]))
i_list[i].fmod_(vocab_size)
# 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores
s_list[i].sub_(sibling_score)
# 4/ Choose top K hypotheses
indices = torch.stack(i_list, dim=1).view(bsz, -1)
final_scores = lprobs.new()
final_indices = torch.LongTensor().to(device=lprobs.device)
final_beams = torch.LongTensor().to(device=lprobs.device)
torch.topk(
torch.stack(s_list, dim=1).view(bsz, -1),
k,
out=(final_scores, final_indices),
)
torch.div(final_indices, k, out=final_beams)
for i in range(bsz):
final_indices[i] = indices[i][final_indices[i]]
return final_scores, final_indices, final_beams

View File

@ -24,14 +24,10 @@ class SequenceGenerator(object):
len_penalty=1.,
unk_penalty=0.,
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_topp=-1.0,
temperature=1.,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
match_source_len=False,
no_repeat_ngram_size=0,
search_strategy=None,
):
"""Generates translations of a given source sentence.
@ -50,18 +46,9 @@ class SequenceGenerator(object):
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
sampling (bool, optional): sample outputs instead of beam search
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
sampling_topp (float, optional): only sample among the smallest set
of words whose cumulative probability mass exceeds p
at each step (default: -1.0)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
@ -82,20 +69,12 @@ class SequenceGenerator(object):
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
assert temperature > 0, '--temperature must be greater than 0'
if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_topp)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
self.search = search.LengthConstrainedBeamSearch(
tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
)
else:
self.search = search.BeamSearch(tgt_dict)
self.search = (
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
)
@torch.no_grad()
def generate(self, models, sample, **kwargs):

View File

@ -6,7 +6,7 @@
import numpy as np
import torch
from fairseq import tokenizer
from fairseq import search, tokenizer
from fairseq.data import (
data_utils,
FairseqDataset,
@ -202,30 +202,69 @@ class FairseqTask(object):
if getattr(args, 'score_reference', False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(self.target_dictionary)
else:
from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment
if getattr(args, 'print_alignment', False):
seq_gen_cls = SequenceGeneratorWithAlignment
else:
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
self.target_dictionary,
beam_size=getattr(args, 'beam', 5),
max_len_a=getattr(args, 'max_len_a', 0),
max_len_b=getattr(args, 'max_len_b', 200),
min_len=getattr(args, 'min_len', 1),
normalize_scores=(not getattr(args, 'unnormalized', False)),
len_penalty=getattr(args, 'lenpen', 1),
unk_penalty=getattr(args, 'unkpen', 0),
sampling=getattr(args, 'sampling', False),
sampling_topk=getattr(args, 'sampling_topk', -1),
sampling_topp=getattr(args, 'sampling_topp', -1.0),
temperature=getattr(args, 'temperature', 1.),
diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
match_source_len=getattr(args, 'match_source_len', False),
no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment
# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, 'sampling', False)
sampling_topk = getattr(args, 'sampling_topk', -1)
sampling_topp = getattr(args, 'sampling_topp', -1.0)
diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1)
diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5),
match_source_len = getattr(args, 'match_source_len', False)
diversity_rate = getattr(args, 'diversity_rate', -1)
if (
sum(
int(cond)
for cond in [
sampling,
diverse_beam_groups > 0,
match_source_len,
diversity_rate > 0,
]
)
> 1
):
raise ValueError('Provided Search parameters are mutually exclusive.')
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
if sampling:
search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp)
elif diverse_beam_groups > 0:
search_strategy = search.DiverseBeamSearch(
self.target_dictionary, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
# this is useful for tagging applications where the output
# length should match the input length, so we hardcode the
# length constraints for simplicity
search_strategy = search.LengthConstrainedBeamSearch(
self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
)
elif diversity_rate > -1:
search_strategy = search.DiverseSiblingsSearch(self.target_dictionary, diversity_rate)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
if getattr(args, 'print_alignment', False):
seq_gen_cls = SequenceGeneratorWithAlignment
else:
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
self.target_dictionary,
beam_size=getattr(args, 'beam', 5),
max_len_a=getattr(args, 'max_len_a', 0),
max_len_b=getattr(args, 'max_len_b', 200),
min_len=getattr(args, 'min_len', 1),
normalize_scores=(not getattr(args, 'unnormalized', False)),
len_penalty=getattr(args, 'lenpen', 1),
unk_penalty=getattr(args, 'unkpen', 0),
temperature=getattr(args, 'temperature', 1.),
match_source_len=getattr(args, 'match_source_len', False),
no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
search_strategy=search_strategy,
)
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
"""

View File

@ -47,7 +47,6 @@ class TestBacktranslationDataset(unittest.TestCase):
max_len_b=200,
beam_size=2,
unk_penalty=0,
sampling=False,
)
backtranslation_dataset = BacktranslationDataset(

View File

@ -111,6 +111,15 @@ class TestTranslation(unittest.TestCase):
'--beam', '2',
'--nbest', '2',
])
generate_main(data_dir, [
'--diversity-rate', '0.5',
'--beam', '6',
])
with self.assertRaises(ValueError):
generate_main(data_dir, [
'--diverse-beam-groups', '4',
'--match-source-len',
])
generate_main(data_dir, ['--prefix-size', '2'])
def test_lstm(self):

View File

@ -8,6 +8,7 @@ import unittest
import torch
from fairseq import search
from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils
@ -197,8 +198,9 @@ class TestDiverseBeamSearch(TestSequenceGeneratorBase):
self.tgt_dict = task.target_dictionary
def test_diverse_beam_search(self):
search_strategy = search.DiverseBeamSearch(self.tgt_dict, num_groups=2, diversity_strength=0.)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
self.tgt_dict, beam_size=2, search_strategy=search_strategy,
)
sample = {'net_input': {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}}
hypos = generator.generate([self.model], sample)
@ -217,6 +219,48 @@ class TestDiverseBeamSearch(TestSequenceGeneratorBase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
class TestDiverseSiblingsSearch(TestDiverseBeamSearch):
def assertHypoScore(
self, hypo, pos_probs, sibling_rank, diversity_rate, normalized=True, lenpen=1.0
):
pos_scores = torch.FloatTensor(pos_probs).log()
pos_scores.sub_(torch.Tensor(sibling_rank) * diversity_rate)
self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel() ** lenpen
self.assertLess(abs(score - hypo["score"]), 1e-6)
def test_diverse_beam_search(self):
search_strategy = search.DiverseSiblingsSearch(
self.tgt_dict, diversity_rate=0.5
)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, search_strategy=search_strategy
)
sample = {
"net_input": {
"src_tokens": self.src_tokens,
"src_lengths": self.src_lengths,
}
}
hypos = generator.generate([self.model], sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0], [0, 1, 1], 0.5)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.9, 0.4, 1.0], [0, 2, 1], 0.5)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9], [0, 1, 1], 0.5)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5)
class TestTopPSamplingSearch(TestSequenceGeneratorBase):
def setUp(self):
@ -282,10 +326,9 @@ class TestTopPSamplingSearch(TestSequenceGeneratorBase):
# Given a prob low enough to top-P sampling, we expect only the top
# 1 token to be sampled, which always results in the same output.
low_sampling_topp = self.min_top1_prob/2.0
search_strategy = search.Sampling(self.tgt_dict, sampling_topp=low_sampling_topp)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, sampling=True,
sampling_topp=low_sampling_topp
)
self.tgt_dict, beam_size=2, search_strategy=search_strategy)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
@ -311,10 +354,9 @@ class TestTopPSamplingSearch(TestSequenceGeneratorBase):
# Given a prob high enough to top-P sampling, any of the top 2
# tokens could be sampled. This can cause different outputs.
high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0
search_strategy = search.Sampling(self.tgt_dict, sampling_topp=high_sampling_topp)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, sampling=True,
sampling_topp=high_sampling_topp
)
self.tgt_dict, beam_size=2, search_strategy=search_strategy)
sample = {
'net_input': {
'src_tokens': self.src_tokens,