Fix bug in LM sampling when using bos token

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/945

Differential Revision: D19084265

Pulled By: myleott

fbshipit-source-id: d7788b4311ce9d1ef94e479f7dc3017dd888426c
This commit is contained in:
Myle Ott 2019-12-16 19:16:15 -08:00 committed by Facebook Github Bot
parent be3515b289
commit 78d86dcc6c
3 changed files with 17 additions and 17 deletions

View File

@ -225,31 +225,27 @@ class Sampling(Search):
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
# we exclude the first two vocab items, one of which is pad
assert self.pad <= 1, 'sampling assumes the first two symbols can be ignored'
lprobs_nopad = lprobs[:, :, 2:]
if self.sampling_topp > 0:
# only sample from the smallest set of words whose cumulative probability mass exceeds p
probs_nopad, top_indices = self._sample_topp(lprobs_nopad)
probs, top_indices = self._sample_topp(lprobs)
elif self.sampling_topk > 0:
# only sample from top-k candidates
lprobs_nopad, top_indices = lprobs_nopad.topk(self.sampling_topk)
probs_nopad = lprobs_nopad.exp_()
lprobs, top_indices = lprobs.topk(self.sampling_topk)
probs = lprobs.exp_()
else:
probs_nopad = lprobs_nopad.exp_()
probs = lprobs.exp_()
# sample
if step == 0:
self.indices_buf = torch.multinomial(
probs_nopad.view(bsz, -1),
probs.view(bsz, -1),
beam_size,
replacement=True,
out=self.indices_buf,
).view(bsz, beam_size)
else:
self.indices_buf = torch.multinomial(
probs_nopad.view(bsz * beam_size, -1),
probs.view(bsz * beam_size, -1),
1,
replacement=True,
out=self.indices_buf,
@ -257,11 +253,11 @@ class Sampling(Search):
if step == 0:
# expand to beam size
probs_nopad = probs_nopad.expand(bsz, beam_size, -1)
probs = probs.expand(bsz, beam_size, -1)
# gather scores
torch.gather(
probs_nopad,
probs,
dim=2,
index=self.indices_buf.unsqueeze(-1),
out=self.scores_buf,
@ -276,9 +272,6 @@ class Sampling(Search):
index=self.indices_buf.unsqueeze(-1),
).squeeze(2)
# remap indices since we excluded the first two vocab items
self.indices_buf.add_(2)
if step == 0:
self.beams_buf = self.indices_buf.new_zeros(bsz, beam_size)
else:

View File

@ -147,6 +147,7 @@ class SequenceGenerator(object):
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
assert self.min_len <= max_len, 'min_len cannot be larger than max_len, please adjust these!'
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)

View File

@ -197,7 +197,7 @@ class LanguageModelingTask(FairseqTask):
self.target_dictionary,
add_eos_for_other_targets=False,
shuffle=False,
add_bos_token=self.args.add_bos_token,
add_bos_token=False, # we handle this in inference_step
),
eos=self.source_dictionary.eos(),
# remove EOS since this will be used as a prefix for generation
@ -211,7 +211,13 @@ class LanguageModelingTask(FairseqTask):
prefix_tokens = sample["net_input"]["src_tokens"]
if prefix_tokens[:, 0].eq(self.source_dictionary.eos()).all():
prefix_tokens = prefix_tokens[:, 1:]
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
if getattr(self.args, 'add_bos_token', False):
bos_token = self.source_dictionary.bos()
else:
bos_token = None
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token,
)
@property
def source_dictionary(self):