From 16ebfa752cd653dc28ac0a1b9f6f60a041881ad0 Mon Sep 17 00:00:00 2001 From: Jingfei Du Date: Tue, 14 Dec 2021 13:20:44 -0800 Subject: [PATCH] Revert preix beamsearch fix (#2763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? reverting to fix issue mentioned [here](https://github.com/pytorch/fairseq/issues/3913). Having another PR for fixing the original issue later. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2763 Reviewed By: myleott Differential Revision: D33000411 Pulled By: jingfeidu fbshipit-source-id: 95a54cbdc612129a0eab4b5e6aa576a5bcf00588 --- fairseq/sequence_generator.py | 23 ++++++------- tests/test_sequence_generator.py | 56 -------------------------------- 2 files changed, 12 insertions(+), 67 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index bfa791a01..db730c624 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -350,6 +350,17 @@ class SequenceGenerator(nn.Module): ) probs = probs[:, -1, :] * self.lm_weight lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + # handle prefix tokens (possibly with different lengths) if ( prefix_tokens is not None @@ -363,16 +374,6 @@ class SequenceGenerator(nn.Module): # minimum length constraint (does not apply if using prefix_tokens) lprobs[:, self.eos] = -math.inf - lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) - - lprobs[:, self.pad] = -math.inf # never select pad - lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty - - # handle max length constraint - if step >= max_len: - lprobs[:, : self.eos] = -math.inf - lprobs[:, self.eos + 1 :] = -math.inf - # Record attention scores, only support avg_attn_scores is a Tensor if avg_attn_scores is not None: if attn is None: @@ -574,7 +575,7 @@ class SequenceGenerator(nn.Module): prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.pad) - lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) lprobs[prefix_mask] = lprobs[prefix_mask].scatter( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] ) diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index a14d73989..823c917d2 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -562,62 +562,6 @@ class TestDiverseSiblingsSearch(TestDiverseBeamSearch): self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5) -class TestPrefixBeamSearch(TestSequenceGeneratorBase): - def setUp(self): - # construct dummy dictionary - vocab_size = 10 - d = test_utils.dummy_dictionary(vocab_size=vocab_size) - self.assertEqual(d.pad(), 1) - self.assertEqual(d.eos(), 2) - self.assertEqual(d.unk(), 3) - self.eos = d.eos() - self.w1 = 4 - self.w2 = 5 - self.beam_size = 3 - - # construct prefix data - self.tokens = torch.LongTensor( - [ - [self.w1, self.w2, self.eos], - ] - ) - self.token_lengths = torch.LongTensor([2]) - - args = argparse.Namespace() - unk = 0.0 - args.beam_probs = [ - # prefix step 0: - torch.FloatTensor( - [ - # eos - [0.0, unk] - + [1.0 / vocab_size] * vocab_size # beam 1 - ] - * self.beam_size - ), - ] * vocab_size - - task = test_utils.TestTranslationTask.setup_task(args, d, d) - self.model = task.build_model(args) - self.tgt_dict = task.target_dictionary - - def test_prefix_beam_search(self): - search_strategy = search.BeamSearch(self.tgt_dict) - generator = SequenceGenerator( - [self.model], - self.tgt_dict, - beam_size=self.beam_size, - search_strategy=search_strategy, - ) - sample = { - "net_input": { - "src_tokens": self.tokens, - "src_lengths": self.token_lengths, - } - } - # make sure test sample doesn't break any assertion - generator.forward(sample, prefix_tokens=self.tokens[:, :-1]) - class TestTopPSamplingSearch(TestSequenceGeneratorBase): def setUp(self):