speech integration tests for jointly trained models

Summary: Add test for DualInputS2TTransformerModel at examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py

Reviewed By: kahne

Differential Revision: D33284188

fbshipit-source-id: c02b697fc7734425661e00bbb606852b5d94a587
This commit is contained in:
Yun Tang 2022-01-07 12:44:15 -08:00 committed by Facebook GitHub Bot
parent c9a8bea83f
commit e69f1fa37f
2 changed files with 129 additions and 1 deletions

View File

@ -5,6 +5,7 @@
from argparse import Namespace
import os
import re
import unittest
from pathlib import Path
from typing import List, Dict, Optional
@ -32,7 +33,9 @@ class TestFairseqSpeech(unittest.TestCase):
self.root = Path.home() / ".cache" / "fairseq" / dataset_id
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.base_url = f"{S3_BASE_URL}/{s3_dir}"
self.base_url = (
s3_dir if re.search("^https:", s3_dir) else f"{S3_BASE_URL}/{s3_dir}"
)
for filename in data_filenames:
self.download(self.base_url, self.root, filename)
@ -75,6 +78,21 @@ class TestFairseqSpeech(unittest.TestCase):
],
)
def set_up_mustc_de_fbank(self):
self._set_up(
"mustc_de_fbank",
"https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de",
[
"config.yaml",
"spm.model",
"dict.txt",
"src_dict.txt",
"tgt_dict.txt",
"tst-COMMON.tsv",
"tst-COMMON.zip",
],
)
def download_and_load_checkpoint(
self, checkpoint_filename: str, arg_overrides: Optional[Dict[str, str]] = None
):

View File

@ -0,0 +1,110 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from argparse import Namespace
from collections import namedtuple
from pathlib import Path
import torch
from tqdm import tqdm
import fairseq
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.bleu import SacrebleuScorer
from fairseq.tasks import import_tasks
from tests.speech import TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestDualInputS2TTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_mustc_de_fbank()
def import_user_module(self):
user_dir = (
Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
)
Arg = namedtuple("Arg", ["user_dir"])
arg = Arg(user_dir.__str__())
utils.import_user_module(arg)
@torch.no_grad()
def test_mustc_de_fbank_dualinput_s2t_transformer_checkpoint(self):
self.import_user_module()
checkpoint_filename = "checkpoint_ave_10.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()],
arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "config.yaml",
"load_pretrain_speech_encoder": "",
"load_pretrain_text_encoder_last": "",
"load_pretrain_decoder": "",
"beam": 10,
"nbest": 1,
"lenpen": 1.0,
"load_speech_only": True,
},
)
if self.use_cuda:
for model in models:
model.cuda()
generator = task.build_generator(models, cfg)
test_split = "tst-COMMON"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=250_000,
max_positions=(10_000, 1_024),
num_workers=1,
).next_epoch_itr(shuffle=False)
tokenizer = task.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(cfg.bpe)
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
scorer_args = {
"sacrebleu_tokenizer": "13a",
"sacrebleu_lowercase": False,
"sacrebleu_char_level": False,
}
scorer = SacrebleuScorer(Namespace(**scorer_args))
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypo = task.inference_step(generator, models, sample)
for i, sample_id in enumerate(sample["id"].tolist()):
tgt_tokens = (
utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
.int()
.cpu()
)
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(
hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
)
if batch_idx == 0 and i < 3:
print(f"T-{sample_id} {tgt_str}")
print(f"D-{sample_id} {hypo_str}")
scorer.add_string(tgt_str, hypo_str)
reference_bleu = 27.3
result = scorer.result_string()
print(result + f" (reference: {reference_bleu})")
res_bleu = float(result.split()[2])
self.assertAlmostEqual(res_bleu, reference_bleu, delta=0.3)
if __name__ == "__main__":
unittest.main()