Pull out some code into separate methods (#3068)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Pulling out some changes from https://github.com/fairinternal/fairseq-py/pull/2263 unrelated to xformers to make the PR cleaner

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

X-link: https://github.com/fairinternal/fairseq-py/pull/3068

Reviewed By: blefaudeux

Differential Revision: D34149016

Pulled By: dianaml0

fbshipit-source-id: 6442a5f451d56cc47106227298a624516b19a9ad
This commit is contained in:
Diana Liskovich 2022-04-27 16:54:02 -07:00 committed by Facebook GitHub Bot
parent caac187386
commit 72d3408481
2 changed files with 141 additions and 30 deletions

View File

@ -241,6 +241,57 @@ class MultiheadAttention(nn.Module):
def _set_skip_embed_dim_check(self):
self.skip_embed_dim_check = True
def _pad_masks(
self,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
if attn_mask is not None:
shape = attn_mask.size()[:-1] + torch.Size([1])
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
if key_padding_mask is not None:
shape = key_padding_mask.size()[:-1] + torch.Size([1])
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(shape),
],
dim=-1,
)
return key_padding_mask, attn_mask
def _add_bias(
self,
k: Tensor,
v: Tensor,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
bsz: int,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
assert self.bias_k is not None
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return k, v, key_padding_mask, attn_mask
def _append_zero_attn(
self,
k: Tensor,
v: Tensor,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2
)
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return k, v, key_padding_mask, attn_mask
def forward(
self,
query,
@ -371,20 +422,9 @@ class MultiheadAttention(nn.Module):
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
k, v, attn_mask, key_padding_mask = self._add_bias(
k, v, attn_mask, key_padding_mask, bsz
)
q = (
q.contiguous()
@ -466,22 +506,9 @@ class MultiheadAttention(nn.Module):
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
if self.encoder_decoder_attention and bsz != kv_bsz:
attn_weights = torch.einsum(

View File

@ -9,6 +9,90 @@ import torch
from fairseq.modules.multihead_attention import MultiheadAttention
def test_mask_padding_parity():
def old_padding_code(key_padding_mask, attn_mask):
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
return key_padding_mask, attn_mask
# values don't matter for this test.
mha = MultiheadAttention(
embedding=8,
num_heads=2,
dropout=0.0,
add_bias_kv=True,
add_zero_attn=True,
)
key_padding_mask = torch.rand((8, 64))
attn_mask = torch.rand((64, 64))
kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask)
kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask)
assert kp_mask_orig.size() == kp_mask_new.size()
assert a_mask_orig.size() == a_mask_new.size()
assert torch.equal(kp_mask_orig, kp_mask_new)
assert torch.equal(a_mask_orig, a_mask_new)
def test_add_bias_parity():
# values don't matter for this test.
mha = MultiheadAttention(
embedding=8,
num_heads=2,
dropout=0.0,
add_bias_kv=True,
add_zero_attn=True,
)
def old_bias_code(k, v, key_padding_mask, attn_mask, bsz):
k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
return k, v, key_padding_mask, attn_mask
seq_len = 64
bsz = 8
embedding = 8
key_padding_mask = torch.rand((bsz, seq_len))
attn_mask = torch.rand((seq_len, seq_len))
k = torch.rand((seq_len, bsz, embedding))
v = torch.rand((seq_len, bsz, embedding))
k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(k, v, key_padding_mask, attn_mask, bsz)
k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(k, v, key_padding_mask, attn_mask, bsz)
assert torch.equal(k_orig, k_new)
assert torch.equal(v_orig, v_new)
assert torch.equal(kp_mask_orig, kp_mask_new)
assert torch.equal(a_mask_orig, a_mask_new)
class TestMultiheadAttention(unittest.TestCase):
def test_append_prev_key_padding_mask(self):
bsz = 1