speech integration tests (batch 1)

Summary:
Adding the first batch of speech integration tests (based on test set scores on pre-trained checkpoints) for
- S2T transformer
- TTS transformer

Reviewed By: yuntang

Differential Revision: D33050653

fbshipit-source-id: fb5bb9f46e8e17cb705971ca1990c8e1cb99d5f9
This commit is contained in:
Changhan Wang 2021-12-14 17:40:59 -08:00 committed by Facebook GitHub Bot
parent c2b771b1be
commit ee833ed49d
3 changed files with 187 additions and 0 deletions

54
tests/speech/__init__.py Normal file
View File

@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from pathlib import Path
import unittest
import torch
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
class TestFairseqSpeech(unittest.TestCase):
@classmethod
def download(cls, base_url: str, out_root: Path, filename: str):
url = f"{base_url}/{filename}"
path = out_root / filename
if not path.exists():
torch.hub.download_url_to_file(url, path.as_posix(), progress=True)
return path
def set_up_librispeech(self):
self.use_cuda = torch.cuda.is_available()
self.root = Path.home() / ".cache" / "fairseq" / "librispeech"
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.data_filenames = [
"cfg_librispeech.yaml",
"spm_librispeech_unigram10000.model",
"spm_librispeech_unigram10000.txt",
"librispeech_test-other.tsv",
"librispeech_test-other.zip"
]
self.base_url = f"{S3_BASE_URL}/s2t/librispeech"
for filename in self.data_filenames:
self.download(self.base_url, self.root, filename)
def set_up_ljspeech(self):
self.use_cuda = torch.cuda.is_available()
self.root = Path.home() / ".cache" / "fairseq" / "ljspeech"
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.data_filenames = [
"cfg_ljspeech_g2p.yaml",
"ljspeech_g2p_gcmvn_stats.npz",
"ljspeech_g2p.txt",
"ljspeech_test.tsv",
"ljspeech_test.zip"
]
self.base_url = f"{S3_BASE_URL}/s2/ljspeech"
for filename in self.data_filenames:
self.download(self.base_url, self.root, filename)

View File

@ -0,0 +1,66 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
import unittest
import torch
from tqdm import tqdm
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq import utils
from fairseq.scoring.wer import WerScorer
from tests.speech import TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestS2TTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_librispeech()
@torch.no_grad()
def test_librispeech_s2t_transformer_s_checkpoint(self):
checkpoint_filename = "librispeech_transformer_s.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()], arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "cfg_librispeech.yaml"
}
)
if self.use_cuda:
for model in models:
model.cuda()
generator = task.build_generator(models, cfg)
test_split = "librispeech_test-other"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=65_536, max_positions=(4_096, 1_024), num_workers=1
).next_epoch_itr(shuffle=False)
scorer_args = {"wer_tokenizer": "none", "wer_lowercase": False,
"wer_remove_punct": False, "wer_char_level": False}
scorer = WerScorer(Namespace(**scorer_args))
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypo = task.inference_step(generator, models, sample)
for i, sample_id in enumerate(sample["id"].tolist()):
tgt_tokens = utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad()).int().cpu()
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(hypo[i][0]["tokens"].int().cpu(), "sentencepiece")
if batch_idx == 0 and i < 3:
print(f"T-{sample_id} {tgt_str}")
print(f"H-{sample_id} {hypo_str}")
scorer.add_string(tgt_str, hypo_str)
reference_wer = 9.0
print(scorer.result_string() + f" (reference: {reference_wer})")
self.assertAlmostEqual(scorer.score(), reference_wer, delta=0.3)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from tqdm import tqdm
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq import utils
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
from tests.speech import TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestTTSTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_ljspeech()
@torch.no_grad()
def test_ljspeech_tts_transformer_checkpoint(self):
checkpoint_filename = "ljspeech_transformer_g2p.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()], arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "cfg_ljspeech_g2p.yaml",
"vocoder": "griffin_lim",
"fp16": False
}
)
if self.use_cuda:
for model in models:
model.cuda()
test_split = "ljspeech_test"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=65_536, max_positions=768, num_workers=1
).next_epoch_itr(shuffle=False)
progress = tqdm(batch_iterator, total=len(batch_iterator))
generator = task.build_generator(models, cfg)
mcd, n_samples = 0., 0
for sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypos = generator.generate(models[0], sample, has_targ=True)
rets = batch_mel_cepstral_distortion(
[hypo["targ_waveform"] for hypo in hypos],
[hypo["waveform"] for hypo in hypos],
sr=task.sr
)
mcd += sum(d.item() for d, _ in rets)
n_samples += len(sample["id"].tolist())
mcd = round(mcd / n_samples, 1)
reference_mcd = 3.3
print(f"MCD: {mcd} (reference: {reference_mcd})")
self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)
if __name__ == "__main__":
unittest.main()