Fix LM generation and add unit test

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/896

Differential Revision: D18250948

Pulled By: myleott

fbshipit-source-id: 7a515311e18795670b29f5e24eeba7619a625da7
This commit is contained in:
Myle Ott 2019-11-13 14:35:39 -08:00 committed by Facebook Github Bot
parent 096d7d301e
commit e26ee47a8c
4 changed files with 19 additions and 9 deletions

View File

@ -65,6 +65,8 @@ class TokenBlockDataset(FairseqDataset):
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
if torch.is_tensor(sizes):
sizes = sizes.numpy()
sizes = sizes.astype(np.int64)
break_mode = break_mode if break_mode is not None else 'none'

View File

@ -195,7 +195,7 @@ class SequenceGenerator(object):
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if len(finalized[sent]) == beam_size or step == max_len:
return True
return False
@ -298,21 +298,19 @@ class SequenceGenerator(object):
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle min and max length constraints
# handle max length constraint
if step >= max_len:
lprobs[:, :self.eos] = -math.inf
lprobs[:, self.eos + 1:] = -math.inf
elif step < self.min_len:
lprobs[:, self.eos] = -math.inf
# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1):
if prefix_tokens is not None and step < prefix_tokens.size(1) and step < max_len:
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
lprobs[prefix_mask] = -math.inf
lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
)
# if prefix includes eos, then we should make sure tokens and
# scores are the same across all beams
@ -333,6 +331,9 @@ class SequenceGenerator(object):
tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
scores = replicate_first_beam(scores, eos_mask_batch_dim)
lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)
elif step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
@ -383,8 +384,9 @@ class SequenceGenerator(object):
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos (except for blacklisted ones)
eos_mask = cand_indices.eq(self.eos)
# finalize hypotheses that end in eos, except for blacklisted ones
# or candidates with a score of -inf
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
eos_mask[:, :beam_size][blacklist] = 0
# only consider eos when it's among the top beam_size indices

View File

@ -223,8 +223,9 @@ class LanguageModelingTask(FairseqTask):
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
# note: EOS has already been removed in build_dataset_for_inference
prefix_tokens = sample["net_input"]["src_tokens"]
if prefix_tokens[:, 0].eq(self.source_dictionary.eos()).all():
prefix_tokens = prefix_tokens[:, 1:]
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
@property

View File

@ -384,6 +384,11 @@ class TestLanguageModeling(unittest.TestCase):
data_dir, 'transformer_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir)
generate_main(data_dir, [
'--task', 'language_modeling',
'--sample-break-mode', 'eos',
'--tokens-per-sample', '500',
])
class TestMaskedLanguageModel(unittest.TestCase):