Add attention matrix to output of SequenceGenerator

This commit is contained in:
Louis Martin 2017-10-03 09:50:27 -07:00 committed by Myle Ott
parent 376c265f35
commit 84754894b9

View File

@ -108,8 +108,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
align = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(-1)
align_buf = align.clone()
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
# list of completed sentences
finalized = [[] for i in range(bsz)]
@ -177,10 +177,12 @@ class SequenceGenerator(object):
def get_hypo():
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
hypo[step] = self.eos
alignment = align[idx, 1:step+2].clone()
attention = attn[idx, :, 1:step+2].clone()
_, alignment = attention.max(dim=0)
return {
'tokens': hypo,
'score': score,
'attention': attention,
'alignment': alignment,
}
@ -224,9 +226,8 @@ class SequenceGenerator(object):
probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
# record alignment to source tokens, based on attention
_ignore_scores = buffer('_ignore_scores', type_of=scores)
avg_attn_scores.topk(1, out=(_ignore_scores, align[:, step+1].unsqueeze(1)))
# Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores)
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
@ -290,17 +291,17 @@ class SequenceGenerator(object):
cand_indices.gather(1, active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])
# copy attention/alignment for active hypotheses
torch.index_select(align[:, :step+2], dim=0, index=active_bbsz_idx,
out=align_buf[:, :step+2])
# copy attention for active hypotheses
torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2])
# swap buffers
old_tokens = tokens
tokens = tokens_buf
tokens_buf = old_tokens
old_align = align
align = align_buf
align_buf = old_align
old_attn = attn
attn = attn_buf
attn_buf = old_attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx