Fix and generalize --temperature option (#508)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/508

The previous version applied the temperature after the softmax. Fix that, and
also generalize so it works with other search approaches.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/694

Differential Revision: D15175160

Pulled By: myleott

fbshipit-source-id: cc87ff0e97a8a1dd37f9983163f58a8641155ab0
This commit is contained in:
Myle Ott 2019-05-04 16:33:48 -07:00 committed by Facebook Github Bot
parent fc1a19a38d
commit 96ac28d33d
5 changed files with 31 additions and 19 deletions

View File

@ -422,8 +422,8 @@ def add_generation_args(parser):
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',

View File

@ -168,10 +168,9 @@ class DiverseBeamSearch(Search):
class Sampling(Search):
def __init__(self, tgt_dict, sampling_topk=-1, sampling_temperature=1.):
def __init__(self, tgt_dict, sampling_topk=-1):
super().__init__(tgt_dict)
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
@ -190,10 +189,6 @@ class Sampling(Search):
if self.sampling_topk > 0:
lprobs_nopad, topk_indices = lprobs_nopad.topk(self.sampling_topk)
# sampling temperature
if self.sampling_temperature != 1.:
lprobs_nopad = lprobs_nopad.div_(self.sampling_temperature)
# sample
probs_nopad = lprobs_nopad.exp_()
if step == 0:

View File

@ -28,7 +28,7 @@ class SequenceGenerator(object):
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_temperature=1.,
temperature=1.,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
match_source_len=False,
@ -58,9 +58,9 @@ class SequenceGenerator(object):
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
sampling_temperature (float, optional): temperature for sampling,
where values >1.0 produces more uniform sampling and values
<1.0 produces sharper sampling (default: 1.0)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source
@ -81,13 +81,15 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert temperature > 0, '--temperature must be greater than 0'
if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
self.search = search.Sampling(tgt_dict, sampling_topk)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
@ -304,7 +306,9 @@ class SequenceGenerator(object):
model.reorder_incremental_state(reorder_state)
model.reorder_encoder_out(encoder_outs, reorder_state)
lprobs, avg_attn_scores = model.forward_decoder(tokens[:, :step + 1], encoder_outs)
lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
@ -547,7 +551,7 @@ class EnsembleModel(torch.nn.Module):
return [model.encoder(**encoder_input) for model in self.models]
@torch.no_grad()
def forward_decoder(self, tokens, encoder_outs):
def forward_decoder(self, tokens, encoder_outs, temperature=1.):
if len(self.models) == 1:
return self._decode_one(
tokens,
@ -555,12 +559,20 @@ class EnsembleModel(torch.nn.Module):
encoder_outs[0] if self.has_encoder() else None,
self.incremental_states,
log_probs=True,
temperature=temperature,
)
log_probs = []
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, self.incremental_states, log_probs=True)
probs, attn = self._decode_one(
tokens,
model,
encoder_out,
self.incremental_states,
log_probs=True,
temperature=temperature,
)
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
@ -572,12 +584,17 @@ class EnsembleModel(torch.nn.Module):
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
def _decode_one(
self, tokens, model, encoder_out, incremental_states, log_probs,
temperature=1.,
):
if self.incremental_states is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
attn = decoder_out[1]
if type(attn) is dict:
attn = attn['attn']

View File

@ -197,7 +197,7 @@ class FairseqTask(object):
unk_penalty=args.unkpen,
sampling=args.sampling,
sampling_topk=args.sampling_topk,
sampling_temperature=args.sampling_temperature,
temperature=args.temperature,
diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len,

View File

@ -94,7 +94,7 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir, [
'--sampling',
'--sampling-temperature', '2',
'--temperature', '2',
'--beam', '2',
'--nbest', '2',
])