add xm_transformer test; refactor speech tests

Summary: add xm_transformer test; refactor speech tests

Reviewed By: sravyapopuri388

Differential Revision: D33312231

fbshipit-source-id: a2b2695fc3c10d5420abbe23a4a3005777aa2ae1
This commit is contained in:
Changhan Wang 2021-12-31 12:29:14 -08:00 committed by Facebook GitHub Bot
parent 59b3ada2e2
commit ee177fc4fa
5 changed files with 164 additions and 97 deletions

View File

@ -3,12 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
import os
import unittest
from pathlib import Path
from typing import List, Dict, Optional
import torch
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.wer import WerScorer
from fairseq.scoring.bleu import SacrebleuScorer
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
@ -21,34 +27,96 @@ class TestFairseqSpeech(unittest.TestCase):
torch.hub.download_url_to_file(url, path.as_posix(), progress=True)
return path
def set_up_librispeech(self):
def _set_up(self, dataset_id: str, s3_dir: str, data_filenames: List[str]):
self.use_cuda = torch.cuda.is_available()
self.root = Path.home() / ".cache" / "fairseq" / "librispeech"
self.root = Path.home() / ".cache" / "fairseq" / dataset_id
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.data_filenames = [
"cfg_librispeech.yaml",
"spm_librispeech_unigram10000.model",
"spm_librispeech_unigram10000.txt",
"librispeech_test-other.tsv",
"librispeech_test-other.zip",
]
self.base_url = f"{S3_BASE_URL}/s2t/librispeech"
for filename in self.data_filenames:
self.base_url = f"{S3_BASE_URL}/{s3_dir}"
for filename in data_filenames:
self.download(self.base_url, self.root, filename)
def set_up_librispeech(self):
self._set_up(
"librispeech",
"s2t/librispeech",
[
"cfg_librispeech.yaml",
"spm_librispeech_unigram10000.model",
"spm_librispeech_unigram10000.txt",
"librispeech_test-other.tsv",
"librispeech_test-other.zip",
],
)
def set_up_ljspeech(self):
self.use_cuda = torch.cuda.is_available()
self.root = Path.home() / ".cache" / "fairseq" / "ljspeech"
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.data_filenames = [
"cfg_ljspeech_g2p.yaml",
"ljspeech_g2p_gcmvn_stats.npz",
"ljspeech_g2p.txt",
"ljspeech_test.tsv",
"ljspeech_test.zip",
]
self.base_url = f"{S3_BASE_URL}/s2/ljspeech"
for filename in self.data_filenames:
self.download(self.base_url, self.root, filename)
self._set_up(
"ljspeech",
"s2/ljspeech",
[
"cfg_ljspeech_g2p.yaml",
"ljspeech_g2p_gcmvn_stats.npz",
"ljspeech_g2p.txt",
"ljspeech_test.tsv",
"ljspeech_test.zip",
],
)
def set_up_sotasty_es_en(self):
self._set_up(
"sotasty_es_en",
"s2t/big/es-en",
[
"cfg_es_en.yaml",
"spm_bpe32768_es_en.model",
"spm_bpe32768_es_en.txt",
"sotasty_es_en_test_ted.tsv",
"sotasty_es_en_test_ted.zip",
],
)
def download_and_load_checkpoint(
self, checkpoint_filename: str, arg_overrides: Optional[Dict[str, str]] = None
):
path = self.download(self.base_url, self.root, checkpoint_filename)
_arg_overrides = arg_overrides or {}
_arg_overrides["data"] = self.root.as_posix()
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()], arg_overrides=_arg_overrides
)
if self.use_cuda:
for model in models:
model.cuda()
generator = task.build_generator(models, cfg)
return models, cfg, task, generator
@classmethod
def get_batch_iterator(cls, task, test_split, max_tokens, max_positions):
task.load_dataset(test_split)
return task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=max_tokens,
max_positions=max_positions,
num_workers=1,
).next_epoch_itr(shuffle=False)
@classmethod
def get_wer_scorer(
cls, tokenizer="none", lowercase=False, remove_punct=False, char_level=False
):
scorer_args = {
"wer_tokenizer": tokenizer,
"wer_lowercase": lowercase,
"wer_remove_punct": remove_punct,
"wer_char_level": char_level,
}
return WerScorer(Namespace(**scorer_args))
@classmethod
def get_bleu_scorer(cls, tokenizer="13a", lowercase=False, char_level=False):
scorer_args = {
"sacrebleu_tokenizer": tokenizer,
"sacrebleu_lowercase": lowercase,
"sacrebleu_char_level": char_level,
}
return SacrebleuScorer(Namespace(**scorer_args))

View File

@ -9,7 +9,6 @@ import torch
from tqdm import tqdm
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
from tests.speech import TestFairseqSpeech
@ -21,33 +20,17 @@ class TestFastSpeech2(TestFairseqSpeech):
@torch.no_grad()
def test_ljspeech_fastspeech2_checkpoint(self):
checkpoint_filename = "ljspeech_fastspeech2_g2p.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()],
models, cfg, task, generator = self.download_and_load_checkpoint(
"ljspeech_fastspeech2_g2p.pt",
arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "cfg_ljspeech_g2p.yaml",
"vocoder": "griffin_lim",
"fp16": False,
},
)
if self.use_cuda:
for model in models:
model.cuda()
test_split = "ljspeech_test"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=65_536,
max_positions=4_096,
num_workers=1,
).next_epoch_itr(shuffle=False)
batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 4_096)
progress = tqdm(batch_iterator, total=len(batch_iterator))
generator = task.build_generator(models, cfg)
mcd, n_samples = 0.0, 0
for sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample

View File

@ -4,54 +4,31 @@
# LICENSE file in the root directory of this source tree.
import unittest
from argparse import Namespace
import torch
from tqdm import tqdm
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.wer import WerScorer
from tests.speech import TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestS2TTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_librispeech()
@torch.no_grad()
def test_librispeech_s2t_transformer_s_checkpoint(self):
checkpoint_filename = "librispeech_transformer_s.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()],
arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "cfg_librispeech.yaml",
},
models, cfg, task, generator = self.download_and_load_checkpoint(
"librispeech_transformer_s.pt",
arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
)
if self.use_cuda:
for model in models:
model.cuda()
generator = task.build_generator(models, cfg)
test_split = "librispeech_test-other"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=65_536,
max_positions=(4_096, 1_024),
num_workers=1,
).next_epoch_itr(shuffle=False)
if not self.use_cuda:
return
scorer_args = {
"wer_tokenizer": "none",
"wer_lowercase": False,
"wer_remove_punct": False,
"wer_char_level": False,
}
scorer = WerScorer(Namespace(**scorer_args))
batch_iterator = self.get_batch_iterator(
task, "librispeech_test-other", 65_536, (4_096, 1_024)
)
scorer = self.get_wer_scorer()
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample

View File

@ -9,7 +9,6 @@ import torch
from tqdm import tqdm
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
from tests.speech import TestFairseqSpeech
@ -21,33 +20,17 @@ class TestTTSTransformer(TestFairseqSpeech):
@torch.no_grad()
def test_ljspeech_tts_transformer_checkpoint(self):
checkpoint_filename = "ljspeech_transformer_g2p.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()],
models, cfg, task, generator = self.download_and_load_checkpoint(
"ljspeech_transformer_g2p.pt",
arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "cfg_ljspeech_g2p.yaml",
"vocoder": "griffin_lim",
"fp16": False,
},
)
if self.use_cuda:
for model in models:
model.cuda()
test_split = "ljspeech_test"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=65_536,
max_positions=768,
num_workers=1,
).next_epoch_itr(shuffle=False)
batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 1024)
progress = tqdm(batch_iterator, total=len(batch_iterator))
generator = task.build_generator(models, cfg)
mcd, n_samples = 0.0, 0
for sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample

View File

@ -0,0 +1,56 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from tqdm import tqdm
from fairseq import utils
from tests.speech import TestFairseqSpeech
class TestXMTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_sotasty_es_en()
@torch.no_grad()
def test_sotasty_es_en_600m_checkpoint(self):
models, cfg, task, generator = self.download_and_load_checkpoint(
"xm_transformer_600m_es_en_md.pt",
arg_overrides={"config_yaml": "cfg_es_en.yaml"},
)
if not self.use_cuda:
return
batch_iterator = self.get_batch_iterator(
task, "sotasty_es_en_test_ted", 3_000_000, (1_000_000, 1_024)
)
scorer = self.get_bleu_scorer()
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypo = task.inference_step(generator, models, sample)
for i, sample_id in enumerate(sample["id"].tolist()):
tgt_tokens = (
utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
.int()
.cpu()
)
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(
hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
)
if batch_idx == 0 and i < 3:
print(f"T-{sample_id} {tgt_str}")
print(f"H-{sample_id} {hypo_str}")
scorer.add_string(tgt_str, hypo_str)
reference_bleu = 31.7
print(f"{scorer.result_string()} (reference: {reference_bleu})")
self.assertAlmostEqual(scorer.score(), reference_bleu, delta=0.2)
if __name__ == "__main__":
unittest.main()