mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-22 06:39:29 +03:00
Add attention matrix to output of SequenceGenerator
This commit is contained in:
parent
376c265f35
commit
84754894b9
@ -108,8 +108,8 @@ class SequenceGenerator(object):
|
||||
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
|
||||
tokens_buf = tokens.clone()
|
||||
tokens[:, 0] = self.eos
|
||||
align = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(-1)
|
||||
align_buf = align.clone()
|
||||
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
|
||||
attn_buf = attn.clone()
|
||||
|
||||
# list of completed sentences
|
||||
finalized = [[] for i in range(bsz)]
|
||||
@ -177,10 +177,12 @@ class SequenceGenerator(object):
|
||||
def get_hypo():
|
||||
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
|
||||
hypo[step] = self.eos
|
||||
alignment = align[idx, 1:step+2].clone()
|
||||
attention = attn[idx, :, 1:step+2].clone()
|
||||
_, alignment = attention.max(dim=0)
|
||||
return {
|
||||
'tokens': hypo,
|
||||
'score': score,
|
||||
'attention': attention,
|
||||
'alignment': alignment,
|
||||
}
|
||||
|
||||
@ -224,9 +226,8 @@ class SequenceGenerator(object):
|
||||
probs.add_(scores.view(-1, 1))
|
||||
probs[:, self.pad] = -math.inf # never select pad
|
||||
|
||||
# record alignment to source tokens, based on attention
|
||||
_ignore_scores = buffer('_ignore_scores', type_of=scores)
|
||||
avg_attn_scores.topk(1, out=(_ignore_scores, align[:, step+1].unsqueeze(1)))
|
||||
# Record attention scores
|
||||
attn[:, :, step+1].copy_(avg_attn_scores)
|
||||
|
||||
# take the best 2 x beam_size predictions. We'll choose the first
|
||||
# beam_size of these which don't predict eos to continue with.
|
||||
@ -290,17 +291,17 @@ class SequenceGenerator(object):
|
||||
cand_indices.gather(1, active_hypos,
|
||||
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])
|
||||
|
||||
# copy attention/alignment for active hypotheses
|
||||
torch.index_select(align[:, :step+2], dim=0, index=active_bbsz_idx,
|
||||
out=align_buf[:, :step+2])
|
||||
# copy attention for active hypotheses
|
||||
torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
|
||||
out=attn_buf[:, :, :step+2])
|
||||
|
||||
# swap buffers
|
||||
old_tokens = tokens
|
||||
tokens = tokens_buf
|
||||
tokens_buf = old_tokens
|
||||
old_align = align
|
||||
align = align_buf
|
||||
align_buf = old_align
|
||||
old_attn = attn
|
||||
attn = attn_buf
|
||||
attn_buf = old_attn
|
||||
|
||||
# reorder incremental state in decoder
|
||||
reorder_state = active_bbsz_idx
|
||||
|
Loading…
Reference in New Issue
Block a user