Refactor speech tests and add missing regression tests (#3001)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/3001

Reviewed By: kahne

Differential Revision: D33904550

Pulled By: sravyapopuri388

fbshipit-source-id: f55f8121d83e5abebdfcf7ac90dcba39f65cafaf
This commit is contained in:
Sravya Popuri 2022-02-04 14:33:48 -08:00 committed by Facebook GitHub Bot
parent 6b7a7d6457
commit 11b2830d29
6 changed files with 211 additions and 59 deletions

View File

@ -15,7 +15,7 @@ from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.wer import WerScorer
from fairseq.scoring.bleu import SacrebleuScorer
from fairseq import utils
import zipfile
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
@ -95,19 +95,30 @@ class TestFairseqSpeech(unittest.TestCase):
)
def download_and_load_checkpoint(
self, checkpoint_filename: str, arg_overrides: Optional[Dict[str, str]] = None
self,
checkpoint_filename: str,
arg_overrides: Optional[Dict[str, str]] = None,
strict: bool = True,
):
path = self.download(self.base_url, self.root, checkpoint_filename)
_arg_overrides = arg_overrides or {}
_arg_overrides["data"] = self.root.as_posix()
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()], arg_overrides=_arg_overrides
[path.as_posix()], arg_overrides=_arg_overrides, strict=strict
)
if self.use_cuda:
for model in models:
model.cuda()
generator = task.build_generator(models, cfg)
return models, cfg, task, generator
return models, cfg, task, self.build_generator(task, models, cfg)
def build_generator(
self,
task,
models,
cfg,
):
return task.build_generator(models, cfg)
@classmethod
def get_batch_iterator(cls, task, test_split, max_tokens, max_positions):
@ -141,35 +152,59 @@ class TestFairseqSpeech(unittest.TestCase):
return SacrebleuScorer(Namespace(**scorer_args))
@torch.no_grad()
def librispeech_s2t_test_base(self, ckpt_name, reference_wer):
models, cfg, task, generator = self.download_and_load_checkpoint(
ckpt_name,
arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
def base_test(
self,
ckpt_name,
reference_score,
score_delta=0.3,
dataset="librispeech_test-other",
max_tokens=65_536,
max_positions=(4_096, 1_024),
arg_overrides=None,
strict=True,
score_type="wer",
):
models, _, task, generator = self.download_and_load_checkpoint(
ckpt_name, arg_overrides=arg_overrides, strict=strict
)
if not self.use_cuda:
return
batch_iterator = self.get_batch_iterator(
task, "librispeech_test-other", 65_536, (4_096, 1_024)
task, dataset, max_tokens, max_positions
)
scorer = self.get_wer_scorer()
if score_type == "bleu":
scorer = self.get_bleu_scorer()
elif score_type == "wer":
scorer = self.get_wer_scorer()
else:
raise Exception(f"Unsupported score type {score_type}")
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypo = task.inference_step(generator, models, sample)
for i, sample_id in enumerate(sample["id"].tolist()):
tgt_tokens = (
utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
.int()
.cpu()
)
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(
hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
tgt_str, hypo_str = self.postprocess_tokens(
task,
sample["target"][i, :],
hypo[i][0]["tokens"].int().cpu(),
)
if batch_idx == 0 and i < 3:
print(f"T-{sample_id} {tgt_str}")
print(f"H-{sample_id} {hypo_str}")
scorer.add_string(tgt_str, hypo_str)
print(scorer.result_string() + f" (reference: {reference_wer})")
self.assertAlmostEqual(scorer.score(), reference_wer, delta=0.3)
print(scorer.result_string() + f" (reference: {reference_score})")
self.assertAlmostEqual(scorer.score(), reference_score, delta=score_delta)
def postprocess_tokens(self, task, target, hypo_tokens):
tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(hypo_tokens, "sentencepiece")
return tgt_str, hypo_str
def unzip_files(self, zip_file_name):
zip_file_path = self.root / zip_file_name
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(self.root / zip_file_name.strip(".zip"))

View File

@ -0,0 +1,47 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from tests.speech import TestFairseqSpeech
from fairseq import utils
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
class TestS2STransformer(TestFairseqSpeech):
def setUp(self):
self._set_up(
"s2s",
"speech_tests/s2s",
[
"dev_shuf200.tsv",
"src_feat.zip",
"config_specaug_lb.yaml",
"config_letter_enc_unigram_dec.yaml",
],
)
def test_s2s_transformer_checkpoint(self):
self.base_test(
ckpt_name="s2u_transformer_reduced_fisher.pt",
reference_score=38.3,
dataset="dev_shuf200",
arg_overrides={
"config_yaml": "config_specaug_lb.yaml",
"target_is_code": True,
"target_code_size": 100,
},
score_type="bleu",
)
def postprocess_tokens(self, task, target, hypo_tokens):
tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
tgt_str = task.tgt_dict.string(tgt_tokens)
hypo_str = task.tgt_dict.string(hypo_tokens)
return tgt_str, hypo_str
if __name__ == "__main__":
unittest.main()

View File

@ -12,7 +12,11 @@ class TestS2TConformer(TestFairseqSpeech):
self.set_up_librispeech()
def test_librispeech_s2t_conformer_s_checkpoint(self):
self.librispeech_s2t_test_base("librispeech_conformer_rel_pos_s.pt", 12)
self.base_test(
ckpt_name="librispeech_conformer_rel_pos_s.pt",
reference_score=12,
arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
)
if __name__ == "__main__":

View File

@ -12,7 +12,11 @@ class TestS2TTransformer(TestFairseqSpeech):
self.set_up_librispeech()
def test_librispeech_s2t_transformer_s_checkpoint(self):
self.librispeech_s2t_test_base("librispeech_transformer_s.pt", 9)
self.base_test(
ckpt_name="librispeech_transformer_s.pt",
reference_score=9,
arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
)
if __name__ == "__main__":

View File

@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from tests.speech import TestFairseqSpeech
from fairseq.data.data_utils import post_process
from fairseq import utils
from omegaconf import open_dict
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestWav2Vec2(TestFairseqSpeech):
def setUp(self):
self._set_up(
"librispeech_w2v2",
"conformer/wav2vec2/librispeech",
[
"test_librispeech-other.ltr",
"test_librispeech-other.tsv",
"test_librispeech-other_small.ltr_100",
"test_librispeech-other_small.tsv",
"test-other.zip",
"dict.ltr.txt",
"dict.ltr_100.txt",
],
)
self.unzip_files(
"test-other.zip",
)
def test_transformer_w2v2(self):
self.base_test(
ckpt_name="transformer_oss_small_100h.pt",
reference_score=38,
score_delta=1,
dataset="test_librispeech-other",
max_tokens=1000000,
max_positions=(700000, 1000),
arg_overrides={
"task": "audio_finetuning",
"labels": "ltr",
"nbest": 1,
"tpu": False,
},
strict=False,
)
def test_conformer_w2v2(self):
self.base_test(
ckpt_name="conformer_LS_PT_LS_FT_rope.pt",
reference_score=4.5,
score_delta=1,
dataset="test_librispeech-other_small",
max_tokens=1000000,
max_positions=(700000, 1000),
arg_overrides={
"task": "audio_finetuning",
"labels": "ltr_100",
"nbest": 1,
"tpu": False,
},
strict=True,
)
def build_generator(self, task, models, cfg):
try:
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
except Exception:
raise Exception("Cannot run this test without flashlight dependency")
with open_dict(cfg):
cfg.nbest = 1
return W2lViterbiDecoder(cfg, task.target_dictionary)
def postprocess_tokens(self, task, target, hypo_tokens):
tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu()
tgt_str = task.target_dictionary.string(tgt_tokens)
tgt_str = post_process(tgt_str, "letter")
hypo_pieces = task.target_dictionary.string(hypo_tokens)
hypo_str = post_process(hypo_pieces, "letter")
return tgt_str, hypo_str
if __name__ == "__main__":
unittest.main()

View File

@ -4,11 +4,6 @@
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from tqdm import tqdm
from fairseq import utils
from tests.speech import TestFairseqSpeech
@ -16,40 +11,17 @@ class TestXMTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_sotasty_es_en()
@torch.no_grad()
def test_sotasty_es_en_600m_checkpoint(self):
models, cfg, task, generator = self.download_and_load_checkpoint(
"xm_transformer_600m_es_en_md.pt",
self.base_test(
ckpt_name="xm_transformer_600m_es_en_md.pt",
reference_score=30.42,
score_delta=0.2,
max_tokens=3_000_000,
max_positions=(1_000_000, 1_024),
dataset="sotasty_es_en_test_ted",
arg_overrides={"config_yaml": "cfg_es_en.yaml"},
score_type="bleu",
)
if not self.use_cuda:
return
batch_iterator = self.get_batch_iterator(
task, "sotasty_es_en_test_ted", 3_000_000, (1_000_000, 1_024)
)
scorer = self.get_bleu_scorer()
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypo = task.inference_step(generator, models, sample)
for i, sample_id in enumerate(sample["id"].tolist()):
tgt_tokens = (
utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
.int()
.cpu()
)
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(
hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
)
if batch_idx == 0 and i < 3:
print(f"T-{sample_id} {tgt_str}")
print(f"H-{sample_id} {hypo_str}")
scorer.add_string(tgt_str, hypo_str)
reference_bleu = 31.7
print(f"{scorer.result_string()} (reference: {reference_bleu})")
self.assertAlmostEqual(scorer.score(), reference_bleu, delta=0.2)
if __name__ == "__main__":