diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 0d7c03441..fa0d45961 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -126,7 +126,9 @@ def gen_config_yaml( specaugment_policy: str = "lb", prepend_tgt_lang_tag: bool = False, sampling_alpha: float = 1.0, - audio_root: str = "" + audio_root: str = "", + cmvn_type: str = "utterance", + gcmvn_path: Optional[Path] = None, ): manifest_root = manifest_root.absolute() writer = S2TDataConfigWriter(manifest_root / yaml_filename) @@ -151,8 +153,19 @@ def gen_config_yaml( if prepend_tgt_lang_tag: writer.set_prepend_tgt_lang_tag(True) writer.set_sampling_alpha(sampling_alpha) - writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"]) - writer.set_feature_transforms("*", ["utterance_cmvn"]) + + if cmvn_type not in ["global", "utterance"]: + raise NotImplementedError + + writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"]) + writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) + + if cmvn_type == "global": + assert gcmvn_path is not None, ( + 'Please provide path of global cmvn file.' + ) + writer.set_global_cmvn(gcmvn_path) + if len(audio_root) > 0: writer.set_audio_root(audio_root) writer.flush() @@ -206,6 +219,16 @@ def filter_manifest_df( return df[valid] +def cal_gcmvn_stats(features_list): + features = np.concatenate(features_list) + square_sums = (features ** 2).sum(axis=0) + mean = features.mean(axis=0) + features = np.subtract(features, mean) + var = square_sums / features.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-8)) + return {"mean": mean.astype("float32"), "std": std.astype("float32")} + + class S2TDataConfigWriter(object): DEFAULT_VOCAB_FILENAME = "dict.txt" DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 @@ -297,6 +320,9 @@ class S2TDataConfigWriter(object): def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): self.config["bpe_tokenizer"] = bpe_tokenizer + def set_global_cmvn(self, stats_npz_path: str): + self.config["stats_npz_path"] = stats_npz_path + def set_feature_transforms(self, split: str, transforms: List[str]): if "transforms" not in self.config: self.config["transforms"] = {} diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 520968401..4e410bcb1 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -13,6 +13,7 @@ from itertools import groupby from tempfile import NamedTemporaryFile from typing import Tuple +import numpy as np import pandas as pd import torchaudio from examples.speech_to_text.data_utils import ( @@ -24,6 +25,7 @@ from examples.speech_to_text.data_utils import ( get_zip_manifest, load_df_from_tsv, save_df_to_tsv, + cal_gcmvn_stats, ) from torch import Tensor from torch.utils.data import Dataset @@ -111,10 +113,28 @@ def process(args): print(f"Fetching split {split}...") dataset = MUSTC(root.as_posix(), lang, split) print("Extracting log mel filter bank features...") + if split == 'train' and args.cmvn_type == "global": + print("And estimating cepstral mean and variance stats...") + gcmvn_feature_list = [] + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features( - waveform, sample_rate, feature_root / f"{utt_id}.npy" + features = extract_fbank_features(waveform, sample_rate) + + np.save( + (feature_root / f"{utt_id}.npy").as_posix(), + features ) + + if split == 'train' and args.cmvn_type == "global": + if len(gcmvn_feature_list) < args.gcmvn_max_num: + gcmvn_feature_list.append(features) + + if split == 'train' and args.cmvn_type == "global": + # Estimate and save cmv + stats = cal_gcmvn_stats(gcmvn_feature_list) + with open(cur_root / "gcmvn.npz", "wb") as f: + np.savez(f, mean=stats["mean"], std=stats["std"]) + # Pack features into ZIP zip_path = cur_root / "fbank80.zip" print("ZIPing features...") @@ -158,6 +178,11 @@ def process(args): spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="lb", + cmvn_type=args.cmvn_type, + gcmvn_cmvn_path=( + cur_root / "gcmvn.npz" if args.cmvn_type == "global" + else None + ), ) # Clean up shutil.rmtree(feature_root) @@ -216,6 +241,14 @@ def main(): parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--joint", action="store_true", help="") + parser.add_argument("--cmvn-type", default="utterance", + choices=["global", "utterance"], + help="The type of cepstral mean and variance normalization") + parser.add_argument("--gcmvn-max-num", default=150000, type=int, + help=( + "Maximum number of sentences to use to estimate" + "global mean and variance" + )) args = parser.parse_args() if args.joint: