fairseq/tests/test_sparse_multihead_attention.py

49 lines
2.5 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import unittest
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
class TestSparseMultiheadAttention(unittest.TestCase):
def test_sparse_multihead_attention(self):
attn_weights = torch.randn(1, 8, 8)
bidirectional_sparse_mask = torch.tensor([
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
])
bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))
sparse_mask = torch.tensor([
[0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
float('-inf'), float('-inf')],
[0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
[0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
[0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
])
attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
torch.all(torch.eq(attention_sparse_mask, sparse_mask))
if __name__ == '__main__':
unittest.main()