don't allow language tokens in output for 1:many decoding

Summary:
Because of the way language ID tokens were introduced for training 1:N MT models, we sometimes see the artifact of the model producing language tokens in the output (see T119348697 for details). This change prevents that by simply zeroing out their probabilities during beam search.

Edit: it turns out that the majority of cases of target language token appearing in the output are due to UNK replacement, where the language ID in source prefix was the "most-attended-to" source token at the step where the UNK was produced. Thus, this change zeros out the attention weights for prefix tokens in the source sequence.

Reviewed By: theweiho

Differential Revision: D36282175

fbshipit-source-id: fb7b2bfd3a8c1c66563ea509e68ab742a831ba4a
This commit is contained in:
James Cross 2022-06-15 16:48:07 -07:00 committed by Facebook GitHub Bot
parent a0ceabc287
commit d9c661bf4f

View File

@ -38,6 +38,7 @@ class SequenceGenerator(nn.Module):
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0,
tokens_to_suppress=(),
):
"""Generates translations of a given source sentence.
@ -77,6 +78,18 @@ class SequenceGenerator(nn.Module):
if symbols_to_strip_from_output is not None
else {self.eos}
)
self.token_indices_to_suppress: Optional[Tensor] = None
token_indices_to_suppress = []
for token_string in tokens_to_suppress:
token_index = tgt_dict.index(token_string)
assert token_index != self.unk
token_indices_to_suppress.append(token_index)
if len(token_indices_to_suppress) > 0:
self.token_indices_to_suppress = torch.Tensor(
token_indices_to_suppress
).long()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
@ -372,9 +385,13 @@ class SequenceGenerator(nn.Module):
lprobs, tokens, scores = self._prefix_tokens(
step, lprobs, scores, tokens, prefix_tokens, beam_size
)
elif step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
else:
if step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
if self.token_indices_to_suppress is not None:
lprobs[:, self.token_indices_to_suppress] = -math.inf
# Record attention scores, only support avg_attn_scores is a Tensor
if avg_attn_scores is not None: