fairseq/tests/test_multihead_attention.py
alexeib 995c204337 Data2vec prelim (#2929)
Summary:
Preliminaries for data2vec release, include some minor improvements and bug fixes

Most important change is that we now default to raising an exception when fields in config do not have a corresponding field in the model dataclass

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2929

Reviewed By: wnhsu

Differential Revision: D33649708

Pulled By: alexeib

fbshipit-source-id: 629bdb4c361550740b451c570c2005bb956c6fcb
2022-01-20 00:02:16 -08:00

89 lines
2.9 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from fairseq.modules.multihead_attention import MultiheadAttention
class TestMultiheadAttention(unittest.TestCase):
def test_append_prev_key_padding_mask(self):
bsz = 1
src_len = 4
cases = [
# no padding mask
(None, None, None),
# current padding mask only
(
torch.tensor([[1]]).bool(),
None,
torch.tensor([[0, 0, 0, 1]]).bool(),
),
# previous padding mask only
(
None,
torch.tensor([[0, 1, 0]]).bool(),
torch.tensor([[0, 1, 0, 0]]).bool(),
),
# both padding masks
(
torch.tensor([[1]]).bool(),
torch.tensor([[0, 1, 0]]).bool(),
torch.tensor([[0, 1, 0, 1]]).bool(),
),
# prev_key_padding_mask already full
(
torch.tensor([[0, 1, 0, 1]]).bool(),
None,
torch.tensor([[0, 1, 0, 1]]).bool(),
),
# key_padding_mask already full
(
None,
torch.tensor([[0, 1, 0, 1]]).bool(),
torch.tensor([[0, 1, 0, 1]]).bool(),
),
]
for c in cases:
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
c[0],
c[1],
batch_size=bsz,
src_len=src_len,
static_kv=False,
)
if key_padding_mask is not None:
self.assertTrue(
torch.all(torch.eq(key_padding_mask, c[2])),
f"Unexpected resultant key padding mask: {key_padding_mask}"
f" given current: {c[0]} and previous: {c[1]}",
)
self.assertEqual(key_padding_mask.size(0), bsz)
self.assertEqual(key_padding_mask.size(1), src_len)
else:
self.assertIsNone(c[2])
def test_pruning_heads(self):
embed_dim = 768
num_heads = 12
num_heads_to_keep = 8
dummy_input = torch.randn(32, 2, embed_dim)
mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
reserve_head_index = mha._get_reserve_head_index(
num_heads_to_keep=num_heads_to_keep
)
mha._adaptive_prune_heads(reserve_head_index=reserve_head_index)
mha._set_skip_embed_dim_check()
mha(query=dummy_input, key=dummy_input, value=dummy_input)
self.assertEqual(mha.head_dim, embed_dim / num_heads)
self.assertEqual(mha.num_heads, num_heads_to_keep)
if __name__ == "__main__":
unittest.main()