mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-21 06:13:31 +03:00
d839d84f1e
Summary: Miscellaneous S2T & S2 bug fixes Reviewed By: yuntang Differential Revision: D33469556 fbshipit-source-id: 430c2cad01dd7ea862a6c1564ad609887d66b788
384 lines
12 KiB
Python
384 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import csv
|
|
from pathlib import Path
|
|
import zipfile
|
|
from functools import reduce
|
|
from multiprocessing import cpu_count
|
|
from typing import Any, Dict, List, Optional, Union
|
|
import io
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import sentencepiece as sp
|
|
from fairseq.data.audio.audio_utils import (
|
|
convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data,
|
|
is_sf_audio_data
|
|
)
|
|
import torch
|
|
import soundfile as sf
|
|
from tqdm import tqdm
|
|
|
|
|
|
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
|
|
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
|
|
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
|
|
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
|
|
|
|
|
|
def gen_vocab(
|
|
input_path: Path, output_path_prefix: Path, model_type="bpe",
|
|
vocab_size=1000, special_symbols: Optional[List[str]] = None
|
|
):
|
|
# Train SentencePiece Model
|
|
arguments = [
|
|
f"--input={input_path.as_posix()}",
|
|
f"--model_prefix={output_path_prefix.as_posix()}",
|
|
f"--model_type={model_type}",
|
|
f"--vocab_size={vocab_size}",
|
|
"--character_coverage=1.0",
|
|
f"--num_threads={cpu_count()}",
|
|
f"--unk_id={UNK_TOKEN_ID}",
|
|
f"--bos_id={BOS_TOKEN_ID}",
|
|
f"--eos_id={EOS_TOKEN_ID}",
|
|
f"--pad_id={PAD_TOKEN_ID}",
|
|
]
|
|
if special_symbols is not None:
|
|
_special_symbols = ",".join(special_symbols)
|
|
arguments.append(f"--user_defined_symbols={_special_symbols}")
|
|
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
|
# Export fairseq dictionary
|
|
spm = sp.SentencePieceProcessor()
|
|
spm.Load(output_path_prefix.as_posix() + ".model")
|
|
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
|
assert (
|
|
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
|
|
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
|
|
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
|
|
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
|
)
|
|
vocab = {
|
|
i: s
|
|
for i, s in vocab.items()
|
|
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
|
}
|
|
with open(output_path_prefix.as_posix() + ".txt", "w") as f_out:
|
|
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
|
f_out.write(f"{s} 1\n")
|
|
|
|
|
|
def extract_fbank_features(
|
|
waveform: torch.FloatTensor,
|
|
sample_rate: int,
|
|
output_path: Optional[Path] = None,
|
|
n_mel_bins: int = 80,
|
|
overwrite: bool = False,
|
|
):
|
|
if output_path is not None and output_path.is_file() and not overwrite:
|
|
return
|
|
|
|
_waveform, _ = convert_waveform(waveform, sample_rate, to_mono=True)
|
|
# Kaldi compliance: 16-bit signed integers
|
|
_waveform = _waveform * (2 ** 15)
|
|
_waveform = _waveform.numpy()
|
|
|
|
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
|
|
if features is None:
|
|
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
|
if features is None:
|
|
raise ImportError(
|
|
"Please install pyKaldi or torchaudio to enable fbank feature extraction"
|
|
)
|
|
|
|
if output_path is not None:
|
|
np.save(output_path.as_posix(), features)
|
|
return features
|
|
|
|
|
|
def create_zip(data_root: Path, zip_path: Path):
|
|
paths = list(data_root.glob("*.npy"))
|
|
paths.extend(data_root.glob("*.flac"))
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
|
|
for path in tqdm(paths):
|
|
f.write(path, arcname=path.name)
|
|
|
|
|
|
def get_zip_manifest(
|
|
zip_path: Path, zip_root: Optional[Path] = None, is_audio=False
|
|
):
|
|
_zip_path = Path.joinpath(zip_root or Path(""), zip_path)
|
|
with zipfile.ZipFile(_zip_path, mode="r") as f:
|
|
info = f.infolist()
|
|
paths, lengths = {}, {}
|
|
for i in tqdm(info):
|
|
utt_id = Path(i.filename).stem
|
|
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
|
paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}"
|
|
with open(_zip_path, "rb") as f:
|
|
f.seek(offset)
|
|
byte_data = f.read(file_size)
|
|
assert len(byte_data) > 1
|
|
if is_audio:
|
|
assert is_sf_audio_data(byte_data), i
|
|
else:
|
|
assert is_npy_data(byte_data), i
|
|
byte_data_fp = io.BytesIO(byte_data)
|
|
if is_audio:
|
|
lengths[utt_id] = sf.info(byte_data_fp).frames
|
|
else:
|
|
lengths[utt_id] = np.load(byte_data_fp).shape[0]
|
|
return paths, lengths
|
|
|
|
|
|
def gen_config_yaml(
|
|
manifest_root: Path,
|
|
spm_filename: Optional[str] = None,
|
|
vocab_name: Optional[str] = None,
|
|
yaml_filename: str = "config.yaml",
|
|
specaugment_policy: Optional[str] = "lb",
|
|
prepend_tgt_lang_tag: bool = False,
|
|
sampling_alpha: Optional[float] = None,
|
|
input_channels: Optional[int] = 1,
|
|
input_feat_per_channel: Optional[int] = 80,
|
|
audio_root: str = "",
|
|
cmvn_type: str = "utterance",
|
|
gcmvn_path: Optional[Path] = None,
|
|
extra=None
|
|
):
|
|
manifest_root = manifest_root.absolute()
|
|
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
|
|
assert spm_filename is not None or vocab_name is not None
|
|
vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \
|
|
else vocab_name
|
|
writer.set_vocab_filename(vocab_name)
|
|
if input_channels is not None:
|
|
writer.set_input_channels(input_channels)
|
|
if input_feat_per_channel is not None:
|
|
writer.set_input_feat_per_channel(input_feat_per_channel)
|
|
specaugment_setters = {
|
|
"lb": writer.set_specaugment_lb_policy,
|
|
"ld": writer.set_specaugment_ld_policy,
|
|
"sm": writer.set_specaugment_sm_policy,
|
|
"ss": writer.set_specaugment_ss_policy,
|
|
}
|
|
specaugment_setter = specaugment_setters.get(specaugment_policy, None)
|
|
if specaugment_setter is not None:
|
|
specaugment_setter()
|
|
if spm_filename is not None:
|
|
writer.set_bpe_tokenizer(
|
|
{
|
|
"bpe": "sentencepiece",
|
|
"sentencepiece_model": (manifest_root / spm_filename).as_posix(),
|
|
}
|
|
)
|
|
if prepend_tgt_lang_tag:
|
|
writer.set_prepend_tgt_lang_tag(True)
|
|
if sampling_alpha is not None:
|
|
writer.set_sampling_alpha(sampling_alpha)
|
|
|
|
if cmvn_type not in ["global", "utterance"]:
|
|
raise NotImplementedError
|
|
|
|
if specaugment_policy is not None:
|
|
writer.set_feature_transforms(
|
|
"_train", [f"{cmvn_type}_cmvn", "specaugment"]
|
|
)
|
|
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
|
|
|
|
if cmvn_type == "global":
|
|
if gcmvn_path is None:
|
|
raise ValueError("Please provide path of global cmvn file.")
|
|
else:
|
|
writer.set_global_cmvn(gcmvn_path.as_posix())
|
|
|
|
if len(audio_root) > 0:
|
|
writer.set_audio_root(audio_root)
|
|
|
|
if extra is not None:
|
|
writer.set_extra(extra)
|
|
writer.flush()
|
|
|
|
|
|
def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame:
|
|
_path = path if isinstance(path, str) else path.as_posix()
|
|
return pd.read_csv(
|
|
_path,
|
|
sep="\t",
|
|
header=0,
|
|
encoding="utf-8",
|
|
escapechar="\\",
|
|
quoting=csv.QUOTE_NONE,
|
|
na_filter=False,
|
|
)
|
|
|
|
|
|
def save_df_to_tsv(dataframe, path: Union[str, Path]):
|
|
_path = path if isinstance(path, str) else path.as_posix()
|
|
dataframe.to_csv(
|
|
_path,
|
|
sep="\t",
|
|
header=True,
|
|
index=False,
|
|
encoding="utf-8",
|
|
escapechar="\\",
|
|
quoting=csv.QUOTE_NONE,
|
|
)
|
|
|
|
|
|
def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]:
|
|
with open(path, "r") as f:
|
|
reader = csv.DictReader(
|
|
f,
|
|
delimiter="\t",
|
|
quotechar=None,
|
|
doublequote=False,
|
|
lineterminator="\n",
|
|
quoting=csv.QUOTE_NONE,
|
|
)
|
|
rows = [dict(e) for e in reader]
|
|
return rows
|
|
|
|
|
|
def filter_manifest_df(
|
|
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
|
|
):
|
|
filters = {
|
|
"no speech": df["audio"] == "",
|
|
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
|
|
"empty sentence": df["tgt_text"] == "",
|
|
}
|
|
if is_train_split:
|
|
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
|
|
if extra_filters is not None:
|
|
filters.update(extra_filters)
|
|
invalid = reduce(lambda x, y: x | y, filters.values())
|
|
valid = ~invalid
|
|
print(
|
|
"| "
|
|
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
|
|
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
|
|
)
|
|
return df[valid]
|
|
|
|
|
|
def cal_gcmvn_stats(features_list):
|
|
features = np.concatenate(features_list)
|
|
square_sums = (features ** 2).sum(axis=0)
|
|
mean = features.mean(axis=0)
|
|
features = np.subtract(features, mean)
|
|
var = square_sums / features.shape[0] - mean ** 2
|
|
std = np.sqrt(np.maximum(var, 1e-8))
|
|
return {"mean": mean.astype("float32"), "std": std.astype("float32")}
|
|
|
|
|
|
class S2TDataConfigWriter(object):
|
|
DEFAULT_VOCAB_FILENAME = "dict.txt"
|
|
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
|
DEFAULT_INPUT_CHANNELS = 1
|
|
|
|
def __init__(self, yaml_path: Path):
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
print("Please install PyYAML for S2T data config YAML files")
|
|
self.yaml = yaml
|
|
self.yaml_path = yaml_path
|
|
self.config = {}
|
|
|
|
def flush(self):
|
|
with open(self.yaml_path, "w") as f:
|
|
self.yaml.dump(self.config, f)
|
|
|
|
def set_audio_root(self, audio_root=""):
|
|
self.config["audio_root"] = audio_root
|
|
|
|
def set_vocab_filename(self, vocab_filename: str = "dict.txt"):
|
|
self.config["vocab_filename"] = vocab_filename
|
|
|
|
def set_specaugment(
|
|
self,
|
|
time_wrap_w: int,
|
|
freq_mask_n: int,
|
|
freq_mask_f: int,
|
|
time_mask_n: int,
|
|
time_mask_t: int,
|
|
time_mask_p: float,
|
|
):
|
|
self.config["specaugment"] = {
|
|
"time_wrap_W": time_wrap_w,
|
|
"freq_mask_N": freq_mask_n,
|
|
"freq_mask_F": freq_mask_f,
|
|
"time_mask_N": time_mask_n,
|
|
"time_mask_T": time_mask_t,
|
|
"time_mask_p": time_mask_p,
|
|
}
|
|
|
|
def set_specaugment_lb_policy(self):
|
|
self.set_specaugment(
|
|
time_wrap_w=0,
|
|
freq_mask_n=1,
|
|
freq_mask_f=27,
|
|
time_mask_n=1,
|
|
time_mask_t=100,
|
|
time_mask_p=1.0,
|
|
)
|
|
|
|
def set_specaugment_ld_policy(self):
|
|
self.set_specaugment(
|
|
time_wrap_w=0,
|
|
freq_mask_n=2,
|
|
freq_mask_f=27,
|
|
time_mask_n=2,
|
|
time_mask_t=100,
|
|
time_mask_p=1.0,
|
|
)
|
|
|
|
def set_specaugment_sm_policy(self):
|
|
self.set_specaugment(
|
|
time_wrap_w=0,
|
|
freq_mask_n=2,
|
|
freq_mask_f=15,
|
|
time_mask_n=2,
|
|
time_mask_t=70,
|
|
time_mask_p=0.2,
|
|
)
|
|
|
|
def set_specaugment_ss_policy(self):
|
|
self.set_specaugment(
|
|
time_wrap_w=0,
|
|
freq_mask_n=2,
|
|
freq_mask_f=27,
|
|
time_mask_n=2,
|
|
time_mask_t=70,
|
|
time_mask_p=0.2,
|
|
)
|
|
|
|
def set_input_channels(self, input_channels: int = 1):
|
|
self.config["input_channels"] = input_channels
|
|
|
|
def set_input_feat_per_channel(self, input_feat_per_channel: int = 80):
|
|
self.config["input_feat_per_channel"] = input_feat_per_channel
|
|
|
|
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
|
self.config["bpe_tokenizer"] = bpe_tokenizer
|
|
|
|
def set_global_cmvn(self, stats_npz_path: str):
|
|
self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path}
|
|
|
|
def set_feature_transforms(self, split: str, transforms: List[str]):
|
|
if "transforms" not in self.config:
|
|
self.config["transforms"] = {}
|
|
self.config["transforms"][split] = transforms
|
|
|
|
def set_prepend_tgt_lang_tag(self, flag: bool = True):
|
|
self.config["prepend_tgt_lang_tag"] = flag
|
|
|
|
def set_sampling_alpha(self, sampling_alpha: float = 1.0):
|
|
self.config["sampling_alpha"] = sampling_alpha
|
|
|
|
def set_extra(self, data):
|
|
self.config.update(data)
|