mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-21 14:17:25 +03:00
Add global cmvn for mustc data preparation (#1660)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1660 Reviewed By: jmp84, kahne Differential Revision: D26708521 Pulled By: xutaima fbshipit-source-id: c53e9052298c559706ceffeb359dadfede2f1a09
This commit is contained in:
parent
3100d0b8e5
commit
12e21b9a6e
@ -126,7 +126,9 @@ def gen_config_yaml(
|
||||
specaugment_policy: str = "lb",
|
||||
prepend_tgt_lang_tag: bool = False,
|
||||
sampling_alpha: float = 1.0,
|
||||
audio_root: str = ""
|
||||
audio_root: str = "",
|
||||
cmvn_type: str = "utterance",
|
||||
gcmvn_path: Optional[Path] = None,
|
||||
):
|
||||
manifest_root = manifest_root.absolute()
|
||||
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
|
||||
@ -151,8 +153,19 @@ def gen_config_yaml(
|
||||
if prepend_tgt_lang_tag:
|
||||
writer.set_prepend_tgt_lang_tag(True)
|
||||
writer.set_sampling_alpha(sampling_alpha)
|
||||
writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"])
|
||||
writer.set_feature_transforms("*", ["utterance_cmvn"])
|
||||
|
||||
if cmvn_type not in ["global", "utterance"]:
|
||||
raise NotImplementedError
|
||||
|
||||
writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"])
|
||||
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
|
||||
|
||||
if cmvn_type == "global":
|
||||
assert gcmvn_path is not None, (
|
||||
'Please provide path of global cmvn file.'
|
||||
)
|
||||
writer.set_global_cmvn(gcmvn_path)
|
||||
|
||||
if len(audio_root) > 0:
|
||||
writer.set_audio_root(audio_root)
|
||||
writer.flush()
|
||||
@ -206,6 +219,16 @@ def filter_manifest_df(
|
||||
return df[valid]
|
||||
|
||||
|
||||
def cal_gcmvn_stats(features_list):
|
||||
features = np.concatenate(features_list)
|
||||
square_sums = (features ** 2).sum(axis=0)
|
||||
mean = features.mean(axis=0)
|
||||
features = np.subtract(features, mean)
|
||||
var = square_sums / features.shape[0] - mean ** 2
|
||||
std = np.sqrt(np.maximum(var, 1e-8))
|
||||
return {"mean": mean.astype("float32"), "std": std.astype("float32")}
|
||||
|
||||
|
||||
class S2TDataConfigWriter(object):
|
||||
DEFAULT_VOCAB_FILENAME = "dict.txt"
|
||||
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
||||
@ -297,6 +320,9 @@ class S2TDataConfigWriter(object):
|
||||
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
||||
self.config["bpe_tokenizer"] = bpe_tokenizer
|
||||
|
||||
def set_global_cmvn(self, stats_npz_path: str):
|
||||
self.config["stats_npz_path"] = stats_npz_path
|
||||
|
||||
def set_feature_transforms(self, split: str, transforms: List[str]):
|
||||
if "transforms" not in self.config:
|
||||
self.config["transforms"] = {}
|
||||
|
@ -13,6 +13,7 @@ from itertools import groupby
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torchaudio
|
||||
from examples.speech_to_text.data_utils import (
|
||||
@ -24,6 +25,7 @@ from examples.speech_to_text.data_utils import (
|
||||
get_zip_manifest,
|
||||
load_df_from_tsv,
|
||||
save_df_to_tsv,
|
||||
cal_gcmvn_stats,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
@ -111,10 +113,28 @@ def process(args):
|
||||
print(f"Fetching split {split}...")
|
||||
dataset = MUSTC(root.as_posix(), lang, split)
|
||||
print("Extracting log mel filter bank features...")
|
||||
if split == 'train' and args.cmvn_type == "global":
|
||||
print("And estimating cepstral mean and variance stats...")
|
||||
gcmvn_feature_list = []
|
||||
|
||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||
extract_fbank_features(
|
||||
waveform, sample_rate, feature_root / f"{utt_id}.npy"
|
||||
features = extract_fbank_features(waveform, sample_rate)
|
||||
|
||||
np.save(
|
||||
(feature_root / f"{utt_id}.npy").as_posix(),
|
||||
features
|
||||
)
|
||||
|
||||
if split == 'train' and args.cmvn_type == "global":
|
||||
if len(gcmvn_feature_list) < args.gcmvn_max_num:
|
||||
gcmvn_feature_list.append(features)
|
||||
|
||||
if split == 'train' and args.cmvn_type == "global":
|
||||
# Estimate and save cmv
|
||||
stats = cal_gcmvn_stats(gcmvn_feature_list)
|
||||
with open(cur_root / "gcmvn.npz", "wb") as f:
|
||||
np.savez(f, mean=stats["mean"], std=stats["std"])
|
||||
|
||||
# Pack features into ZIP
|
||||
zip_path = cur_root / "fbank80.zip"
|
||||
print("ZIPing features...")
|
||||
@ -158,6 +178,11 @@ def process(args):
|
||||
spm_filename_prefix + ".model",
|
||||
yaml_filename=f"config_{args.task}.yaml",
|
||||
specaugment_policy="lb",
|
||||
cmvn_type=args.cmvn_type,
|
||||
gcmvn_cmvn_path=(
|
||||
cur_root / "gcmvn.npz" if args.cmvn_type == "global"
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
@ -216,6 +241,14 @@ def main():
|
||||
parser.add_argument("--vocab-size", default=8000, type=int)
|
||||
parser.add_argument("--task", type=str, choices=["asr", "st"])
|
||||
parser.add_argument("--joint", action="store_true", help="")
|
||||
parser.add_argument("--cmvn-type", default="utterance",
|
||||
choices=["global", "utterance"],
|
||||
help="The type of cepstral mean and variance normalization")
|
||||
parser.add_argument("--gcmvn-max-num", default=150000, type=int,
|
||||
help=(
|
||||
"Maximum number of sentences to use to estimate"
|
||||
"global mean and variance"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.joint:
|
||||
|
Loading…
Reference in New Issue
Block a user