fairseq/tests/test_positional_encoding.py
Sravya Popuri 40ff55abbe conformer (#2859)
Summary:
**This PR**

- Adds conformer layer based on https://arxiv.org/pdf/2005.08100.pdf.
- Conformer implementation supports multihead attention based on 3 different positional embedding types - absolute positional embedding, relative positional encoding  and rotational positional embedding.
- Adds conformer encoder with conv1d subsampling, positional embedding followed by N conformer layers
- Adds S2T_Conformer model based on the conformer encoder and transformer decoder.
- Add conformer support in Wav2Vec2
- Add unit tests for core modules

**Verfication**

- Verified the set up on MUST-C En-De S2T, Covost2 Es-En S2T, Librispeech ASR to ensure the implementation is correct.
- For S2T setups, the performance is either similar to the transformer based models or better.
- Wav2vec2 pretraining and finetuning based on librispeech showed improvements over corresponding transformer baselines.
- [WIP] Experiment log: https://docs.google.com/document/d/1QI-ROWVenUEXPJoHTaKD85Fq7T8ZXNc8bc54MzgwJjA/edit#

**Next steps**
- Add regression tests
- Add README and open source checkpoints

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2859

Reviewed By: kahne

Differential Revision: D33434092

Pulled By: sravyapopuri388

fbshipit-source-id: 62f22b917a332481370750e04a439e05832a2282
2022-01-10 16:18:38 -08:00

64 lines
1.7 KiB
Python

import unittest
import torch
from fairseq.modules import RelPositionalEncoding
import numpy as np
class TestRelPositionalEncoding(unittest.TestCase):
def setUp(self) -> None:
self.T = 3
self.B = 1
self.C = 2
torch.manual_seed(0)
self.sample = torch.randn(self.T, self.B, self.C) # TBC
self.rel_pos_enc = RelPositionalEncoding(max_len=4, d_model=self.C)
def test_extend_pe(self):
inp = self.sample.transpose(0, 1)
self.rel_pos_enc.extend_pe(inp)
expected_pe = torch.tensor(
[
[
[0.1411, -0.9900],
[0.9093, -0.4161],
[0.8415, 0.5403],
[0.0000, 1.0000],
[-0.8415, 0.5403],
[-0.9093, -0.4161],
[-0.1411, -0.9900],
]
]
)
self.assertTrue(
np.allclose(
expected_pe.cpu().detach().numpy(),
self.rel_pos_enc.pe.cpu().detach().numpy(),
atol=1e-4,
)
)
def test_forward(self):
pos_enc = self.rel_pos_enc(self.sample)
expected_pos_enc = torch.tensor(
[
[[0.9093, -0.4161]],
[[0.8415, 0.5403]],
[[0.0000, 1.0000]],
[[-0.8415, 0.5403]],
[[-0.9093, -0.4161]],
]
)
self.assertTrue(
np.allclose(
pos_enc.cpu().detach().numpy(),
expected_pos_enc.cpu().detach().numpy(),
atol=1e-4,
)
)
if __name__ == "__main__":
unittest.main()