Add global cmvn for mustc data preparation (#1660)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1660

Reviewed By: jmp84, kahne

Differential Revision: D26708521

Pulled By: xutaima

fbshipit-source-id: c53e9052298c559706ceffeb359dadfede2f1a09
This commit is contained in:
Xutai Ma 2021-03-02 13:28:53 -08:00 committed by Facebook GitHub Bot
parent 3100d0b8e5
commit 12e21b9a6e
2 changed files with 64 additions and 5 deletions

View File

@ -126,7 +126,9 @@ def gen_config_yaml(
specaugment_policy: str = "lb",
prepend_tgt_lang_tag: bool = False,
sampling_alpha: float = 1.0,
audio_root: str = ""
audio_root: str = "",
cmvn_type: str = "utterance",
gcmvn_path: Optional[Path] = None,
):
manifest_root = manifest_root.absolute()
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
@ -151,8 +153,19 @@ def gen_config_yaml(
if prepend_tgt_lang_tag:
writer.set_prepend_tgt_lang_tag(True)
writer.set_sampling_alpha(sampling_alpha)
writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"])
writer.set_feature_transforms("*", ["utterance_cmvn"])
if cmvn_type not in ["global", "utterance"]:
raise NotImplementedError
writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"])
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
if cmvn_type == "global":
assert gcmvn_path is not None, (
'Please provide path of global cmvn file.'
)
writer.set_global_cmvn(gcmvn_path)
if len(audio_root) > 0:
writer.set_audio_root(audio_root)
writer.flush()
@ -206,6 +219,16 @@ def filter_manifest_df(
return df[valid]
def cal_gcmvn_stats(features_list):
features = np.concatenate(features_list)
square_sums = (features ** 2).sum(axis=0)
mean = features.mean(axis=0)
features = np.subtract(features, mean)
var = square_sums / features.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-8))
return {"mean": mean.astype("float32"), "std": std.astype("float32")}
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
@ -297,6 +320,9 @@ class S2TDataConfigWriter(object):
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config["bpe_tokenizer"] = bpe_tokenizer
def set_global_cmvn(self, stats_npz_path: str):
self.config["stats_npz_path"] = stats_npz_path
def set_feature_transforms(self, split: str, transforms: List[str]):
if "transforms" not in self.config:
self.config["transforms"] = {}

View File

@ -13,6 +13,7 @@ from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import numpy as np
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
@ -24,6 +25,7 @@ from examples.speech_to_text.data_utils import (
get_zip_manifest,
load_df_from_tsv,
save_df_to_tsv,
cal_gcmvn_stats,
)
from torch import Tensor
from torch.utils.data import Dataset
@ -111,10 +113,28 @@ def process(args):
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split)
print("Extracting log mel filter bank features...")
if split == 'train' and args.cmvn_type == "global":
print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = []
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, feature_root / f"{utt_id}.npy"
features = extract_fbank_features(waveform, sample_rate)
np.save(
(feature_root / f"{utt_id}.npy").as_posix(),
features
)
if split == 'train' and args.cmvn_type == "global":
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global":
# Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(cur_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"])
# Pack features into ZIP
zip_path = cur_root / "fbank80.zip"
print("ZIPing features...")
@ -158,6 +178,11 @@ def process(args):
spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="lb",
cmvn_type=args.cmvn_type,
gcmvn_cmvn_path=(
cur_root / "gcmvn.npz" if args.cmvn_type == "global"
else None
),
)
# Clean up
shutil.rmtree(feature_root)
@ -216,6 +241,14 @@ def main():
parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--task", type=str, choices=["asr", "st"])
parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
parser.add_argument("--gcmvn-max-num", default=150000, type=int,
help=(
"Maximum number of sentences to use to estimate"
"global mean and variance"
))
args = parser.parse_args()
if args.joint: