Revert preix beamsearch fix (#2763)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
reverting to fix issue mentioned [here](https://github.com/pytorch/fairseq/issues/3913). Having another PR for fixing the original issue later.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2763

Reviewed By: myleott

Differential Revision: D33000411

Pulled By: jingfeidu

fbshipit-source-id: 95a54cbdc612129a0eab4b5e6aa576a5bcf00588
This commit is contained in:
Jingfei Du 2021-12-14 13:20:44 -08:00 committed by Facebook GitHub Bot
parent 771f85025e
commit 16ebfa752c
2 changed files with 12 additions and 67 deletions

View File

@ -350,6 +350,17 @@ class SequenceGenerator(nn.Module):
)
probs = probs[:, -1, :] * self.lm_weight
lprobs += probs
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle max length constraint
if step >= max_len:
lprobs[:, : self.eos] = -math.inf
lprobs[:, self.eos + 1 :] = -math.inf
# handle prefix tokens (possibly with different lengths)
if (
prefix_tokens is not None
@ -363,16 +374,6 @@ class SequenceGenerator(nn.Module):
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle max length constraint
if step >= max_len:
lprobs[:, : self.eos] = -math.inf
lprobs[:, self.eos + 1 :] = -math.inf
# Record attention scores, only support avg_attn_scores is a Tensor
if avg_attn_scores is not None:
if attn is None:
@ -574,7 +575,7 @@ class SequenceGenerator(nn.Module):
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
)

View File

@ -562,62 +562,6 @@ class TestDiverseSiblingsSearch(TestDiverseBeamSearch):
self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5)
class TestPrefixBeamSearch(TestSequenceGeneratorBase):
def setUp(self):
# construct dummy dictionary
vocab_size = 10
d = test_utils.dummy_dictionary(vocab_size=vocab_size)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5
self.beam_size = 3
# construct prefix data
self.tokens = torch.LongTensor(
[
[self.w1, self.w2, self.eos],
]
)
self.token_lengths = torch.LongTensor([2])
args = argparse.Namespace()
unk = 0.0
args.beam_probs = [
# prefix step 0:
torch.FloatTensor(
[
# eos
[0.0, unk]
+ [1.0 / vocab_size] * vocab_size # beam 1
]
* self.beam_size
),
] * vocab_size
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_prefix_beam_search(self):
search_strategy = search.BeamSearch(self.tgt_dict)
generator = SequenceGenerator(
[self.model],
self.tgt_dict,
beam_size=self.beam_size,
search_strategy=search_strategy,
)
sample = {
"net_input": {
"src_tokens": self.tokens,
"src_lengths": self.token_lengths,
}
}
# make sure test sample doesn't break any assertion
generator.forward(sample, prefix_tokens=self.tokens[:, :-1])
class TestTopPSamplingSearch(TestSequenceGeneratorBase):
def setUp(self):