fairseq/tests/test_espnet_multihead_attention.py
Sravya Popuri 40ff55abbe conformer (#2859)
Summary:
**This PR**

- Adds conformer layer based on https://arxiv.org/pdf/2005.08100.pdf.
- Conformer implementation supports multihead attention based on 3 different positional embedding types - absolute positional embedding, relative positional encoding  and rotational positional embedding.
- Adds conformer encoder with conv1d subsampling, positional embedding followed by N conformer layers
- Adds S2T_Conformer model based on the conformer encoder and transformer decoder.
- Add conformer support in Wav2Vec2
- Add unit tests for core modules

**Verfication**

- Verified the set up on MUST-C En-De S2T, Covost2 Es-En S2T, Librispeech ASR to ensure the implementation is correct.
- For S2T setups, the performance is either similar to the transformer based models or better.
- Wav2vec2 pretraining and finetuning based on librispeech showed improvements over corresponding transformer baselines.
- [WIP] Experiment log: https://docs.google.com/document/d/1QI-ROWVenUEXPJoHTaKD85Fq7T8ZXNc8bc54MzgwJjA/edit#

**Next steps**
- Add regression tests
- Add README and open source checkpoints

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2859

Reviewed By: kahne

Differential Revision: D33434092

Pulled By: sravyapopuri388

fbshipit-source-id: 62f22b917a332481370750e04a439e05832a2282
2022-01-10 16:18:38 -08:00

177 lines
5.4 KiB
Python

import torch
import numpy as np
import unittest
from fairseq.modules import (
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
torch.use_deterministic_algorithms(True)
class TestESPNETMultiHeadedAttention(unittest.TestCase):
def setUp(self) -> None:
self.T = 3
self.B = 1
self.C = 2
torch.manual_seed(0)
self.sample = torch.randn(self.T, self.B, self.C) # TBC
self.sample_scores = torch.randn(self.B, 1, self.T, self.T)
self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0)
def test_forward(self):
expected_scores = torch.tensor(
[[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]]
)
scores, _ = self.MHA(self.sample, self.sample, self.sample)
self.assertTrue(
np.allclose(
expected_scores.cpu().detach().numpy(),
scores.cpu().detach().numpy(),
atol=1e-4,
)
)
def test_forward_qkv(self):
expected_query = torch.tensor(
[[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]]
)
expected_key = torch.tensor(
[[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]]
)
expected_val = torch.tensor(
[[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]]
)
sample_t = self.sample.transpose(0, 1)
query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t)
self.assertTrue(
np.allclose(
expected_query.cpu().detach().numpy(),
query.cpu().detach().numpy(),
atol=1e-4,
)
)
self.assertTrue(
np.allclose(
expected_key.cpu().detach().numpy(),
key.cpu().detach().numpy(),
atol=1e-4,
)
)
self.assertTrue(
np.allclose(
expected_val.cpu().detach().numpy(),
val.cpu().detach().numpy(),
atol=1e-4,
)
)
def test_forward_attention(self):
expected_scores = torch.tensor(
[[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]]
)
scores = self.MHA.forward_attention(
self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C),
self.sample_scores,
mask=None,
)
self.assertTrue(
np.allclose(
expected_scores.cpu().detach().numpy(),
scores.cpu().detach().numpy(),
atol=1e-4,
)
)
class TestRelPositionMultiHeadedAttention(unittest.TestCase):
def setUp(self) -> None:
self.T = 3
self.B = 1
self.C = 2
torch.manual_seed(0)
self.sample = torch.randn(self.T, self.B, self.C) # TBC
self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1)
self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C)
self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0)
def test_rel_shift(self):
expected_x = torch.tensor(
[
[
[
[-0.7193, -0.4033, -0.5966],
[-0.8567, 1.1006, -1.0712],
[-0.5663, 0.3731, -0.8920],
]
]
]
)
x = self.MHA.rel_shift(self.sample_x)
self.assertTrue(
np.allclose(
expected_x.cpu().detach().numpy(),
x.cpu().detach().numpy(),
atol=1e-4,
)
)
def test_forward(self):
expected_scores = torch.tensor(
[
[[-0.9609, -0.5020]],
[[-0.9308, -0.4890]],
[[-0.9473, -0.4948]],
[[-0.9609, -0.5020]],
[[-0.9308, -0.4890]],
[[-0.9473, -0.4948]],
[[-0.9609, -0.5020]],
[[-0.9308, -0.4890]],
[[-0.9473, -0.4948]],
[[-0.9609, -0.5020]],
[[-0.9308, -0.4890]],
[[-0.9473, -0.4948]],
[[-0.9609, -0.5020]],
[[-0.9308, -0.4890]],
[[-0.9473, -0.4948]],
]
)
scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos)
self.assertTrue(
np.allclose(
expected_scores.cpu().detach().numpy(),
scores.cpu().detach().numpy(),
atol=1e-4,
)
)
class TestRotaryPositionMultiHeadedAttention(unittest.TestCase):
def setUp(self) -> None:
self.T = 3
self.B = 1
self.C = 2
torch.manual_seed(0)
self.sample = torch.randn(self.T, self.B, self.C) # TBC
self.MHA = RotaryPositionMultiHeadedAttention(
self.C, 1, dropout=0, precision=None
)
def test_forward(self):
expected_scores = torch.tensor(
[[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]]
)
scores, _ = self.MHA(self.sample, self.sample, self.sample)
self.assertTrue(
np.allclose(
expected_scores.cpu().detach().numpy(),
scores.cpu().detach().numpy(),
atol=1e-4,
)
)
if __name__ == "__main__":
unittest.main()