diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 495a506ff..fd33ba446 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -272,7 +272,9 @@ class MultiheadAttention(nn.Module): assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) - key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) return k, v, key_padding_mask, attn_mask def _append_zero_attn( @@ -289,7 +291,9 @@ class MultiheadAttention(nn.Module): v = torch.cat( [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 ) - key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) return k, v, key_padding_mask, attn_mask def forward( diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py index 4e8f32c16..ebed9c903 100644 --- a/tests/test_multihead_attention.py +++ b/tests/test_multihead_attention.py @@ -6,11 +6,11 @@ import unittest import torch + from fairseq.modules.multihead_attention import MultiheadAttention def test_mask_padding_parity(): - def old_padding_code(key_padding_mask, attn_mask): if attn_mask is not None: attn_mask = torch.cat( @@ -20,9 +20,7 @@ def test_mask_padding_parity(): key_padding_mask = torch.cat( [ key_padding_mask, - torch.zeros(key_padding_mask.size(0), 1).type_as( - key_padding_mask - ), + torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), ], dim=1, ) @@ -84,8 +82,12 @@ def test_add_bias_parity(): k = torch.rand((seq_len, bsz, embedding)) v = torch.rand((seq_len, bsz, embedding)) - k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(k, v, key_padding_mask, attn_mask, bsz) - k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(k, v, key_padding_mask, attn_mask, bsz) + k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code( + k, v, key_padding_mask, attn_mask, bsz + ) + k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias( + k, v, key_padding_mask, attn_mask, bsz + ) assert torch.equal(k_orig, k_new) assert torch.equal(v_orig, v_new)