diff --git a/tests/speech/__init__.py b/tests/speech/__init__.py index 13b3ecf48..dba99e4d9 100644 --- a/tests/speech/__init__.py +++ b/tests/speech/__init__.py @@ -15,7 +15,7 @@ from fairseq.checkpoint_utils import load_model_ensemble_and_task from fairseq.scoring.wer import WerScorer from fairseq.scoring.bleu import SacrebleuScorer from fairseq import utils - +import zipfile S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq" @@ -95,19 +95,30 @@ class TestFairseqSpeech(unittest.TestCase): ) def download_and_load_checkpoint( - self, checkpoint_filename: str, arg_overrides: Optional[Dict[str, str]] = None + self, + checkpoint_filename: str, + arg_overrides: Optional[Dict[str, str]] = None, + strict: bool = True, ): path = self.download(self.base_url, self.root, checkpoint_filename) _arg_overrides = arg_overrides or {} _arg_overrides["data"] = self.root.as_posix() models, cfg, task = load_model_ensemble_and_task( - [path.as_posix()], arg_overrides=_arg_overrides + [path.as_posix()], arg_overrides=_arg_overrides, strict=strict ) if self.use_cuda: for model in models: model.cuda() - generator = task.build_generator(models, cfg) - return models, cfg, task, generator + + return models, cfg, task, self.build_generator(task, models, cfg) + + def build_generator( + self, + task, + models, + cfg, + ): + return task.build_generator(models, cfg) @classmethod def get_batch_iterator(cls, task, test_split, max_tokens, max_positions): @@ -141,35 +152,59 @@ class TestFairseqSpeech(unittest.TestCase): return SacrebleuScorer(Namespace(**scorer_args)) @torch.no_grad() - def librispeech_s2t_test_base(self, ckpt_name, reference_wer): - models, cfg, task, generator = self.download_and_load_checkpoint( - ckpt_name, - arg_overrides={"config_yaml": "cfg_librispeech.yaml"}, + def base_test( + self, + ckpt_name, + reference_score, + score_delta=0.3, + dataset="librispeech_test-other", + max_tokens=65_536, + max_positions=(4_096, 1_024), + arg_overrides=None, + strict=True, + score_type="wer", + ): + models, _, task, generator = self.download_and_load_checkpoint( + ckpt_name, arg_overrides=arg_overrides, strict=strict ) if not self.use_cuda: return batch_iterator = self.get_batch_iterator( - task, "librispeech_test-other", 65_536, (4_096, 1_024) + task, dataset, max_tokens, max_positions ) - scorer = self.get_wer_scorer() + if score_type == "bleu": + scorer = self.get_bleu_scorer() + elif score_type == "wer": + scorer = self.get_wer_scorer() + else: + raise Exception(f"Unsupported score type {score_type}") + progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator)) for batch_idx, sample in progress: sample = utils.move_to_cuda(sample) if self.use_cuda else sample hypo = task.inference_step(generator, models, sample) for i, sample_id in enumerate(sample["id"].tolist()): - tgt_tokens = ( - utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad()) - .int() - .cpu() - ) - tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece") - hypo_str = task.tgt_dict.string( - hypo[i][0]["tokens"].int().cpu(), "sentencepiece" + tgt_str, hypo_str = self.postprocess_tokens( + task, + sample["target"][i, :], + hypo[i][0]["tokens"].int().cpu(), ) if batch_idx == 0 and i < 3: print(f"T-{sample_id} {tgt_str}") print(f"H-{sample_id} {hypo_str}") scorer.add_string(tgt_str, hypo_str) - print(scorer.result_string() + f" (reference: {reference_wer})") - self.assertAlmostEqual(scorer.score(), reference_wer, delta=0.3) + + print(scorer.result_string() + f" (reference: {reference_score})") + self.assertAlmostEqual(scorer.score(), reference_score, delta=score_delta) + + def postprocess_tokens(self, task, target, hypo_tokens): + tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu() + tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece") + hypo_str = task.tgt_dict.string(hypo_tokens, "sentencepiece") + return tgt_str, hypo_str + + def unzip_files(self, zip_file_name): + zip_file_path = self.root / zip_file_name + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + zip_ref.extractall(self.root / zip_file_name.strip(".zip")) diff --git a/tests/speech/test_s2s_transformer.py b/tests/speech/test_s2s_transformer.py new file mode 100644 index 000000000..a4e71e246 --- /dev/null +++ b/tests/speech/test_s2s_transformer.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from tests.speech import TestFairseqSpeech +from fairseq import utils + +S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/" + + +class TestS2STransformer(TestFairseqSpeech): + def setUp(self): + self._set_up( + "s2s", + "speech_tests/s2s", + [ + "dev_shuf200.tsv", + "src_feat.zip", + "config_specaug_lb.yaml", + "config_letter_enc_unigram_dec.yaml", + ], + ) + + def test_s2s_transformer_checkpoint(self): + self.base_test( + ckpt_name="s2u_transformer_reduced_fisher.pt", + reference_score=38.3, + dataset="dev_shuf200", + arg_overrides={ + "config_yaml": "config_specaug_lb.yaml", + "target_is_code": True, + "target_code_size": 100, + }, + score_type="bleu", + ) + + def postprocess_tokens(self, task, target, hypo_tokens): + tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu() + tgt_str = task.tgt_dict.string(tgt_tokens) + hypo_str = task.tgt_dict.string(hypo_tokens) + return tgt_str, hypo_str + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/speech/test_s2t_conformer.py b/tests/speech/test_s2t_conformer.py index 2c7c57445..5aaa4a0ed 100644 --- a/tests/speech/test_s2t_conformer.py +++ b/tests/speech/test_s2t_conformer.py @@ -12,7 +12,11 @@ class TestS2TConformer(TestFairseqSpeech): self.set_up_librispeech() def test_librispeech_s2t_conformer_s_checkpoint(self): - self.librispeech_s2t_test_base("librispeech_conformer_rel_pos_s.pt", 12) + self.base_test( + ckpt_name="librispeech_conformer_rel_pos_s.pt", + reference_score=12, + arg_overrides={"config_yaml": "cfg_librispeech.yaml"}, + ) if __name__ == "__main__": diff --git a/tests/speech/test_s2t_transformer.py b/tests/speech/test_s2t_transformer.py index edf9f64c9..172f5484a 100644 --- a/tests/speech/test_s2t_transformer.py +++ b/tests/speech/test_s2t_transformer.py @@ -12,7 +12,11 @@ class TestS2TTransformer(TestFairseqSpeech): self.set_up_librispeech() def test_librispeech_s2t_transformer_s_checkpoint(self): - self.librispeech_s2t_test_base("librispeech_transformer_s.pt", 9) + self.base_test( + ckpt_name="librispeech_transformer_s.pt", + reference_score=9, + arg_overrides={"config_yaml": "cfg_librispeech.yaml"}, + ) if __name__ == "__main__": diff --git a/tests/speech/test_wav2vec2.py b/tests/speech/test_wav2vec2.py new file mode 100644 index 000000000..eff6114c8 --- /dev/null +++ b/tests/speech/test_wav2vec2.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import torch +from tests.speech import TestFairseqSpeech +from fairseq.data.data_utils import post_process +from fairseq import utils +from omegaconf import open_dict + +S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq" + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestWav2Vec2(TestFairseqSpeech): + def setUp(self): + self._set_up( + "librispeech_w2v2", + "conformer/wav2vec2/librispeech", + [ + "test_librispeech-other.ltr", + "test_librispeech-other.tsv", + "test_librispeech-other_small.ltr_100", + "test_librispeech-other_small.tsv", + "test-other.zip", + "dict.ltr.txt", + "dict.ltr_100.txt", + ], + ) + self.unzip_files( + "test-other.zip", + ) + + def test_transformer_w2v2(self): + self.base_test( + ckpt_name="transformer_oss_small_100h.pt", + reference_score=38, + score_delta=1, + dataset="test_librispeech-other", + max_tokens=1000000, + max_positions=(700000, 1000), + arg_overrides={ + "task": "audio_finetuning", + "labels": "ltr", + "nbest": 1, + "tpu": False, + }, + strict=False, + ) + + def test_conformer_w2v2(self): + self.base_test( + ckpt_name="conformer_LS_PT_LS_FT_rope.pt", + reference_score=4.5, + score_delta=1, + dataset="test_librispeech-other_small", + max_tokens=1000000, + max_positions=(700000, 1000), + arg_overrides={ + "task": "audio_finetuning", + "labels": "ltr_100", + "nbest": 1, + "tpu": False, + }, + strict=True, + ) + + def build_generator(self, task, models, cfg): + try: + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + except Exception: + raise Exception("Cannot run this test without flashlight dependency") + with open_dict(cfg): + cfg.nbest = 1 + return W2lViterbiDecoder(cfg, task.target_dictionary) + + def postprocess_tokens(self, task, target, hypo_tokens): + tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu() + tgt_str = task.target_dictionary.string(tgt_tokens) + tgt_str = post_process(tgt_str, "letter") + + hypo_pieces = task.target_dictionary.string(hypo_tokens) + hypo_str = post_process(hypo_pieces, "letter") + return tgt_str, hypo_str + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/speech/test_xm_transformer.py b/tests/speech/test_xm_transformer.py index 60bb0dc6e..43d321bef 100644 --- a/tests/speech/test_xm_transformer.py +++ b/tests/speech/test_xm_transformer.py @@ -4,11 +4,6 @@ # LICENSE file in the root directory of this source tree. import unittest - -import torch -from tqdm import tqdm - -from fairseq import utils from tests.speech import TestFairseqSpeech @@ -16,40 +11,17 @@ class TestXMTransformer(TestFairseqSpeech): def setUp(self): self.set_up_sotasty_es_en() - @torch.no_grad() def test_sotasty_es_en_600m_checkpoint(self): - models, cfg, task, generator = self.download_and_load_checkpoint( - "xm_transformer_600m_es_en_md.pt", + self.base_test( + ckpt_name="xm_transformer_600m_es_en_md.pt", + reference_score=30.42, + score_delta=0.2, + max_tokens=3_000_000, + max_positions=(1_000_000, 1_024), + dataset="sotasty_es_en_test_ted", arg_overrides={"config_yaml": "cfg_es_en.yaml"}, + score_type="bleu", ) - if not self.use_cuda: - return - - batch_iterator = self.get_batch_iterator( - task, "sotasty_es_en_test_ted", 3_000_000, (1_000_000, 1_024) - ) - scorer = self.get_bleu_scorer() - progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator)) - for batch_idx, sample in progress: - sample = utils.move_to_cuda(sample) if self.use_cuda else sample - hypo = task.inference_step(generator, models, sample) - for i, sample_id in enumerate(sample["id"].tolist()): - tgt_tokens = ( - utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad()) - .int() - .cpu() - ) - tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece") - hypo_str = task.tgt_dict.string( - hypo[i][0]["tokens"].int().cpu(), "sentencepiece" - ) - if batch_idx == 0 and i < 3: - print(f"T-{sample_id} {tgt_str}") - print(f"H-{sample_id} {hypo_str}") - scorer.add_string(tgt_str, hypo_str) - reference_bleu = 31.7 - print(f"{scorer.result_string()} (reference: {reference_bleu})") - self.assertAlmostEqual(scorer.score(), reference_bleu, delta=0.2) if __name__ == "__main__":