fairseq/tests/test_rotary_positional_embedding.py
Sravya Popuri 40ff55abbe conformer (#2859)
Summary:
**This PR**

- Adds conformer layer based on https://arxiv.org/pdf/2005.08100.pdf.
- Conformer implementation supports multihead attention based on 3 different positional embedding types - absolute positional embedding, relative positional encoding  and rotational positional embedding.
- Adds conformer encoder with conv1d subsampling, positional embedding followed by N conformer layers
- Adds S2T_Conformer model based on the conformer encoder and transformer decoder.
- Add conformer support in Wav2Vec2
- Add unit tests for core modules

**Verfication**

- Verified the set up on MUST-C En-De S2T, Covost2 Es-En S2T, Librispeech ASR to ensure the implementation is correct.
- For S2T setups, the performance is either similar to the transformer based models or better.
- Wav2vec2 pretraining and finetuning based on librispeech showed improvements over corresponding transformer baselines.
- [WIP] Experiment log: https://docs.google.com/document/d/1QI-ROWVenUEXPJoHTaKD85Fq7T8ZXNc8bc54MzgwJjA/edit#

**Next steps**
- Add regression tests
- Add README and open source checkpoints

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2859

Reviewed By: kahne

Differential Revision: D33434092

Pulled By: sravyapopuri388

fbshipit-source-id: 62f22b917a332481370750e04a439e05832a2282
2022-01-10 16:18:38 -08:00

65 lines
2.0 KiB
Python

import torch
import numpy as np
import unittest
from fairseq.modules.rotary_positional_embedding import apply_rotary_pos_emb
from fairseq.modules import RotaryPositionalEmbedding
class TestRotaryPositionalEmbedding(unittest.TestCase):
def setUp(self) -> None:
self.T = 3
self.B = 1
self.C = 2
torch.manual_seed(0)
self.sample = torch.randn(self.T, self.B, self.C) # TBC
self.rope_pos_emd = RotaryPositionalEmbedding(dim=self.C)
def test_forward(self):
expected_cos = torch.tensor(
[[[[1.0000, 1.0000]]], [[[0.5403, 0.5403]]], [[[-0.4161, -0.4161]]]]
)
expected_sin = torch.tensor(
[[[[0.0000, 0.0000]]], [[[0.8415, 0.8415]]], [[[0.9093, 0.9093]]]]
)
cos, sin = self.rope_pos_emd(self.sample, self.T)
self.assertTrue(
np.allclose(
expected_cos.cpu().detach().numpy(),
cos.cpu().detach().numpy(),
atol=1e-4,
)
)
self.assertTrue(
np.allclose(
expected_sin.cpu().detach().numpy(),
sin.cpu().detach().numpy(),
atol=1e-4,
)
)
def test_apply_rotary_pos_emb(self):
cos, sin = self.rope_pos_emd(self.sample, self.T)
query = self.sample.view(self.T, self.B, 1, self.C)
expected_query = torch.tensor(
[[[[1.5410, -0.2934]]], [[[-1.6555, -1.5263]]], [[[1.7231, -0.4041]]]]
)
new_query, new_key = apply_rotary_pos_emb(query, query, cos, sin)
self.assertTrue(
np.allclose(
expected_query.cpu().detach().numpy(),
new_query.cpu().detach().numpy(),
atol=1e-4,
)
)
self.assertTrue(
np.allclose(
expected_query.cpu().detach().numpy(),
new_key.cpu().detach().numpy(),
atol=1e-4,
)
)
if __name__ == "__main__":
unittest.main()