mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Summary: **This PR** - Adds conformer layer based on https://arxiv.org/pdf/2005.08100.pdf. - Conformer implementation supports multihead attention based on 3 different positional embedding types - absolute positional embedding, relative positional encoding and rotational positional embedding. - Adds conformer encoder with conv1d subsampling, positional embedding followed by N conformer layers - Adds S2T_Conformer model based on the conformer encoder and transformer decoder. - Add conformer support in Wav2Vec2 - Add unit tests for core modules **Verfication** - Verified the set up on MUST-C En-De S2T, Covost2 Es-En S2T, Librispeech ASR to ensure the implementation is correct. - For S2T setups, the performance is either similar to the transformer based models or better. - Wav2vec2 pretraining and finetuning based on librispeech showed improvements over corresponding transformer baselines. - [WIP] Experiment log: https://docs.google.com/document/d/1QI-ROWVenUEXPJoHTaKD85Fq7T8ZXNc8bc54MzgwJjA/edit# **Next steps** - Add regression tests - Add README and open source checkpoints Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2859 Reviewed By: kahne Differential Revision: D33434092 Pulled By: sravyapopuri388 fbshipit-source-id: 62f22b917a332481370750e04a439e05832a2282
177 lines
5.4 KiB
Python
177 lines
5.4 KiB
Python
import torch
|
|
import numpy as np
|
|
import unittest
|
|
from fairseq.modules import (
|
|
ESPNETMultiHeadedAttention,
|
|
RelPositionMultiHeadedAttention,
|
|
RotaryPositionMultiHeadedAttention,
|
|
)
|
|
|
|
torch.use_deterministic_algorithms(True)
|
|
|
|
|
|
class TestESPNETMultiHeadedAttention(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.T = 3
|
|
self.B = 1
|
|
self.C = 2
|
|
torch.manual_seed(0)
|
|
self.sample = torch.randn(self.T, self.B, self.C) # TBC
|
|
self.sample_scores = torch.randn(self.B, 1, self.T, self.T)
|
|
self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0)
|
|
|
|
def test_forward(self):
|
|
expected_scores = torch.tensor(
|
|
[[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]]
|
|
)
|
|
scores, _ = self.MHA(self.sample, self.sample, self.sample)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_scores.cpu().detach().numpy(),
|
|
scores.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
|
|
def test_forward_qkv(self):
|
|
expected_query = torch.tensor(
|
|
[[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]]
|
|
)
|
|
expected_key = torch.tensor(
|
|
[[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]]
|
|
)
|
|
expected_val = torch.tensor(
|
|
[[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]]
|
|
)
|
|
sample_t = self.sample.transpose(0, 1)
|
|
query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_query.cpu().detach().numpy(),
|
|
query.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_key.cpu().detach().numpy(),
|
|
key.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_val.cpu().detach().numpy(),
|
|
val.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
|
|
def test_forward_attention(self):
|
|
expected_scores = torch.tensor(
|
|
[[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]]
|
|
)
|
|
scores = self.MHA.forward_attention(
|
|
self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C),
|
|
self.sample_scores,
|
|
mask=None,
|
|
)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_scores.cpu().detach().numpy(),
|
|
scores.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
|
|
|
|
class TestRelPositionMultiHeadedAttention(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.T = 3
|
|
self.B = 1
|
|
self.C = 2
|
|
torch.manual_seed(0)
|
|
self.sample = torch.randn(self.T, self.B, self.C) # TBC
|
|
self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1)
|
|
self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C)
|
|
self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0)
|
|
|
|
def test_rel_shift(self):
|
|
expected_x = torch.tensor(
|
|
[
|
|
[
|
|
[
|
|
[-0.7193, -0.4033, -0.5966],
|
|
[-0.8567, 1.1006, -1.0712],
|
|
[-0.5663, 0.3731, -0.8920],
|
|
]
|
|
]
|
|
]
|
|
)
|
|
x = self.MHA.rel_shift(self.sample_x)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_x.cpu().detach().numpy(),
|
|
x.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
|
|
def test_forward(self):
|
|
expected_scores = torch.tensor(
|
|
[
|
|
[[-0.9609, -0.5020]],
|
|
[[-0.9308, -0.4890]],
|
|
[[-0.9473, -0.4948]],
|
|
[[-0.9609, -0.5020]],
|
|
[[-0.9308, -0.4890]],
|
|
[[-0.9473, -0.4948]],
|
|
[[-0.9609, -0.5020]],
|
|
[[-0.9308, -0.4890]],
|
|
[[-0.9473, -0.4948]],
|
|
[[-0.9609, -0.5020]],
|
|
[[-0.9308, -0.4890]],
|
|
[[-0.9473, -0.4948]],
|
|
[[-0.9609, -0.5020]],
|
|
[[-0.9308, -0.4890]],
|
|
[[-0.9473, -0.4948]],
|
|
]
|
|
)
|
|
scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_scores.cpu().detach().numpy(),
|
|
scores.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
|
|
|
|
class TestRotaryPositionMultiHeadedAttention(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.T = 3
|
|
self.B = 1
|
|
self.C = 2
|
|
torch.manual_seed(0)
|
|
self.sample = torch.randn(self.T, self.B, self.C) # TBC
|
|
self.MHA = RotaryPositionMultiHeadedAttention(
|
|
self.C, 1, dropout=0, precision=None
|
|
)
|
|
|
|
def test_forward(self):
|
|
expected_scores = torch.tensor(
|
|
[[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]]
|
|
)
|
|
scores, _ = self.MHA(self.sample, self.sample, self.sample)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
expected_scores.cpu().detach().numpy(),
|
|
scores.cpu().detach().numpy(),
|
|
atol=1e-4,
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|