diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 4fcfaff92..495a506ff 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -241,6 +241,57 @@ class MultiheadAttention(nn.Module): def _set_skip_embed_dim_check(self): self.skip_embed_dim_check = True + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(shape), + ], + dim=-1, + ) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 + ) + key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return k, v, key_padding_mask, attn_mask + def forward( self, query, @@ -371,20 +422,9 @@ class MultiheadAttention(nn.Module): if self.bias_k is not None: assert self.bias_v is not None - k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) - v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) - if attn_mask is not None: - attn_mask = torch.cat( - [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 - ) - if key_padding_mask is not None: - key_padding_mask = torch.cat( - [ - key_padding_mask, - key_padding_mask.new_zeros(key_padding_mask.size(0), 1), - ], - dim=1, - ) + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) q = ( q.contiguous() @@ -466,22 +506,9 @@ class MultiheadAttention(nn.Module): if self.add_zero_attn: assert v is not None src_len += 1 - k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) - v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) - if attn_mask is not None: - attn_mask = torch.cat( - [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 - ) - if key_padding_mask is not None: - key_padding_mask = torch.cat( - [ - key_padding_mask, - torch.zeros(key_padding_mask.size(0), 1).type_as( - key_padding_mask - ), - ], - dim=1, - ) + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) if self.encoder_decoder_attention and bsz != kv_bsz: attn_weights = torch.einsum( diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py index 59864f968..4e8f32c16 100644 --- a/tests/test_multihead_attention.py +++ b/tests/test_multihead_attention.py @@ -9,6 +9,90 @@ import torch from fairseq.modules.multihead_attention import MultiheadAttention +def test_mask_padding_parity(): + + def old_padding_code(key_padding_mask, attn_mask): + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + return key_padding_mask, attn_mask + + # values don't matter for this test. + mha = MultiheadAttention( + embedding=8, + num_heads=2, + dropout=0.0, + add_bias_kv=True, + add_zero_attn=True, + ) + + key_padding_mask = torch.rand((8, 64)) + attn_mask = torch.rand((64, 64)) + + kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask) + kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask) + + assert kp_mask_orig.size() == kp_mask_new.size() + assert a_mask_orig.size() == a_mask_new.size() + assert torch.equal(kp_mask_orig, kp_mask_new) + assert torch.equal(a_mask_orig, a_mask_new) + + +def test_add_bias_parity(): + # values don't matter for this test. + mha = MultiheadAttention( + embedding=8, + num_heads=2, + dropout=0.0, + add_bias_kv=True, + add_zero_attn=True, + ) + + def old_bias_code(k, v, key_padding_mask, attn_mask, bsz): + k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + return k, v, key_padding_mask, attn_mask + + seq_len = 64 + bsz = 8 + embedding = 8 + key_padding_mask = torch.rand((bsz, seq_len)) + attn_mask = torch.rand((seq_len, seq_len)) + k = torch.rand((seq_len, bsz, embedding)) + v = torch.rand((seq_len, bsz, embedding)) + + k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(k, v, key_padding_mask, attn_mask, bsz) + k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(k, v, key_padding_mask, attn_mask, bsz) + + assert torch.equal(k_orig, k_new) + assert torch.equal(v_orig, v_new) + assert torch.equal(kp_mask_orig, kp_mask_new) + assert torch.equal(a_mask_orig, a_mask_new) + + class TestMultiheadAttention(unittest.TestCase): def test_append_prev_key_padding_mask(self): bsz = 1