mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
add TTS
Summary: [fairseq-py] add TTS Reviewed By: wnhsu Differential Revision: D30720666 fbshipit-source-id: b5288acec72bea1d3a9f3884a4ed51b616c7a403
This commit is contained in:
parent
32b31173aa
commit
0ac3f3270c
320
examples/speech_synthesis/data_utils.py
Normal file
320
examples/speech_synthesis/data_utils.py
Normal file
@ -0,0 +1,320 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
import zipfile
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from itertools import groupby
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from examples.speech_to_text.data_utils import load_tsv_to_dicts
|
||||
from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale
|
||||
|
||||
|
||||
def trim_or_pad_to_target_length(
|
||||
data_1d_or_2d: np.ndarray, target_length: int
|
||||
) -> np.ndarray:
|
||||
assert len(data_1d_or_2d.shape) in {1, 2}
|
||||
delta = data_1d_or_2d.shape[0] - target_length
|
||||
if delta >= 0: # trim if being longer
|
||||
data_1d_or_2d = data_1d_or_2d[: target_length]
|
||||
else: # pad if being shorter
|
||||
if len(data_1d_or_2d.shape) == 1:
|
||||
data_1d_or_2d = np.concatenate(
|
||||
[data_1d_or_2d, np.zeros(-delta)], axis=0
|
||||
)
|
||||
else:
|
||||
data_1d_or_2d = np.concatenate(
|
||||
[data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))],
|
||||
axis=0
|
||||
)
|
||||
return data_1d_or_2d
|
||||
|
||||
|
||||
def extract_logmel_spectrogram(
|
||||
waveform: torch.Tensor, sample_rate: int,
|
||||
output_path: Optional[Path] = None, win_length: int = 1024,
|
||||
hop_length: int = 256, n_fft: int = 1024,
|
||||
win_fn: callable = torch.hann_window, n_mels: int = 80,
|
||||
f_min: float = 0., f_max: float = 8000, eps: float = 1e-5,
|
||||
overwrite: bool = False, target_length: Optional[int] = None
|
||||
):
|
||||
if output_path is not None and output_path.is_file() and not overwrite:
|
||||
return
|
||||
|
||||
spectrogram_transform = TTSSpectrogram(
|
||||
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
|
||||
window_fn=win_fn
|
||||
)
|
||||
mel_scale_transform = TTSMelScale(
|
||||
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
||||
n_stft=n_fft // 2 + 1
|
||||
)
|
||||
spectrogram = spectrogram_transform(waveform)
|
||||
mel_spec = mel_scale_transform(spectrogram)
|
||||
logmel_spec = torch.clamp(mel_spec, min=eps).log()
|
||||
assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1
|
||||
logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D
|
||||
if target_length is not None:
|
||||
trim_or_pad_to_target_length(logmel_spec, target_length)
|
||||
|
||||
if output_path is not None:
|
||||
np.save(output_path.as_posix(), logmel_spec)
|
||||
else:
|
||||
return logmel_spec
|
||||
|
||||
|
||||
def extract_pitch(
|
||||
waveform: torch.Tensor, sample_rate: int,
|
||||
output_path: Optional[Path] = None, hop_length: int = 256,
|
||||
log_scale: bool = True, phoneme_durations: Optional[List[int]] = None
|
||||
):
|
||||
if output_path is not None and output_path.is_file():
|
||||
return
|
||||
|
||||
try:
|
||||
import pyworld
|
||||
except ImportError:
|
||||
raise ImportError("Please install PyWORLD: pip install pyworld")
|
||||
|
||||
_waveform = waveform.squeeze(0).double().numpy()
|
||||
pitch, t = pyworld.dio(
|
||||
_waveform, sample_rate, frame_period=hop_length / sample_rate * 1000
|
||||
)
|
||||
pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate)
|
||||
|
||||
if phoneme_durations is not None:
|
||||
pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations))
|
||||
try:
|
||||
from scipy.interpolate import interp1d
|
||||
except ImportError:
|
||||
raise ImportError("Please install SciPy: pip install scipy")
|
||||
nonzero_ids = np.where(pitch != 0)[0]
|
||||
interp_fn = interp1d(
|
||||
nonzero_ids,
|
||||
pitch[nonzero_ids],
|
||||
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
|
||||
bounds_error=False,
|
||||
)
|
||||
pitch = interp_fn(np.arange(0, len(pitch)))
|
||||
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
|
||||
pitch = np.array(
|
||||
[
|
||||
np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]])
|
||||
for i in range(1, len(d_cumsum))
|
||||
]
|
||||
)
|
||||
assert len(pitch) == len(phoneme_durations)
|
||||
|
||||
if log_scale:
|
||||
pitch = np.log(pitch + 1)
|
||||
|
||||
if output_path is not None:
|
||||
np.save(output_path.as_posix(), pitch)
|
||||
else:
|
||||
return pitch
|
||||
|
||||
|
||||
def extract_energy(
|
||||
waveform: torch.Tensor, output_path: Optional[Path] = None,
|
||||
hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True,
|
||||
phoneme_durations: Optional[List[int]] = None
|
||||
):
|
||||
if output_path is not None and output_path.is_file():
|
||||
return
|
||||
|
||||
assert len(waveform.shape) == 2 and waveform.shape[0] == 1
|
||||
waveform = waveform.view(1, 1, waveform.shape[1])
|
||||
waveform = F.pad(
|
||||
waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0],
|
||||
mode="reflect"
|
||||
)
|
||||
waveform = waveform.squeeze(1)
|
||||
|
||||
fourier_basis = np.fft.fft(np.eye(n_fft))
|
||||
cutoff = int((n_fft / 2 + 1))
|
||||
fourier_basis = np.vstack(
|
||||
[np.real(fourier_basis[:cutoff, :]),
|
||||
np.imag(fourier_basis[:cutoff, :])]
|
||||
)
|
||||
|
||||
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||
forward_transform = F.conv1d(
|
||||
waveform, forward_basis, stride=hop_length, padding=0
|
||||
)
|
||||
|
||||
real_part = forward_transform[:, :cutoff, :]
|
||||
imag_part = forward_transform[:, cutoff:, :]
|
||||
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
|
||||
energy = torch.norm(magnitude, dim=1).squeeze(0).numpy()
|
||||
|
||||
if phoneme_durations is not None:
|
||||
energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations))
|
||||
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
|
||||
energy = np.array(
|
||||
[
|
||||
np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]])
|
||||
for i in range(1, len(d_cumsum))
|
||||
]
|
||||
)
|
||||
assert len(energy) == len(phoneme_durations)
|
||||
|
||||
if log_scale:
|
||||
energy = np.log(energy + 1)
|
||||
|
||||
if output_path is not None:
|
||||
np.save(output_path.as_posix(), energy)
|
||||
else:
|
||||
return energy
|
||||
|
||||
|
||||
def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None):
|
||||
mean_x, mean_x2, n_frames = None, None, 0
|
||||
feature_paths = feature_root.glob("*.npy")
|
||||
for p in tqdm(feature_paths):
|
||||
with open(p, 'rb') as f:
|
||||
frames = np.load(f).squeeze()
|
||||
|
||||
n_frames += frames.shape[0]
|
||||
|
||||
cur_mean_x = frames.sum(axis=0)
|
||||
if mean_x is None:
|
||||
mean_x = cur_mean_x
|
||||
else:
|
||||
mean_x += cur_mean_x
|
||||
|
||||
cur_mean_x2 = (frames ** 2).sum(axis=0)
|
||||
if mean_x2 is None:
|
||||
mean_x2 = cur_mean_x2
|
||||
else:
|
||||
mean_x2 += cur_mean_x2
|
||||
|
||||
mean_x /= n_frames
|
||||
mean_x2 /= n_frames
|
||||
var_x = mean_x2 - mean_x ** 2
|
||||
std_x = np.sqrt(np.maximum(var_x, 1e-10))
|
||||
|
||||
if output_path is not None:
|
||||
with open(output_path, 'wb') as f:
|
||||
np.savez(f, mean=mean_x, std=std_x)
|
||||
else:
|
||||
return {"mean": mean_x, "std": std_x}
|
||||
|
||||
|
||||
def ipa_phonemize(text, lang="en-us", use_g2p=False):
|
||||
if use_g2p:
|
||||
assert lang == "en-us", "g2pE phonemizer only works for en-us"
|
||||
try:
|
||||
from g2p_en import G2p
|
||||
g2p = G2p()
|
||||
return " ".join("|" if p == " " else p for p in g2p(text))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install phonemizer: pip install g2p_en"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from phonemizer import phonemize
|
||||
from phonemizer.separator import Separator
|
||||
return phonemize(
|
||||
text, backend='espeak', language=lang,
|
||||
separator=Separator(word="| ", phone=" ")
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install phonemizer: pip install phonemizer"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForceAlignmentInfo(object):
|
||||
tokens: List[str]
|
||||
frame_durations: List[int]
|
||||
start_sec: Optional[float]
|
||||
end_sec: Optional[float]
|
||||
|
||||
|
||||
def get_mfa_alignment_by_sample_id(
|
||||
textgrid_zip_path: str, sample_id: str, sample_rate: int,
|
||||
hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn")
|
||||
) -> ForceAlignmentInfo:
|
||||
try:
|
||||
import tgt
|
||||
except ImportError:
|
||||
raise ImportError("Please install TextGridTools: pip install tgt")
|
||||
|
||||
filename = f"{sample_id}.TextGrid"
|
||||
out_root = Path(tempfile.gettempdir())
|
||||
tgt_path = out_root / filename
|
||||
with zipfile.ZipFile(textgrid_zip_path) as f_zip:
|
||||
f_zip.extract(filename, path=out_root)
|
||||
textgrid = tgt.io.read_textgrid(tgt_path.as_posix())
|
||||
os.remove(tgt_path)
|
||||
|
||||
phones, frame_durations = [], []
|
||||
start_sec, end_sec, end_idx = 0, 0, 0
|
||||
for t in textgrid.get_tier_by_name("phones")._objects:
|
||||
s, e, p = t.start_time, t.end_time, t.text
|
||||
# Trim leading silences
|
||||
if len(phones) == 0:
|
||||
if p in silence_phones:
|
||||
continue
|
||||
else:
|
||||
start_sec = s
|
||||
phones.append(p)
|
||||
if p not in silence_phones:
|
||||
end_sec = e
|
||||
end_idx = len(phones)
|
||||
r = sample_rate / hop_length
|
||||
frame_durations.append(int(np.round(e * r) - np.round(s * r)))
|
||||
# Trim tailing silences
|
||||
phones = phones[:end_idx]
|
||||
frame_durations = frame_durations[:end_idx]
|
||||
|
||||
return ForceAlignmentInfo(
|
||||
tokens=phones, frame_durations=frame_durations, start_sec=start_sec,
|
||||
end_sec=end_sec
|
||||
)
|
||||
|
||||
|
||||
def get_mfa_alignment(
|
||||
textgrid_zip_path: str, sample_ids: List[str], sample_rate: int,
|
||||
hop_length: int
|
||||
) -> Dict[str, ForceAlignmentInfo]:
|
||||
return {
|
||||
i: get_mfa_alignment_by_sample_id(
|
||||
textgrid_zip_path, i, sample_rate, hop_length
|
||||
) for i in tqdm(sample_ids)
|
||||
}
|
||||
|
||||
|
||||
def get_unit_alignment(
|
||||
id_to_unit_tsv_path: str, sample_ids: List[str]
|
||||
) -> Dict[str, ForceAlignmentInfo]:
|
||||
id_to_units = {
|
||||
e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path)
|
||||
}
|
||||
id_to_units = {i: id_to_units[i].split() for i in sample_ids}
|
||||
id_to_units_collapsed = {
|
||||
i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items()
|
||||
}
|
||||
id_to_durations = {
|
||||
i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items()
|
||||
}
|
||||
|
||||
return {
|
||||
i: ForceAlignmentInfo(
|
||||
tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i],
|
||||
start_sec=None, end_sec=None
|
||||
)
|
||||
for i in sample_ids
|
||||
}
|
191
examples/speech_synthesis/generate_waveform.py
Normal file
191
examples/speech_synthesis/generate_waveform.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import soundfile as sf
|
||||
import sys
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from fairseq import checkpoint_utils, options, tasks, utils
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.tasks.text_to_speech import plot_tts_output
|
||||
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_parser():
|
||||
parser = options.get_speech_generation_parser()
|
||||
parser.add_argument("--dump-features", action="store_true")
|
||||
parser.add_argument("--dump-waveforms", action="store_true")
|
||||
parser.add_argument("--dump-attentions", action="store_true")
|
||||
parser.add_argument("--dump-eos-probs", action="store_true")
|
||||
parser.add_argument("--dump-plots", action="store_true")
|
||||
parser.add_argument("--dump-target", action="store_true")
|
||||
parser.add_argument("--output-sample-rate", default=22050, type=int)
|
||||
parser.add_argument("--teacher-forcing", action="store_true")
|
||||
parser.add_argument(
|
||||
"--audio-format", type=str, default="wav", choices=["wav", "flac"]
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def postprocess_results(
|
||||
dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target
|
||||
):
|
||||
def to_np(x):
|
||||
return None if x is None else x.detach().cpu().numpy()
|
||||
|
||||
sample_ids = [dataset.ids[i] for i in sample["id"].tolist()]
|
||||
texts = sample["src_texts"]
|
||||
attns = [to_np(hypo["attn"]) for hypo in hypos]
|
||||
eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos]
|
||||
feat_preds = [to_np(hypo["feature"]) for hypo in hypos]
|
||||
wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos]
|
||||
if dump_target:
|
||||
feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos]
|
||||
wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos]
|
||||
else:
|
||||
feat_targs = [None for _ in hypos]
|
||||
wave_targs = [None for _ in hypos]
|
||||
|
||||
return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds,
|
||||
feat_targs, wave_targs)
|
||||
|
||||
|
||||
def dump_result(
|
||||
is_na_model,
|
||||
args,
|
||||
vocoder,
|
||||
sample_id,
|
||||
text,
|
||||
attn,
|
||||
eos_prob,
|
||||
feat_pred,
|
||||
wave_pred,
|
||||
feat_targ,
|
||||
wave_targ,
|
||||
):
|
||||
sample_rate = args.output_sample_rate
|
||||
out_root = Path(args.results_path)
|
||||
if args.dump_features:
|
||||
feat_dir = out_root / "feat"
|
||||
feat_dir.mkdir(exist_ok=True, parents=True)
|
||||
np.save(feat_dir / f"{sample_id}.npy", feat_pred)
|
||||
if args.dump_target:
|
||||
feat_tgt_dir = out_root / "feat_tgt"
|
||||
feat_tgt_dir.mkdir(exist_ok=True, parents=True)
|
||||
np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ)
|
||||
if args.dump_attentions:
|
||||
attn_dir = out_root / "attn"
|
||||
attn_dir.mkdir(exist_ok=True, parents=True)
|
||||
np.save(attn_dir / f"{sample_id}.npy", attn.numpy())
|
||||
if args.dump_eos_probs and not is_na_model:
|
||||
eos_dir = out_root / "eos"
|
||||
eos_dir.mkdir(exist_ok=True, parents=True)
|
||||
np.save(eos_dir / f"{sample_id}.npy", eos_prob)
|
||||
|
||||
if args.dump_plots:
|
||||
images = [feat_pred.T] if is_na_model else [feat_pred.T, attn]
|
||||
names = ["output"] if is_na_model else ["output", "alignment"]
|
||||
if feat_targ is not None:
|
||||
images = [feat_targ.T] + images
|
||||
names = [f"target (idx={sample_id})"] + names
|
||||
if is_na_model:
|
||||
plot_tts_output(images, names, attn, "alignment", suptitle=text)
|
||||
else:
|
||||
plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text)
|
||||
plot_dir = out_root / "plot"
|
||||
plot_dir.mkdir(exist_ok=True, parents=True)
|
||||
plt.savefig(plot_dir / f"{sample_id}.png")
|
||||
plt.close()
|
||||
|
||||
if args.dump_waveforms:
|
||||
ext = args.audio_format
|
||||
if wave_pred is not None:
|
||||
wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}"
|
||||
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate)
|
||||
if args.dump_target and wave_targ is not None:
|
||||
wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt"
|
||||
wav_tgt_dir.mkdir(exist_ok=True, parents=True)
|
||||
sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate)
|
||||
|
||||
|
||||
def main(args):
|
||||
assert(args.dump_features or args.dump_waveforms or args.dump_attentions
|
||||
or args.dump_eos_probs or args.dump_plots)
|
||||
if args.max_tokens is None and args.batch_size is None:
|
||||
args.max_tokens = 8000
|
||||
logger.info(args)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
task = tasks.setup_task(args)
|
||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[args.path],
|
||||
task=task,
|
||||
)
|
||||
model = models[0].cuda() if use_cuda else models[0]
|
||||
# use the original n_frames_per_step
|
||||
task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step
|
||||
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
|
||||
|
||||
data_cfg = task.data_cfg
|
||||
sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050)
|
||||
resample_fn = {
|
||||
False: lambda x: x,
|
||||
True: lambda x: torchaudio.sox_effects.apply_effects_tensor(
|
||||
x.detach().cpu().unsqueeze(0), sample_rate,
|
||||
[['rate', str(args.output_sample_rate)]]
|
||||
)[0].squeeze(0)
|
||||
}.get(args.output_sample_rate != sample_rate)
|
||||
if args.output_sample_rate != sample_rate:
|
||||
logger.info(f"resampling to {args.output_sample_rate}Hz")
|
||||
|
||||
generator = task.build_generator([model], args)
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.batch_size,
|
||||
max_positions=(sys.maxsize, sys.maxsize),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=args.required_batch_size_multiple,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
data_buffer_size=args.data_buffer_size,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
|
||||
Path(args.results_path).mkdir(exist_ok=True, parents=True)
|
||||
is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False)
|
||||
dataset = task.dataset(args.gen_subset)
|
||||
vocoder = task.args.vocoder
|
||||
with progress_bar.build_progress_bar(args, itr) as t:
|
||||
for sample in t:
|
||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||
hypos = generator.generate(model, sample, has_targ=args.dump_target)
|
||||
for result in postprocess_results(
|
||||
dataset, sample, hypos, resample_fn, args.dump_target
|
||||
):
|
||||
dump_result(is_na_model, args, vocoder, *result)
|
||||
|
||||
|
||||
def cli_main():
|
||||
parser = make_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
101
examples/speech_synthesis/utils.py
Normal file
101
examples/speech_synthesis/utils.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.interpolate import interp1d
|
||||
import torchaudio
|
||||
|
||||
from fairseq.tasks.text_to_speech import (
|
||||
batch_compute_distortion, compute_rms_dist
|
||||
)
|
||||
|
||||
|
||||
def batch_mel_spectral_distortion(
|
||||
y1, y2, sr, normalize_type="path", mel_fn=None
|
||||
):
|
||||
"""
|
||||
https://arxiv.org/pdf/2011.03568.pdf
|
||||
|
||||
Same as Mel Cepstral Distortion, but computed on log-mel spectrograms.
|
||||
"""
|
||||
if mel_fn is None or mel_fn.sample_rate != sr:
|
||||
mel_fn = torchaudio.transforms.MelSpectrogram(
|
||||
sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr),
|
||||
hop_length=int(0.0125 * sr), f_min=20, n_mels=80,
|
||||
window_fn=torch.hann_window
|
||||
).to(y1[0].device)
|
||||
offset = 1e-6
|
||||
return batch_compute_distortion(
|
||||
y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2),
|
||||
compute_rms_dist, normalize_type
|
||||
)
|
||||
|
||||
|
||||
# This code is based on
|
||||
# "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py"
|
||||
def _same_t_in_true_and_est(func):
|
||||
def new_func(true_t, true_f, est_t, est_f):
|
||||
assert type(true_t) is np.ndarray
|
||||
assert type(true_f) is np.ndarray
|
||||
assert type(est_t) is np.ndarray
|
||||
assert type(est_f) is np.ndarray
|
||||
|
||||
interpolated_f = interp1d(
|
||||
est_t, est_f, bounds_error=False, kind='nearest', fill_value=0
|
||||
)(true_t)
|
||||
return func(true_t, true_f, true_t, interpolated_f)
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
@_same_t_in_true_and_est
|
||||
def gross_pitch_error(true_t, true_f, est_t, est_f):
|
||||
"""The relative frequency in percent of pitch estimates that are
|
||||
outside a threshold around the true pitch. Only frames that are
|
||||
considered pitched by both the ground truth and the estimator (if
|
||||
applicable) are considered.
|
||||
"""
|
||||
|
||||
correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
|
||||
gross_pitch_error_frames = _gross_pitch_error_frames(
|
||||
true_t, true_f, est_t, est_f
|
||||
)
|
||||
return np.sum(gross_pitch_error_frames) / np.sum(correct_frames)
|
||||
|
||||
|
||||
def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8):
|
||||
voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
|
||||
true_f_p_eps = [x + eps for x in true_f]
|
||||
pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2
|
||||
return voiced_frames & pitch_error_frames
|
||||
|
||||
|
||||
def _true_voiced_frames(true_t, true_f, est_t, est_f):
|
||||
return (est_f != 0) & (true_f != 0)
|
||||
|
||||
|
||||
def _voicing_decision_error_frames(true_t, true_f, est_t, est_f):
|
||||
return (est_f != 0) != (true_f != 0)
|
||||
|
||||
|
||||
@_same_t_in_true_and_est
|
||||
def f0_frame_error(true_t, true_f, est_t, est_f):
|
||||
gross_pitch_error_frames = _gross_pitch_error_frames(
|
||||
true_t, true_f, est_t, est_f
|
||||
)
|
||||
voicing_decision_error_frames = _voicing_decision_error_frames(
|
||||
true_t, true_f, est_t, est_f
|
||||
)
|
||||
return (np.sum(gross_pitch_error_frames) +
|
||||
np.sum(voicing_decision_error_frames)) / (len(true_t))
|
||||
|
||||
|
||||
@_same_t_in_true_and_est
|
||||
def voicing_decision_error(true_t, true_f, est_t, est_f):
|
||||
voicing_decision_error_frames = _voicing_decision_error_frames(
|
||||
true_t, true_f, est_t, est_f
|
||||
)
|
||||
return np.sum(voicing_decision_error_frames) / (len(true_t))
|
125
fairseq/criterions/fastspeech2_loss.py
Normal file
125
fairseq/criterions/fastspeech2_loss.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.data.data_utils import lengths_to_mask
|
||||
from fairseq.models.fairseq_model import FairseqEncoderModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class FastSpeech2CriterionConfig(FairseqDataclass):
|
||||
ctc_weight: float = field(
|
||||
default=0.0, metadata={"help": "weight for CTC loss"}
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
|
||||
class FastSpeech2Loss(FairseqCriterion):
|
||||
def __init__(self, task, ctc_weight):
|
||||
super().__init__(task)
|
||||
self.ctc_weight = ctc_weight
|
||||
|
||||
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lens = sample["net_input"]["src_lengths"]
|
||||
tgt_lens = sample["target_lengths"]
|
||||
_feat_out, _, log_dur_out, pitch_out, energy_out = model(
|
||||
src_tokens=src_tokens,
|
||||
src_lengths=src_lens,
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=tgt_lens,
|
||||
speaker=sample["speaker"],
|
||||
durations=sample["durations"],
|
||||
pitches=sample["pitches"],
|
||||
energies=sample["energies"]
|
||||
)
|
||||
|
||||
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
|
||||
tgt_mask = lengths_to_mask(sample["target_lengths"])
|
||||
|
||||
pitches, energies = sample["pitches"], sample["energies"]
|
||||
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
|
||||
energy_out, energies = energy_out[src_mask], energies[src_mask]
|
||||
|
||||
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
|
||||
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
|
||||
|
||||
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
|
||||
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
|
||||
|
||||
log_dur_out = log_dur_out[src_mask]
|
||||
dur = sample["durations"].float()
|
||||
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
|
||||
log_dur = torch.log(dur + 1)[src_mask]
|
||||
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
|
||||
|
||||
ctc_loss = torch.tensor(0.).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.:
|
||||
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
|
||||
lprobs = lprobs.transpose(0, 1) # T x B x C
|
||||
src_mask = lengths_to_mask(src_lens)
|
||||
src_tokens_flat = src_tokens.masked_select(src_mask)
|
||||
ctc_loss = F.ctc_loss(
|
||||
lprobs, src_tokens_flat, tgt_lens, src_lens,
|
||||
reduction=reduction, zero_infinity=True
|
||||
) * self.ctc_weight
|
||||
|
||||
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
|
||||
|
||||
sample_size = sample["nsentences"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"dur_loss": utils.item(dur_loss.data),
|
||||
"pitch_loss": utils.item(pitch_loss.data),
|
||||
"energy_loss": utils.item(energy_loss.data),
|
||||
"ctc_loss": utils.item(ctc_loss.data),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
||||
ns = [log.get("sample_size", 0) for log in logging_outputs]
|
||||
ntot = sum(ns)
|
||||
ws = [n / (ntot + 1e-8) for n in ns]
|
||||
for key in [
|
||||
"loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss",
|
||||
"ctc_loss"
|
||||
]:
|
||||
vals = [log.get(key, 0) for log in logging_outputs]
|
||||
val = sum(val * w for val, w in zip(vals, ws))
|
||||
metrics.log_scalar(key, val, ntot, round=3)
|
||||
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
|
||||
|
||||
# inference metrics
|
||||
if "targ_frames" not in logging_outputs[0]:
|
||||
return
|
||||
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
return False
|
210
fairseq/criterions/tacotron2_loss.py
Normal file
210
fairseq/criterions/tacotron2_loss.py
Normal file
@ -0,0 +1,210 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from functools import lru_cache
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.data.data_utils import lengths_to_mask
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tacotron2CriterionConfig(FairseqDataclass):
|
||||
bce_pos_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "weight of positive examples for BCE loss"},
|
||||
)
|
||||
n_frames_per_step: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Number of frames per decoding step"},
|
||||
)
|
||||
use_guided_attention_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use guided attention loss"},
|
||||
)
|
||||
guided_attention_loss_sigma: float = field(
|
||||
default=0.4,
|
||||
metadata={"help": "weight of positive examples for BCE loss"},
|
||||
)
|
||||
ctc_weight: float = field(
|
||||
default=0.0, metadata={"help": "weight for CTC loss"}
|
||||
)
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
|
||||
|
||||
class GuidedAttentionLoss(torch.nn.Module):
|
||||
"""
|
||||
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
|
||||
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
|
||||
"""
|
||||
|
||||
def __init__(self, sigma):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def _get_weight(s_len, t_len, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
|
||||
grid_x = grid_x.to(s_len.device)
|
||||
grid_y = grid_y.to(s_len.device)
|
||||
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
|
||||
return 1.0 - torch.exp(-w / (2 * (sigma ** 2)))
|
||||
|
||||
def _get_weights(self, src_lens, tgt_lens):
|
||||
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
|
||||
weights = torch.zeros((bsz, max_t_len, max_s_len))
|
||||
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
|
||||
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len,
|
||||
self.sigma)
|
||||
return weights
|
||||
|
||||
@staticmethod
|
||||
def _get_masks(src_lens, tgt_lens):
|
||||
in_masks = lengths_to_mask(src_lens)
|
||||
out_masks = lengths_to_mask(tgt_lens)
|
||||
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
|
||||
|
||||
def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
|
||||
weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
|
||||
masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
|
||||
loss = (weights * attn.transpose(1, 2)).masked_select(masks)
|
||||
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
|
||||
class Tacotron2Criterion(FairseqCriterion):
|
||||
def __init__(self, task, sentence_avg, n_frames_per_step,
|
||||
use_guided_attention_loss, guided_attention_loss_sigma,
|
||||
bce_pos_weight, ctc_weight):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.n_frames_per_step = n_frames_per_step
|
||||
self.bce_pos_weight = bce_pos_weight
|
||||
|
||||
self.guided_attn = None
|
||||
if use_guided_attention_loss:
|
||||
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
|
||||
self.ctc_weight = ctc_weight
|
||||
|
||||
def forward(self, model, sample, reduction="mean"):
|
||||
bsz, max_len, _ = sample["target"].size()
|
||||
feat_tgt = sample["target"]
|
||||
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
||||
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
||||
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
||||
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lens = sample["net_input"]["src_lengths"]
|
||||
tgt_lens = sample["target_lengths"]
|
||||
|
||||
feat_out, eos_out, extra = model(
|
||||
src_tokens=src_tokens,
|
||||
src_lengths=src_lens,
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=tgt_lens,
|
||||
speaker=sample["speaker"]
|
||||
)
|
||||
|
||||
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
||||
extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt,
|
||||
tgt_lens, reduction,
|
||||
)
|
||||
attn_loss = torch.tensor(0.).type_as(l1_loss)
|
||||
if self.guided_attn is not None:
|
||||
attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction)
|
||||
ctc_loss = torch.tensor(0.).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.:
|
||||
net_output = (feat_out, eos_out, extra)
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
lprobs = lprobs.transpose(0, 1) # T x B x C
|
||||
src_mask = lengths_to_mask(src_lens)
|
||||
src_tokens_flat = src_tokens.masked_select(src_mask)
|
||||
ctc_loss = F.ctc_loss(
|
||||
lprobs, src_tokens_flat, tgt_lens, src_lens,
|
||||
reduction=reduction, zero_infinity=True
|
||||
) * self.ctc_weight
|
||||
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
|
||||
|
||||
sample_size = sample["nsentences"] if self.sentence_avg \
|
||||
else sample["ntokens"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"mse_loss": utils.item(mse_loss.data),
|
||||
"eos_loss": utils.item(eos_loss.data),
|
||||
"attn_loss": utils.item(attn_loss.data),
|
||||
"ctc_loss": utils.item(ctc_loss.data),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt,
|
||||
eos_tgt, tgt_lens, reduction="mean"):
|
||||
mask = lengths_to_mask(tgt_lens)
|
||||
_eos_out = eos_out[mask].squeeze()
|
||||
_eos_tgt = eos_tgt[mask]
|
||||
_feat_tgt = feat_tgt[mask]
|
||||
_feat_out = feat_out[mask]
|
||||
_feat_out_post = feat_out_post[mask]
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) +
|
||||
F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction)
|
||||
)
|
||||
mse_loss = (
|
||||
F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) +
|
||||
F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction)
|
||||
)
|
||||
eos_loss = F.binary_cross_entropy_with_logits(
|
||||
_eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight),
|
||||
reduction=reduction
|
||||
)
|
||||
return l1_loss, mse_loss, eos_loss
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
||||
ns = [log.get("sample_size", 0) for log in logging_outputs]
|
||||
ntot = sum(ns)
|
||||
ws = [n / (ntot + 1e-8) for n in ns]
|
||||
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
|
||||
vals = [log.get(key, 0) for log in logging_outputs]
|
||||
val = sum(val * w for val, w in zip(vals, ws))
|
||||
metrics.log_scalar(key, val, ntot, round=3)
|
||||
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
|
||||
|
||||
# inference metrics
|
||||
if "targ_frames" not in logging_outputs[0]:
|
||||
return
|
||||
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
return False
|
207
fairseq/data/audio/frm_text_to_speech_dataset.py
Normal file
207
fairseq/data/audio/frm_text_to_speech_dataset.py
Normal file
@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.abs
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os.path as op
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig
|
||||
)
|
||||
from fairseq.data.audio.text_to_speech_dataset import (
|
||||
TextToSpeechDataset, TextToSpeechDatasetCreator
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
do_chunk=False,
|
||||
chunk_bound=-1,
|
||||
chunk_init=50,
|
||||
chunk_incr=5,
|
||||
add_eos=True,
|
||||
dedup=True,
|
||||
ref_fpu=-1
|
||||
):
|
||||
# It assumes texts are encoded at a fixed frame-rate
|
||||
super().__init__(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id
|
||||
)
|
||||
|
||||
self.do_chunk = do_chunk
|
||||
self.chunk_bound = chunk_bound
|
||||
self.chunk_init = chunk_init
|
||||
self.chunk_incr = chunk_incr
|
||||
self.add_eos = add_eos
|
||||
self.dedup = dedup
|
||||
self.ref_fpu = ref_fpu
|
||||
|
||||
self.chunk_size = -1
|
||||
|
||||
if do_chunk:
|
||||
assert self.chunk_incr >= 0
|
||||
assert self.pre_tokenizer is None
|
||||
|
||||
def __getitem__(self, index):
|
||||
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
|
||||
if target[-1].item() == self.tgt_dict.eos_index:
|
||||
target = target[:-1]
|
||||
|
||||
fpu = source.size(0) / target.size(0) # frame-per-unit
|
||||
fps = self.n_frames_per_step
|
||||
assert (
|
||||
self.ref_fpu == -1 or
|
||||
abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
|
||||
), f"{fpu*fps} != {self.ref_fpu}"
|
||||
|
||||
# only chunk training split
|
||||
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
|
||||
lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)]
|
||||
text = target[int(self.data_cfg.prepend_tgt_lang_tag):]
|
||||
size = len(text)
|
||||
chunk_size = min(self.chunk_size, size)
|
||||
chunk_start = np.random.randint(size - chunk_size + 1)
|
||||
text = text[chunk_start:chunk_start+chunk_size]
|
||||
target = torch.cat((lang, text), 0)
|
||||
|
||||
f_size = int(np.floor(chunk_size * fpu))
|
||||
f_start = int(np.floor(chunk_start * fpu))
|
||||
assert(f_size > 0)
|
||||
source = source[f_start:f_start+f_size, :]
|
||||
|
||||
if self.dedup:
|
||||
target = torch.unique_consecutive(target)
|
||||
|
||||
if self.add_eos:
|
||||
eos_idx = self.tgt_dict.eos_index
|
||||
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
|
||||
|
||||
return index, source, target, speaker_id
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
if self.is_train_split and self.do_chunk:
|
||||
old = self.chunk_size
|
||||
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
|
||||
if self.chunk_bound > 0:
|
||||
self.chunk_size = min(self.chunk_size, self.chunk_bound)
|
||||
logger.info((
|
||||
f"{self.split}: setting chunk size "
|
||||
f"from {old} to {self.chunk_size}"
|
||||
))
|
||||
|
||||
|
||||
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
# inherit for key names
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
data_cfg: S2TDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
n_frames_per_step: int,
|
||||
speaker_to_id,
|
||||
do_chunk: bool = False,
|
||||
chunk_bound: int = -1,
|
||||
chunk_init: int = 50,
|
||||
chunk_incr: int = 5,
|
||||
add_eos: bool = True,
|
||||
dedup: bool = True,
|
||||
ref_fpu: float = -1
|
||||
) -> FrmTextToSpeechDataset:
|
||||
tsv_path = op.join(root, f"{split}.tsv")
|
||||
if not op.isfile(tsv_path):
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
delimiter="\t",
|
||||
quotechar=None,
|
||||
doublequote=False,
|
||||
lineterminator="\n",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
s = [dict(e) for e in reader]
|
||||
assert len(s) > 0
|
||||
|
||||
ids = [ss[cls.KEY_ID] for ss in s]
|
||||
audio_paths = [
|
||||
op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s
|
||||
]
|
||||
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
|
||||
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
|
||||
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
|
||||
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
|
||||
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
|
||||
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
|
||||
|
||||
return FrmTextToSpeechDataset(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
do_chunk=do_chunk,
|
||||
chunk_bound=chunk_bound,
|
||||
chunk_init=chunk_init,
|
||||
chunk_incr=chunk_incr,
|
||||
add_eos=add_eos,
|
||||
dedup=dedup,
|
||||
ref_fpu=ref_fpu
|
||||
)
|
215
fairseq/data/audio/text_to_speech_dataset.py
Normal file
215
fairseq/data/audio/text_to_speech_dataset.py
Normal file
@ -0,0 +1,215 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.abs
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig,
|
||||
_collate_frames, get_features_or_waveform
|
||||
)
|
||||
from fairseq.data import Dictionary, data_utils as fairseq_data_utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextToSpeechDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
speaker_id: Optional[int] = None
|
||||
duration: Optional[torch.Tensor] = None
|
||||
pitch: Optional[torch.Tensor] = None
|
||||
energy: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class TextToSpeechDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
durations: Optional[List[List[int]]] = None,
|
||||
pitches: Optional[List[str]] = None,
|
||||
energies: Optional[List[str]] = None
|
||||
):
|
||||
super(TextToSpeechDataset, self).__init__(
|
||||
split, is_train_split, cfg, audio_paths, n_frames,
|
||||
src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers,
|
||||
src_langs=src_langs, tgt_langs=tgt_langs, ids=ids,
|
||||
tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id
|
||||
)
|
||||
self.durations = durations
|
||||
self.pitches = pitches
|
||||
self.energies = energies
|
||||
|
||||
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
|
||||
s2t_item = super().__getitem__(index)
|
||||
|
||||
duration, pitch, energy = None, None, None
|
||||
if self.durations is not None:
|
||||
duration = torch.tensor(
|
||||
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
|
||||
)
|
||||
if self.pitches is not None:
|
||||
pitch = get_features_or_waveform(self.pitches[index])
|
||||
pitch = torch.from_numpy(
|
||||
np.concatenate((pitch, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
if self.energies is not None:
|
||||
energy = get_features_or_waveform(self.energies[index])
|
||||
energy = torch.from_numpy(
|
||||
np.concatenate((energy, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
return TextToSpeechDatasetItem(
|
||||
index=index, source=s2t_item.source, target=s2t_item.target,
|
||||
speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch,
|
||||
energy=energy
|
||||
)
|
||||
|
||||
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
src_lengths, order = torch.tensor(
|
||||
[s.target.shape[0] for s in samples], dtype=torch.long
|
||||
).sort(descending=True)
|
||||
id_ = torch.tensor([s.index for s in samples],
|
||||
dtype=torch.long).index_select(0, order)
|
||||
feat = _collate_frames(
|
||||
[s.source for s in samples], self.cfg.use_audio_input
|
||||
).index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[s.source.shape[0] for s in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
|
||||
src_tokens = fairseq_data_utils.collate_tokens(
|
||||
[s.target for s in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
).index_select(0, order)
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = torch.tensor(
|
||||
[s.speaker_id for s in samples], dtype=torch.long
|
||||
).index_select(0, order).view(-1, 1)
|
||||
|
||||
bsz, _, d = feat.size()
|
||||
prev_output_tokens = torch.cat(
|
||||
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
|
||||
)
|
||||
|
||||
durations, pitches, energies = None, None, None
|
||||
if self.durations is not None:
|
||||
durations = fairseq_data_utils.collate_tokens(
|
||||
[s.duration for s in samples], 0
|
||||
).index_select(0, order)
|
||||
assert src_tokens.shape[1] == durations.shape[1]
|
||||
if self.pitches is not None:
|
||||
pitches = _collate_frames([s.pitch for s in samples], True)
|
||||
pitches = pitches.index_select(0, order)
|
||||
assert src_tokens.shape[1] == pitches.shape[1]
|
||||
if self.energies is not None:
|
||||
energies = _collate_frames([s.energy for s in samples], True)
|
||||
energies = energies.index_select(0, order)
|
||||
assert src_tokens.shape[1] == energies.shape[1]
|
||||
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
|
||||
|
||||
return {
|
||||
"id": id_,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
},
|
||||
"speaker": speaker,
|
||||
"target": feat,
|
||||
"durations": durations,
|
||||
"pitches": pitches,
|
||||
"energies": energies,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": sum(target_lengths).item(),
|
||||
"nsentences": len(samples),
|
||||
"src_texts": src_texts,
|
||||
}
|
||||
|
||||
|
||||
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
|
||||
KEY_DURATION = "duration"
|
||||
KEY_PITCH = "pitch"
|
||||
KEY_ENERGY = "energy"
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TDataConfig,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id
|
||||
) -> TextToSpeechDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
|
||||
durations = [
|
||||
None if dd is None else [int(d) for d in dd.split(" ")]
|
||||
for dd in durations
|
||||
]
|
||||
durations = None if any(dd is None for dd in durations) else durations
|
||||
|
||||
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
|
||||
pitches = [
|
||||
None if pp is None else (audio_root / pp).as_posix()
|
||||
for pp in pitches
|
||||
]
|
||||
pitches = None if any(pp is None for pp in pitches) else pitches
|
||||
|
||||
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
|
||||
energies = [
|
||||
None if ee is None else (audio_root / ee).as_posix()
|
||||
for ee in energies]
|
||||
energies = None if any(ee is None for ee in energies) else energies
|
||||
|
||||
return TextToSpeechDataset(
|
||||
split_name, is_train_split, cfg, audio_paths, n_frames,
|
||||
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id,
|
||||
durations, pitches, energies
|
||||
)
|
@ -92,7 +92,8 @@ class Dictionary:
|
||||
)
|
||||
|
||||
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
|
||||
extra_symbols_to_ignore.add(self.eos())
|
||||
if not include_eos:
|
||||
extra_symbols_to_ignore.add(self.eos())
|
||||
|
||||
def token_string(i):
|
||||
if i == self.unk():
|
||||
|
@ -5,5 +5,5 @@
|
||||
|
||||
from .berard import * # noqa
|
||||
from .convtransformer import * # noqa
|
||||
from .s2t_transformer import * # noqa
|
||||
from .xm_transformer import * # noqa
|
||||
from .s2t_transformer import * # noqa
|
||||
from .xm_transformer import * # noqa
|
||||
|
@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
8
fairseq/models/text_to_speech/__init__.py
Normal file
8
fairseq/models/text_to_speech/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .tacotron2 import * # noqa
|
||||
from .tts_transformer import * # noqa
|
||||
from .fastspeech2 import * # noqa
|
352
fairseq/models/text_to_speech/fastspeech2.py
Normal file
352
fairseq/models/text_to_speech/fastspeech2.py
Normal file
@ -0,0 +1,352 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model,
|
||||
register_model_architecture)
|
||||
from fairseq.modules import (
|
||||
LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention
|
||||
)
|
||||
from fairseq import utils
|
||||
from fairseq.data.data_utils import lengths_to_padding_mask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def model_init(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
||||
return m
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, kernel_size, dropout):
|
||||
super().__init__()
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2)
|
||||
)
|
||||
self.layer_norm = LayerNorm(in_dim)
|
||||
self.dropout = self.dropout_module = FairseqDropout(
|
||||
p=dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# B x T x C
|
||||
residual = x
|
||||
x = self.ffn(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = self.dropout(x)
|
||||
return self.layer_norm(x + residual)
|
||||
|
||||
|
||||
class FFTLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout,
|
||||
attention_dropout
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
embed_dim, n_heads, dropout=attention_dropout, self_attention=True
|
||||
)
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
self.ffn = PositionwiseFeedForward(
|
||||
embed_dim, hidden_dim, kernel_size, dropout=dropout
|
||||
)
|
||||
|
||||
def forward(self, x, padding_mask=None):
|
||||
# B x T x C
|
||||
residual = x
|
||||
x = x.transpose(0, 1)
|
||||
x, _ = self.self_attn(
|
||||
query=x, key=x, value=x, key_padding_mask=padding_mask,
|
||||
need_weights=False
|
||||
)
|
||||
x = x.transpose(0, 1)
|
||||
x = self.layer_norm(x + residual)
|
||||
return self.ffn(x)
|
||||
|
||||
|
||||
class LengthRegulator(nn.Module):
|
||||
def forward(self, x, durations):
|
||||
# x: B x T x C
|
||||
out_lens = durations.sum(dim=1)
|
||||
max_len = out_lens.max()
|
||||
bsz, seq_len, dim = x.size()
|
||||
out = x.new_zeros((bsz, max_len, dim))
|
||||
|
||||
for b in range(bsz):
|
||||
indices = []
|
||||
for t in range(seq_len):
|
||||
indices.extend([t] * utils.item(durations[b, t]))
|
||||
indices = torch.tensor(indices, dtype=torch.long).to(x.device)
|
||||
out_len = utils.item(out_lens[b])
|
||||
out[b, :out_len] = x[b].index_select(0, indices)
|
||||
|
||||
return out, out_lens
|
||||
|
||||
|
||||
class VariancePredictor(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
args.encoder_embed_dim, args.var_pred_hidden_dim,
|
||||
kernel_size=args.var_pred_kernel_size,
|
||||
padding=(args.var_pred_kernel_size - 1) // 2
|
||||
),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim)
|
||||
self.dropout_module = FairseqDropout(
|
||||
p=args.var_pred_dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
args.var_pred_hidden_dim, args.var_pred_hidden_dim,
|
||||
kernel_size=args.var_pred_kernel_size, padding=1
|
||||
),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim)
|
||||
self.proj = nn.Linear(args.var_pred_hidden_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# Input: B x T x C; Output: B x T
|
||||
x = self.conv1(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = self.dropout_module(self.ln1(x))
|
||||
x = self.conv2(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = self.dropout_module(self.ln2(x))
|
||||
return self.proj(x).squeeze(dim=2)
|
||||
|
||||
|
||||
class VarianceAdaptor(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.length_regulator = LengthRegulator()
|
||||
self.duration_predictor = VariancePredictor(args)
|
||||
self.pitch_predictor = VariancePredictor(args)
|
||||
self.energy_predictor = VariancePredictor(args)
|
||||
|
||||
n_bins, steps = self.args.var_pred_n_bins, self.args.var_pred_n_bins - 1
|
||||
self.pitch_bins = torch.linspace(args.pitch_min, args.pitch_max, steps)
|
||||
self.embed_pitch = Embedding(n_bins, args.encoder_embed_dim)
|
||||
self.energy_bins = torch.linspace(args.energy_min, args.energy_max, steps)
|
||||
self.embed_energy = Embedding(n_bins, args.encoder_embed_dim)
|
||||
|
||||
def get_pitch_emb(self, x, tgt=None, factor=1.0):
|
||||
out = self.pitch_predictor(x)
|
||||
bins = self.pitch_bins.to(x.device)
|
||||
if tgt is None:
|
||||
out = out * factor
|
||||
emb = self.embed_pitch(torch.bucketize(out, bins))
|
||||
else:
|
||||
emb = self.embed_pitch(torch.bucketize(tgt, bins))
|
||||
return out, emb
|
||||
|
||||
def get_energy_emb(self, x, tgt=None, factor=1.0):
|
||||
out = self.energy_predictor(x)
|
||||
bins = self.energy_bins.to(x.device)
|
||||
if tgt is None:
|
||||
out = out * factor
|
||||
emb = self.embed_energy(torch.bucketize(out, bins))
|
||||
else:
|
||||
emb = self.embed_energy(torch.bucketize(tgt, bins))
|
||||
return out, emb
|
||||
|
||||
def forward(
|
||||
self, x, padding_mask, durations=None, pitches=None, energies=None,
|
||||
d_factor=1.0, p_factor=1.0, e_factor=1.0
|
||||
):
|
||||
# x: B x T x C
|
||||
log_dur_out = self.duration_predictor(x)
|
||||
dur_out = torch.clamp(
|
||||
torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0
|
||||
)
|
||||
dur_out.masked_fill_(padding_mask, 0)
|
||||
|
||||
pitch_out, pitch_emb = self.get_pitch_emb(x, pitches, p_factor)
|
||||
x = x + pitch_emb
|
||||
energy_out, energy_emb = self.get_energy_emb(x, energies, e_factor)
|
||||
x = x + energy_emb
|
||||
|
||||
x, out_lens = self.length_regulator(
|
||||
x, dur_out if durations is None else durations
|
||||
)
|
||||
|
||||
return x, out_lens, log_dur_out, pitch_out, energy_out
|
||||
|
||||
|
||||
class FastSpeech2Encoder(FairseqEncoder):
|
||||
def __init__(self, args, src_dict, embed_speaker):
|
||||
super().__init__(src_dict)
|
||||
self.args = args
|
||||
self.padding_idx = src_dict.pad()
|
||||
self.n_frames_per_step = args.n_frames_per_step
|
||||
self.out_dim = args.output_frame_dim * args.n_frames_per_step
|
||||
|
||||
self.embed_speaker = embed_speaker
|
||||
self.spk_emb_proj = None
|
||||
if embed_speaker is not None:
|
||||
self.spk_emb_proj = nn.Linear(
|
||||
args.encoder_embed_dim + args.speaker_embed_dim,
|
||||
args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
p=args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.embed_tokens = Embedding(
|
||||
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
|
||||
)
|
||||
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
|
||||
)
|
||||
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
|
||||
self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1))
|
||||
|
||||
self.encoder_fft_layers = nn.ModuleList(
|
||||
FFTLayer(
|
||||
args.encoder_embed_dim, args.encoder_attention_heads,
|
||||
args.fft_hidden_dim, args.fft_kernel_size,
|
||||
dropout=args.dropout, attention_dropout=args.attention_dropout
|
||||
)
|
||||
for _ in range(args.encoder_layers)
|
||||
)
|
||||
|
||||
self.var_adaptor = VarianceAdaptor(args)
|
||||
|
||||
self.decoder_fft_layers = nn.ModuleList(
|
||||
FFTLayer(
|
||||
args.decoder_embed_dim, args.decoder_attention_heads,
|
||||
args.fft_hidden_dim, args.fft_kernel_size,
|
||||
dropout=args.dropout, attention_dropout=args.attention_dropout
|
||||
)
|
||||
for _ in range(args.decoder_layers)
|
||||
)
|
||||
|
||||
self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
|
||||
|
||||
self.apply(model_init)
|
||||
|
||||
def forward(self, src_tokens, src_lengths=None, speaker=None,
|
||||
durations=None, pitches=None, energies=None, **kwargs):
|
||||
x = self.embed_tokens(src_tokens)
|
||||
|
||||
enc_padding_mask = src_tokens.eq(self.padding_idx)
|
||||
x += self.pos_emb_alpha * self.embed_positions(enc_padding_mask)
|
||||
x = self.dropout_module(x)
|
||||
|
||||
for layer in self.encoder_fft_layers:
|
||||
x = layer(x, enc_padding_mask)
|
||||
|
||||
if self.embed_speaker is not None:
|
||||
bsz, seq_len, _ = x.size()
|
||||
emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1)
|
||||
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
|
||||
|
||||
x, out_lens, log_dur_out, pitch_out, energy_out = \
|
||||
self.var_adaptor(x, enc_padding_mask, durations, pitches, energies)
|
||||
|
||||
dec_padding_mask = lengths_to_padding_mask(out_lens)
|
||||
x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask)
|
||||
for layer in self.decoder_fft_layers:
|
||||
x = layer(x, dec_padding_mask)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x, out_lens, log_dur_out, pitch_out, energy_out
|
||||
|
||||
|
||||
@register_model("fastspeech2")
|
||||
class FastSpeech2Model(FairseqEncoderModel):
|
||||
"""
|
||||
Implementation for https://arxiv.org/abs/2006.04558
|
||||
"""
|
||||
|
||||
NON_AUTOREGRESSIVE = True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("--dropout", type=float)
|
||||
parser.add_argument("--output-frame-dim", type=int)
|
||||
parser.add_argument("--speaker-embed-dim", type=int)
|
||||
# FFT blocks
|
||||
parser.add_argument("--fft-hidden-dim", type=int)
|
||||
parser.add_argument("--fft-kernel-size", type=int)
|
||||
parser.add_argument("--attention-dropout", type=float)
|
||||
parser.add_argument("--encoder-layers", type=int)
|
||||
parser.add_argument("--encoder-embed-dim", type=int)
|
||||
parser.add_argument("--encoder-attention-heads", type=int)
|
||||
parser.add_argument("--decoder-layers", type=int)
|
||||
parser.add_argument("--decoder-embed-dim", type=int)
|
||||
parser.add_argument("--decoder-attention-heads", type=int)
|
||||
# variance predictor
|
||||
parser.add_argument("--var-pred-n-bins", type=int)
|
||||
parser.add_argument("--var-pred-hidden-dim", type=int)
|
||||
parser.add_argument("--var-pred-kernel-size", type=int)
|
||||
parser.add_argument("--var-pred-dropout", type=float)
|
||||
|
||||
def __init__(self, encoder, args, src_dict):
|
||||
super().__init__(encoder)
|
||||
self._num_updates = 0
|
||||
|
||||
out_dim = args.output_frame_dim * args.n_frames_per_step
|
||||
self.ctc_proj = None
|
||||
if getattr(args, "ctc_weight", 0.) > 0.:
|
||||
self.ctc_proj = nn.Linear(out_dim, len(src_dict))
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
embed_speaker = task.get_speaker_embeddings(args)
|
||||
encoder = FastSpeech2Encoder(args, task.src_dict, embed_speaker)
|
||||
return cls(encoder, args, task.src_dict)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
self._num_updates = num_updates
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
logits = self.ctc_proj(net_output[0])
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits.float(), dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits.float(), dim=-1)
|
||||
|
||||
|
||||
@register_model_architecture("fastspeech2", "fastspeech2")
|
||||
def base_architecture(args):
|
||||
args.dropout = getattr(args, "dropout", 0.2)
|
||||
args.output_frame_dim = getattr(args, "output_frame_dim", 80)
|
||||
args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
|
||||
# FFT blocks
|
||||
args.fft_hidden_dim = getattr(args, "fft_hidden_dim", 1024)
|
||||
args.fft_kernel_size = getattr(args, "fft_kernel_size", 9)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 4)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 4)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
|
||||
# variance predictor
|
||||
args.var_pred_n_bins = getattr(args, "var_pred_n_bins", 256)
|
||||
args.var_pred_hidden_dim = getattr(args, "var_pred_hidden_dim", 256)
|
||||
args.var_pred_kernel_size = getattr(args, "var_pred_kernel_size", 3)
|
||||
args.var_pred_dropout = getattr(args, "var_pred_dropout", 0.5)
|
173
fairseq/models/text_to_speech/hifigan.py
Normal file
173
fairseq/models/text_to_speech/hifigan.py
Normal file
@ -0,0 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return (kernel_size * dilation - dilation) // 2
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for layer in self.convs1:
|
||||
remove_weight_norm(layer)
|
||||
for layer in self.convs2:
|
||||
remove_weight_norm(layer)
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(cfg["resblock_kernel_sizes"])
|
||||
self.num_upsamples = len(cfg["upsample_rates"])
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(80, cfg["upsample_initial_channel"], 7, 1, padding=3)
|
||||
)
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(
|
||||
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
|
||||
):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
cfg["upsample_initial_channel"] // (2 ** i),
|
||||
cfg["upsample_initial_channel"] // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = cfg["upsample_initial_channel"] // (2 ** (i + 1))
|
||||
for k, d in zip(
|
||||
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
|
||||
):
|
||||
self.resblocks.append(ResBlock(ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for layer in self.ups:
|
||||
remove_weight_norm(layer)
|
||||
for layer in self.resblocks:
|
||||
layer.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
350
fairseq/models/text_to_speech/tacotron2.py
Normal file
350
fairseq/models/text_to_speech/tacotron2.py
Normal file
@ -0,0 +1,350 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder, register_model,
|
||||
register_model_architecture)
|
||||
from fairseq.modules import LSTMCellWithZoneOut, LocationAttention
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def encoder_init(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
|
||||
|
||||
|
||||
class Tacotron2Encoder(FairseqEncoder):
|
||||
def __init__(self, args, src_dict, embed_speaker):
|
||||
super().__init__(src_dict)
|
||||
self.padding_idx = src_dict.pad()
|
||||
self.embed_speaker = embed_speaker
|
||||
self.spk_emb_proj = None
|
||||
if embed_speaker is not None:
|
||||
self.spk_emb_proj = nn.Linear(
|
||||
args.encoder_embed_dim + args.speaker_embed_dim,
|
||||
args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
assert(args.encoder_conv_kernel_size % 2 == 1)
|
||||
self.convolutions = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
|
||||
kernel_size=args.encoder_conv_kernel_size,
|
||||
padding=((args.encoder_conv_kernel_size - 1) // 2)),
|
||||
nn.BatchNorm1d(args.encoder_embed_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.encoder_dropout)
|
||||
)
|
||||
for _ in range(args.encoder_conv_layers)
|
||||
)
|
||||
|
||||
self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2,
|
||||
num_layers=args.encoder_lstm_layers,
|
||||
batch_first=True, bidirectional=True)
|
||||
|
||||
self.apply(encoder_init)
|
||||
|
||||
def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs):
|
||||
x = self.embed_tokens(src_tokens)
|
||||
x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T
|
||||
for conv in self.convolutions:
|
||||
x = conv(x)
|
||||
x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C
|
||||
|
||||
src_lengths = src_lengths.cpu().long()
|
||||
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
|
||||
x = self.lstm(x)[0]
|
||||
x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0]
|
||||
|
||||
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
||||
|
||||
if self.embed_speaker is not None:
|
||||
seq_len, bsz, _ = x.size()
|
||||
emb = self.embed_speaker(speaker).expand(seq_len, bsz, -1)
|
||||
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # B x T x C
|
||||
"encoder_padding_mask": encoder_padding_mask, # B x T
|
||||
}
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_dim, n_layers, n_units, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units),
|
||||
nn.ReLU())
|
||||
for i in range(n_layers)
|
||||
)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = F.dropout(layer(x), p=self.dropout) # always applies dropout
|
||||
return x
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
assert(kernel_size % 2 == 1)
|
||||
for i in range(n_layers):
|
||||
cur_layers = [
|
||||
nn.Conv1d(in_dim if i == 0 else n_channels,
|
||||
n_channels if i < n_layers - 1 else in_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=((kernel_size - 1) // 2)),
|
||||
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim)
|
||||
] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)]
|
||||
nn.init.xavier_uniform_(
|
||||
cur_layers[0].weight,
|
||||
torch.nn.init.calculate_gain(
|
||||
"tanh" if i < n_layers - 1 else "linear"
|
||||
)
|
||||
)
|
||||
self.convolutions.append(nn.Sequential(*cur_layers))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2) # B x T x C -> B x C x T
|
||||
for conv in self.convolutions:
|
||||
x = conv(x)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
def decoder_init(m):
|
||||
if isinstance(m, torch.nn.Conv1d):
|
||||
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
|
||||
|
||||
|
||||
class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
def __init__(self, args, src_dict):
|
||||
super().__init__(None)
|
||||
self.args = args
|
||||
self.n_frames_per_step = args.n_frames_per_step
|
||||
self.out_dim = args.output_frame_dim * args.n_frames_per_step
|
||||
|
||||
self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
|
||||
args.prenet_dropout)
|
||||
|
||||
# take prev_context, prev_frame, (speaker embedding) as input
|
||||
self.attention_lstm = LSTMCellWithZoneOut(
|
||||
args.zoneout,
|
||||
args.prenet_dim + args.encoder_embed_dim,
|
||||
args.decoder_lstm_dim
|
||||
)
|
||||
|
||||
# take attention_lstm output, attention_state, encoder_out as input
|
||||
self.attention = LocationAttention(
|
||||
args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim,
|
||||
(1 + int(args.attention_use_cumprob)),
|
||||
args.attention_conv_dim, args.attention_conv_kernel_size
|
||||
)
|
||||
|
||||
# take attention_lstm output, context, (gated_latent) as input
|
||||
self.lstm = nn.ModuleList(
|
||||
LSTMCellWithZoneOut(
|
||||
args.zoneout,
|
||||
args.encoder_embed_dim + args.decoder_lstm_dim,
|
||||
args.decoder_lstm_dim
|
||||
)
|
||||
for i in range(args.decoder_lstm_layers)
|
||||
)
|
||||
|
||||
proj_in_dim = args.encoder_embed_dim + args.decoder_lstm_dim
|
||||
self.feat_proj = nn.Linear(proj_in_dim, self.out_dim)
|
||||
self.eos_proj = nn.Linear(proj_in_dim, 1)
|
||||
|
||||
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers, args.postnet_dropout)
|
||||
|
||||
self.ctc_proj = None
|
||||
if getattr(args, "ctc_weight", 0.) > 0.:
|
||||
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
|
||||
|
||||
self.apply(decoder_init)
|
||||
|
||||
def _get_states(self, incremental_state, enc_out):
|
||||
bsz, in_len, _ = enc_out.size()
|
||||
alstm_h = self.get_incremental_state(incremental_state, "alstm_h")
|
||||
if alstm_h is None:
|
||||
alstm_h = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
alstm_c = self.get_incremental_state(incremental_state, "alstm_c")
|
||||
if alstm_c is None:
|
||||
alstm_c = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
|
||||
lstm_h = self.get_incremental_state(incremental_state, "lstm_h")
|
||||
if lstm_h is None:
|
||||
lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
for _ in range(self.args.decoder_lstm_layers)]
|
||||
lstm_c = self.get_incremental_state(incremental_state, "lstm_c")
|
||||
if lstm_c is None:
|
||||
lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
for _ in range(self.args.decoder_lstm_layers)]
|
||||
|
||||
attn_w = self.get_incremental_state(incremental_state, "attn_w")
|
||||
if attn_w is None:
|
||||
attn_w = enc_out.new_zeros(bsz, in_len)
|
||||
attn_w_cum = self.get_incremental_state(incremental_state, "attn_w_cum")
|
||||
if attn_w_cum is None:
|
||||
attn_w_cum = enc_out.new_zeros(bsz, in_len)
|
||||
return alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum
|
||||
|
||||
def _get_init_attn_c(self, enc_out, enc_mask):
|
||||
bsz = enc_out.size(0)
|
||||
if self.args.init_attn_c == "zero":
|
||||
return enc_out.new_zeros(bsz, self.args.encoder_embed_dim)
|
||||
elif self.args.init_attn_c == "avg":
|
||||
enc_w = (~enc_mask).type(enc_out.type())
|
||||
enc_w = enc_w / enc_w.sum(dim=1, keepdim=True)
|
||||
return torch.sum(enc_out * enc_w.unsqueeze(2), dim=1)
|
||||
else:
|
||||
raise ValueError(f"{self.args.init_attn_c} not supported")
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None,
|
||||
incremental_state=None, target_lengths=None, **kwargs):
|
||||
enc_mask = encoder_out["encoder_padding_mask"]
|
||||
enc_out = encoder_out["encoder_out"][0]
|
||||
in_len = enc_out.size(1)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:, :]
|
||||
bsz, out_len, _ = prev_output_tokens.size()
|
||||
|
||||
prenet_out = self.prenet(prev_output_tokens)
|
||||
(alstm_h, alstm_c, lstm_h, lstm_c,
|
||||
attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out)
|
||||
attn_ctx = self._get_init_attn_c(enc_out, enc_mask)
|
||||
|
||||
attn_out = enc_out.new_zeros(bsz, in_len, out_len)
|
||||
feat_out = enc_out.new_zeros(bsz, out_len, self.out_dim)
|
||||
eos_out = enc_out.new_zeros(bsz, out_len)
|
||||
for t in range(out_len):
|
||||
alstm_in = torch.cat((attn_ctx, prenet_out[:, t, :]), dim=1)
|
||||
alstm_h, alstm_c = self.attention_lstm(alstm_in, (alstm_h, alstm_c))
|
||||
|
||||
attn_state = attn_w.unsqueeze(1)
|
||||
if self.args.attention_use_cumprob:
|
||||
attn_state = torch.stack((attn_w, attn_w_cum), dim=1)
|
||||
attn_ctx, attn_w = self.attention(
|
||||
enc_out, enc_mask, alstm_h, attn_state
|
||||
)
|
||||
attn_w_cum = attn_w_cum + attn_w
|
||||
attn_out[:, :, t] = attn_w
|
||||
|
||||
for i, cur_lstm in enumerate(self.lstm):
|
||||
if i == 0:
|
||||
lstm_in = torch.cat((attn_ctx, alstm_h), dim=1)
|
||||
else:
|
||||
lstm_in = torch.cat((attn_ctx, lstm_h[i - 1]), dim=1)
|
||||
lstm_h[i], lstm_c[i] = cur_lstm(lstm_in, (lstm_h[i], lstm_c[i]))
|
||||
|
||||
proj_in = torch.cat((attn_ctx, lstm_h[-1]), dim=1)
|
||||
feat_out[:, t, :] = self.feat_proj(proj_in)
|
||||
eos_out[:, t] = self.eos_proj(proj_in).squeeze(1)
|
||||
self.attention.clear_cache()
|
||||
|
||||
self.set_incremental_state(incremental_state, "alstm_h", alstm_h)
|
||||
self.set_incremental_state(incremental_state, "alstm_c", alstm_c)
|
||||
self.set_incremental_state(incremental_state, "lstm_h", lstm_h)
|
||||
self.set_incremental_state(incremental_state, "lstm_c", lstm_c)
|
||||
self.set_incremental_state(incremental_state, "attn_w", attn_w)
|
||||
self.set_incremental_state(incremental_state, "attn_w_cum", attn_w_cum)
|
||||
|
||||
post_feat_out = feat_out + self.postnet(feat_out)
|
||||
eos_out = eos_out.view(bsz, out_len, 1)
|
||||
return post_feat_out, eos_out, {"attn": attn_out, "feature_out": feat_out}
|
||||
|
||||
|
||||
@register_model("tacotron_2")
|
||||
class Tacotron2Model(FairseqEncoderDecoderModel):
|
||||
"""
|
||||
Implementation for https://arxiv.org/pdf/1712.05884.pdf
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# encoder
|
||||
parser.add_argument("--encoder-dropout", type=float)
|
||||
parser.add_argument("--encoder-embed-dim", type=int)
|
||||
parser.add_argument("--encoder-conv-layers", type=int)
|
||||
parser.add_argument("--encoder-conv-kernel-size", type=int)
|
||||
parser.add_argument("--encoder-lstm-layers", type=int)
|
||||
# decoder
|
||||
parser.add_argument("--attention-dim", type=int)
|
||||
parser.add_argument("--attention-conv-dim", type=int)
|
||||
parser.add_argument("--attention-conv-kernel-size", type=int)
|
||||
parser.add_argument("--prenet-dropout", type=float)
|
||||
parser.add_argument("--prenet-layers", type=int)
|
||||
parser.add_argument("--prenet-dim", type=int)
|
||||
parser.add_argument("--postnet-dropout", type=float)
|
||||
parser.add_argument("--postnet-layers", type=int)
|
||||
parser.add_argument("--postnet-conv-dim", type=int)
|
||||
parser.add_argument("--postnet-conv-kernel-size", type=int)
|
||||
parser.add_argument("--init-attn-c", type=str)
|
||||
parser.add_argument("--attention-use-cumprob", action='store_true')
|
||||
parser.add_argument("--zoneout", type=float)
|
||||
parser.add_argument("--decoder-lstm-layers", type=int)
|
||||
parser.add_argument("--decoder-lstm-dim", type=int)
|
||||
parser.add_argument("--output-frame-dim", type=int)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_updates = 0
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
embed_speaker = task.get_speaker_embeddings(args)
|
||||
encoder = Tacotron2Encoder(args, task.src_dict, embed_speaker)
|
||||
decoder = Tacotron2Decoder(args, task.src_dict)
|
||||
return cls(encoder, decoder)
|
||||
|
||||
def forward_encoder(self, src_tokens, src_lengths, **kwargs):
|
||||
return self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
self._num_updates = num_updates
|
||||
|
||||
|
||||
@register_model_architecture("tacotron_2", "tacotron_2")
|
||||
def base_architecture(args):
|
||||
# encoder
|
||||
args.encoder_dropout = getattr(args, "encoder_dropout", 0.5)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3)
|
||||
args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5)
|
||||
args.encoder_lstm_layers = getattr(args, "encoder_lstm_layers", 1)
|
||||
# decoder
|
||||
args.attention_dim = getattr(args, "attention_dim", 128)
|
||||
args.attention_conv_dim = getattr(args, "attention_conv_dim", 32)
|
||||
args.attention_conv_kernel_size = getattr(args,
|
||||
"attention_conv_kernel_size", 15)
|
||||
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
|
||||
args.prenet_layers = getattr(args, "prenet_layers", 2)
|
||||
args.prenet_dim = getattr(args, "prenet_dim", 256)
|
||||
args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
|
||||
args.postnet_layers = getattr(args, "postnet_layers", 5)
|
||||
args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
|
||||
args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
|
||||
args.init_attn_c = getattr(args, "init_attn_c", "zero")
|
||||
args.attention_use_cumprob = getattr(args, "attention_use_cumprob", True)
|
||||
args.zoneout = getattr(args, "zoneout", 0.1)
|
||||
args.decoder_lstm_layers = getattr(args, "decoder_lstm_layers", 2)
|
||||
args.decoder_lstm_dim = getattr(args, "decoder_lstm_dim", 1024)
|
||||
args.output_frame_dim = getattr(args, "output_frame_dim", 80)
|
371
fairseq/models/text_to_speech/tts_transformer.py
Normal file
371
fairseq/models/text_to_speech/tts_transformer.py
Normal file
@ -0,0 +1,371 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder, register_model,
|
||||
register_model_architecture)
|
||||
from fairseq.modules import (
|
||||
TransformerEncoderLayer, TransformerDecoderLayer
|
||||
)
|
||||
from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet
|
||||
from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout
|
||||
from fairseq.data.data_utils import lengths_to_padding_mask
|
||||
from fairseq import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def encoder_init(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim)
|
||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
||||
return m
|
||||
|
||||
|
||||
class TTSTransformerEncoder(FairseqEncoder):
|
||||
def __init__(self, args, src_dict, embed_speaker):
|
||||
super().__init__(src_dict)
|
||||
self.padding_idx = src_dict.pad()
|
||||
self.embed_speaker = embed_speaker
|
||||
self.spk_emb_proj = None
|
||||
if embed_speaker is not None:
|
||||
self.spk_emb_proj = nn.Linear(
|
||||
args.encoder_embed_dim + args.speaker_embed_dim,
|
||||
args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
p=args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
|
||||
padding_idx=self.padding_idx)
|
||||
assert(args.encoder_conv_kernel_size % 2 == 1)
|
||||
self.prenet = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
|
||||
kernel_size=args.encoder_conv_kernel_size,
|
||||
padding=((args.encoder_conv_kernel_size - 1) // 2)),
|
||||
nn.BatchNorm1d(args.encoder_embed_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.encoder_dropout),
|
||||
)
|
||||
for _ in range(args.encoder_conv_layers)
|
||||
)
|
||||
self.prenet_proj = nn.Linear(
|
||||
args.encoder_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
|
||||
)
|
||||
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
|
||||
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
TransformerEncoderLayer(args)
|
||||
for _ in range(args.encoder_transformer_layers)
|
||||
)
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.apply(encoder_init)
|
||||
|
||||
def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs):
|
||||
x = self.embed_tokens(src_tokens)
|
||||
x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T
|
||||
for conv in self.prenet:
|
||||
x = conv(x)
|
||||
x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C
|
||||
x = self.prenet_proj(x)
|
||||
|
||||
padding_mask = src_tokens.eq(self.padding_idx)
|
||||
positions = self.embed_positions(padding_mask)
|
||||
x += self.pos_emb_alpha * positions
|
||||
x = self.dropout_module(x)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
for layer in self.transformer_layers:
|
||||
x = layer(x, padding_mask)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
if self.embed_speaker is not None:
|
||||
seq_len, bsz, _ = x.size()
|
||||
emb = self.embed_speaker(speaker).transpose(0, 1)
|
||||
emb = emb.expand(seq_len, bsz, -1)
|
||||
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T
|
||||
"encoder_embedding": [], # B x T x C
|
||||
"encoder_states": [], # List[T x B x C]
|
||||
"src_tokens": [],
|
||||
"src_lengths": [],
|
||||
}
|
||||
|
||||
|
||||
def decoder_init(m):
|
||||
if isinstance(m, torch.nn.Conv1d):
|
||||
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
|
||||
|
||||
|
||||
class TTSTransformerDecoder(FairseqIncrementalDecoder):
|
||||
def __init__(self, args, src_dict):
|
||||
super().__init__(None)
|
||||
self._future_mask = torch.empty(0)
|
||||
|
||||
self.args = args
|
||||
self.padding_idx = src_dict.pad()
|
||||
self.n_frames_per_step = args.n_frames_per_step
|
||||
self.out_dim = args.output_frame_dim * args.n_frames_per_step
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_target_positions, args.decoder_embed_dim, self.padding_idx
|
||||
)
|
||||
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
|
||||
self.prenet = nn.Sequential(
|
||||
Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
|
||||
args.prenet_dropout),
|
||||
nn.Linear(args.prenet_dim, args.decoder_embed_dim),
|
||||
)
|
||||
|
||||
self.n_transformer_layers = args.decoder_transformer_layers
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
TransformerDecoderLayer(args)
|
||||
for _ in range(self.n_transformer_layers)
|
||||
)
|
||||
if args.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(args.decoder_embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
|
||||
self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
|
||||
|
||||
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers, args.postnet_dropout)
|
||||
|
||||
self.ctc_proj = None
|
||||
if getattr(args, "ctc_weight", 0.) > 0.:
|
||||
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
|
||||
|
||||
self.apply(decoder_init)
|
||||
|
||||
def extract_features(
|
||||
self, prev_outputs, encoder_out=None, incremental_state=None,
|
||||
target_lengths=None, speaker=None, **kwargs
|
||||
):
|
||||
alignment_layer = self.n_transformer_layers - 1
|
||||
self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
|
||||
positions = self.embed_positions(
|
||||
self_attn_padding_mask, incremental_state=incremental_state
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_outputs = prev_outputs[:, -1:, :]
|
||||
self_attn_padding_mask = self_attn_padding_mask[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
||||
x = self.prenet(prev_outputs)
|
||||
x += self.pos_emb_alpha * positions
|
||||
x = self.dropout_module(x)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if not self_attn_padding_mask.any():
|
||||
self_attn_padding_mask = None
|
||||
|
||||
attn: Optional[torch.Tensor] = None
|
||||
inner_states: List[Optional[torch.Tensor]] = [x]
|
||||
for idx, transformer_layer in enumerate(self.transformer_layers):
|
||||
if incremental_state is None:
|
||||
self_attn_mask = self.buffered_future_mask(x)
|
||||
else:
|
||||
self_attn_mask = None
|
||||
|
||||
x, layer_attn, _ = transformer_layer(
|
||||
x,
|
||||
encoder_out["encoder_out"][0]
|
||||
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
|
||||
else None,
|
||||
encoder_out["encoder_padding_mask"][0]
|
||||
if (
|
||||
encoder_out is not None
|
||||
and len(encoder_out["encoder_padding_mask"]) > 0
|
||||
)
|
||||
else None,
|
||||
incremental_state,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
need_attn=bool((idx == alignment_layer)),
|
||||
need_head_weights=bool((idx == alignment_layer)),
|
||||
)
|
||||
inner_states.append(x)
|
||||
if layer_attn is not None and idx == alignment_layer:
|
||||
attn = layer_attn.float().to(x)
|
||||
|
||||
if attn is not None:
|
||||
# average probabilities over heads, transpose to
|
||||
# (B, src_len, tgt_len)
|
||||
attn = attn.mean(dim=0).transpose(2, 1)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, {"attn": attn, "inner_states": inner_states}
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None,
|
||||
incremental_state=None, target_lengths=None, speaker=None,
|
||||
**kwargs):
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out=encoder_out,
|
||||
incremental_state=incremental_state, target_lengths=target_lengths,
|
||||
speaker=speaker, **kwargs
|
||||
)
|
||||
attn = extra["attn"]
|
||||
feat_out = self.feat_proj(x)
|
||||
bsz, seq_len, _ = x.size()
|
||||
eos_out = self.eos_proj(x)
|
||||
post_feat_out = feat_out + self.postnet(feat_out)
|
||||
return post_feat_out, eos_out, {"attn": attn, "feature_out": feat_out}
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample):
|
||||
logits = self.ctc_proj(net_output[2]["feature_out"])
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits.float(), dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits.float(), dim=-1)
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
||||
if (
|
||||
self._future_mask.size(0) == 0
|
||||
or (not self._future_mask.device == tensor.device)
|
||||
or self._future_mask.size(0) < dim
|
||||
):
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
||||
)
|
||||
self._future_mask = self._future_mask.to(tensor)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
|
||||
@register_model("tts_transformer")
|
||||
class TTSTransformerModel(FairseqEncoderDecoderModel):
|
||||
"""
|
||||
Implementation for https://arxiv.org/pdf/1809.08895.pdf
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("--dropout", type=float)
|
||||
parser.add_argument("--output-frame-dim", type=int)
|
||||
parser.add_argument("--speaker-embed-dim", type=int)
|
||||
# encoder prenet
|
||||
parser.add_argument("--encoder-dropout", type=float)
|
||||
parser.add_argument("--encoder-conv-layers", type=int)
|
||||
parser.add_argument("--encoder-conv-kernel-size", type=int)
|
||||
# encoder transformer layers
|
||||
parser.add_argument("--encoder-transformer-layers", type=int)
|
||||
parser.add_argument("--encoder-embed-dim", type=int)
|
||||
parser.add_argument("--encoder-ffn-embed-dim", type=int)
|
||||
parser.add_argument("--encoder-normalize-before", action="store_true")
|
||||
parser.add_argument("--encoder-attention-heads", type=int)
|
||||
parser.add_argument("--attention-dropout", type=float)
|
||||
parser.add_argument("--activation-dropout", "--relu-dropout", type=float)
|
||||
parser.add_argument("--activation-fn", type=str, default="relu")
|
||||
# decoder prenet
|
||||
parser.add_argument("--prenet-dropout", type=float)
|
||||
parser.add_argument("--prenet-layers", type=int)
|
||||
parser.add_argument("--prenet-dim", type=int)
|
||||
# decoder postnet
|
||||
parser.add_argument("--postnet-dropout", type=float)
|
||||
parser.add_argument("--postnet-layers", type=int)
|
||||
parser.add_argument("--postnet-conv-dim", type=int)
|
||||
parser.add_argument("--postnet-conv-kernel-size", type=int)
|
||||
# decoder transformer layers
|
||||
parser.add_argument("--decoder-transformer-layers", type=int)
|
||||
parser.add_argument("--decoder-embed-dim", type=int)
|
||||
parser.add_argument("--decoder-ffn-embed-dim", type=int)
|
||||
parser.add_argument("--decoder-normalize-before", action="store_true")
|
||||
parser.add_argument("--decoder-attention-heads", type=int)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_updates = 0
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
embed_speaker = task.get_speaker_embeddings(args)
|
||||
encoder = TTSTransformerEncoder(args, task.src_dict, embed_speaker)
|
||||
decoder = TTSTransformerDecoder(args, task.src_dict)
|
||||
return cls(encoder, decoder)
|
||||
|
||||
def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
|
||||
return self.encoder(src_tokens, src_lengths=src_lengths,
|
||||
speaker=speaker, **kwargs)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
self._num_updates = num_updates
|
||||
|
||||
|
||||
@register_model_architecture("tts_transformer", "tts_transformer")
|
||||
def base_architecture(args):
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.output_frame_dim = getattr(args, "output_frame_dim", 80)
|
||||
args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
|
||||
# encoder prenet
|
||||
args.encoder_dropout = getattr(args, "encoder_dropout", 0.5)
|
||||
args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3)
|
||||
args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5)
|
||||
# encoder transformer layers
|
||||
args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim)
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||
# decoder prenet
|
||||
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
|
||||
args.prenet_layers = getattr(args, "prenet_layers", 2)
|
||||
args.prenet_dim = getattr(args, "prenet_dim", 256)
|
||||
# decoder postnet
|
||||
args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
|
||||
args.postnet_layers = getattr(args, "postnet_layers", 5)
|
||||
args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
|
||||
args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
|
||||
# decoder transformer layers
|
||||
args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim)
|
||||
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
197
fairseq/models/text_to_speech/vocoder.py
Normal file
197
fairseq/models/text_to_speech/vocoder.py
Normal file
@ -0,0 +1,197 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data.audio.audio_utils import (
|
||||
get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram
|
||||
)
|
||||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
||||
from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PseudoInverseMelScale(torch.nn.Module):
|
||||
def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None:
|
||||
super(PseudoInverseMelScale, self).__init__()
|
||||
self.n_mels = n_mels
|
||||
basis = get_mel_filters(
|
||||
sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max
|
||||
)
|
||||
basis = torch.pinverse(basis) # F x F_mel
|
||||
self.register_buffer('basis', basis)
|
||||
|
||||
def forward(self, melspec: torch.Tensor) -> torch.Tensor:
|
||||
# pack batch
|
||||
shape = melspec.shape # B_1 x ... x B_K x F_mel x T
|
||||
n_mels, time = shape[-2], shape[-1]
|
||||
melspec = melspec.view(-1, n_mels, time)
|
||||
|
||||
freq, _ = self.basis.size() # F x F_mel
|
||||
assert self.n_mels == n_mels, (self.n_mels, n_mels)
|
||||
specgram = self.basis.matmul(melspec).clamp(min=0)
|
||||
|
||||
# unpack batch
|
||||
specgram = specgram.view(shape[:-2] + (freq, time))
|
||||
return specgram
|
||||
|
||||
|
||||
class GriffinLim(torch.nn.Module):
|
||||
def __init__(
|
||||
self, n_fft: int, win_length: int, hop_length: int, n_iter: int,
|
||||
window_fn=torch.hann_window
|
||||
):
|
||||
super(GriffinLim, self).__init__()
|
||||
self.transform = TTSSpectrogram(
|
||||
n_fft, win_length, hop_length, return_phase=True
|
||||
)
|
||||
|
||||
basis = get_fourier_basis(n_fft)
|
||||
basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :]
|
||||
basis *= get_window(window_fn, n_fft, win_length)
|
||||
self.register_buffer('basis', basis)
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.n_iter = n_iter
|
||||
|
||||
self.tiny = 1.1754944e-38
|
||||
|
||||
@classmethod
|
||||
def get_window_sum_square(
|
||||
cls, n_frames, hop_length, win_length, n_fft,
|
||||
window_fn=torch.hann_window
|
||||
) -> torch.Tensor:
|
||||
w_sq = get_window(window_fn, n_fft, win_length) ** 2
|
||||
n = n_fft + hop_length * (n_frames - 1)
|
||||
x = torch.zeros(n, dtype=torch.float32)
|
||||
for i in range(n_frames):
|
||||
ofst = i * hop_length
|
||||
x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))]
|
||||
return x
|
||||
|
||||
def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor:
|
||||
x = torch.cat(
|
||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)],
|
||||
dim=1
|
||||
)
|
||||
x = F.conv_transpose1d(x, self.basis, stride=self.hop_length)
|
||||
win_sum_sq = self.get_window_sum_square(
|
||||
magnitude.shape[-1], hop_length=self.hop_length,
|
||||
win_length=self.win_length, n_fft=self.n_fft
|
||||
).to(magnitude.device)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = win_sum_sq > self.tiny
|
||||
x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices]
|
||||
x *= self.n_fft / self.hop_length
|
||||
x = x[:, :, self.n_fft // 2:]
|
||||
x = x[:, :, :-self.n_fft // 2:]
|
||||
return x
|
||||
|
||||
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
|
||||
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*specgram.shape)))
|
||||
angles = torch.from_numpy(angles).to(specgram)
|
||||
_specgram = specgram.view(-1, specgram.shape[-2], specgram.shape[-1])
|
||||
waveform = self.inverse(_specgram, angles).squeeze(1)
|
||||
for _ in range(self.n_iter):
|
||||
_, angles = self.transform(waveform)
|
||||
waveform = self.inverse(_specgram, angles).squeeze(1)
|
||||
return waveform.squeeze(0)
|
||||
|
||||
|
||||
class GriffinLimVocoder(nn.Module):
|
||||
def __init__(self, sample_rate, win_size, hop_size, n_fft,
|
||||
n_mels, f_min, f_max, window_fn,
|
||||
spec_bwd_max_iter=32,
|
||||
fp16=False):
|
||||
super().__init__()
|
||||
self.inv_mel_transform = PseudoInverseMelScale(
|
||||
n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate,
|
||||
f_min=f_min, f_max=f_max
|
||||
)
|
||||
self.gl_transform = GriffinLim(
|
||||
n_fft=n_fft, win_length=win_size, hop_length=hop_size,
|
||||
window_fn=window_fn, n_iter=spec_bwd_max_iter
|
||||
)
|
||||
if fp16:
|
||||
self.half()
|
||||
self.inv_mel_transform.half()
|
||||
self.gl_transform.half()
|
||||
else:
|
||||
self.float()
|
||||
self.inv_mel_transform.float()
|
||||
self.gl_transform.float()
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B x) T x D -> (B x) 1 x T
|
||||
# NOTE: batched forward produces noisier waveform. recommend running
|
||||
# one utterance at a time
|
||||
self.eval()
|
||||
x = x.exp().transpose(-1, -2)
|
||||
x = self.inv_mel_transform(x)
|
||||
x = self.gl_transform(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_data_cfg(cls, args, data_cfg: S2TDataConfig):
|
||||
feat_cfg = data_cfg.config["features"]
|
||||
window_fn = getattr(torch, feat_cfg["window_fn"] + "_window")
|
||||
return cls(
|
||||
sample_rate=feat_cfg["sample_rate"],
|
||||
win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]),
|
||||
hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]),
|
||||
n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"],
|
||||
f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"],
|
||||
window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter,
|
||||
fp16=args.fp16
|
||||
)
|
||||
|
||||
|
||||
class HiFiGANVocoder(nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint_path: str, model_cfg: Dict[str, str],
|
||||
fp16: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = HiFiGANModel(model_cfg)
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
self.model.load_state_dict(state_dict["generator"])
|
||||
if fp16:
|
||||
self.model.half()
|
||||
logger.info(f"loaded HiFiGAN checkpoint from {checkpoint_path}")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# (B x) T x D -> (B x) 1 x T
|
||||
model = self.model.eval()
|
||||
if len(x.shape) == 2:
|
||||
return model(x.unsqueeze(0).transpose(1, 2)).detach().squeeze(0)
|
||||
else:
|
||||
return model(x.transpose(-1, -2)).detach()
|
||||
|
||||
@classmethod
|
||||
def from_data_cfg(cls, args, data_cfg: S2TDataConfig):
|
||||
vocoder_cfg = data_cfg.vocoder
|
||||
assert vocoder_cfg.get("type", "griffin_lim") == "hifigan"
|
||||
with open(vocoder_cfg["config"]) as f:
|
||||
model_cfg = json.load(f)
|
||||
return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16)
|
||||
|
||||
|
||||
def get_vocoder(args, data_cfg: S2TDataConfig):
|
||||
if args.vocoder == "griffin_lim":
|
||||
return GriffinLimVocoder.from_data_cfg(args, data_cfg)
|
||||
elif args.vocoder == "hifigan":
|
||||
return HiFiGANVocoder.from_data_cfg(args, data_cfg)
|
||||
else:
|
||||
raise ValueError("Unknown vocoder")
|
@ -25,6 +25,8 @@ from .layer_norm import Fp32LayerNorm, LayerNorm
|
||||
from .learned_positional_embedding import LearnedPositionalEmbedding
|
||||
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
|
||||
from .linearized_convolution import LinearizedConvolution
|
||||
from .location_attention import LocationAttention
|
||||
from .lstm_cell_with_zoneout import LSTMCellWithZoneOut
|
||||
from .multihead_attention import MultiheadAttention
|
||||
from .positional_embedding import PositionalEmbedding
|
||||
from .same_pad import SamePad
|
||||
@ -63,6 +65,8 @@ __all__ = [
|
||||
"LightweightConv1dTBC",
|
||||
"LightweightConv",
|
||||
"LinearizedConvolution",
|
||||
"LocationAttention",
|
||||
"LSTMCellWithZoneOut",
|
||||
"MultiheadAttention",
|
||||
"PositionalEmbedding",
|
||||
"SamePad",
|
||||
|
72
fairseq/modules/location_attention.py
Normal file
72
fairseq/modules/location_attention.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LocationAttention(nn.Module):
|
||||
"""
|
||||
Attention-Based Models for Speech Recognition
|
||||
https://arxiv.org/pdf/1506.07503.pdf
|
||||
|
||||
:param int encoder_dim: # projection-units of encoder
|
||||
:param int decoder_dim: # units of decoder
|
||||
:param int attn_dim: attention dimension
|
||||
:param int conv_dim: # channels of attention convolution
|
||||
:param int conv_kernel_size: filter size of attention convolution
|
||||
"""
|
||||
|
||||
def __init__(self, attn_dim, encoder_dim, decoder_dim,
|
||||
attn_state_kernel_size, conv_dim, conv_kernel_size,
|
||||
scaling=2.0):
|
||||
super(LocationAttention, self).__init__()
|
||||
self.attn_dim = attn_dim
|
||||
self.decoder_dim = decoder_dim
|
||||
self.scaling = scaling
|
||||
self.proj_enc = nn.Linear(encoder_dim, attn_dim)
|
||||
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
|
||||
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
|
||||
self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim,
|
||||
2 * conv_kernel_size + 1,
|
||||
padding=conv_kernel_size, bias=False)
|
||||
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
|
||||
|
||||
self.proj_enc_out = None # cache
|
||||
|
||||
def clear_cache(self):
|
||||
self.proj_enc_out = None
|
||||
|
||||
def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state):
|
||||
"""
|
||||
:param torch.Tensor encoder_out: padded encoder hidden state B x T x D
|
||||
:param torch.Tensor encoder_padding_mask: encoder padding mask
|
||||
:param torch.Tensor decoder_h: decoder hidden state B x D
|
||||
:param torch.Tensor attn_prev: previous attention weight B x K x T
|
||||
:return: attention weighted encoder state (B, D)
|
||||
:rtype: torch.Tensor
|
||||
:return: previous attention weights (B x T)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
bsz, seq_len, _ = encoder_out.size()
|
||||
if self.proj_enc_out is None:
|
||||
self.proj_enc_out = self.proj_enc(encoder_out)
|
||||
|
||||
# B x K x T -> B x C x T
|
||||
attn = self.conv(attn_state)
|
||||
# B x C x T -> B x T x C -> B x T x D
|
||||
attn = self.proj_attn(attn.transpose(1, 2))
|
||||
|
||||
if decoder_h is None:
|
||||
decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim)
|
||||
dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim)
|
||||
|
||||
out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2)
|
||||
out.masked_fill_(encoder_padding_mask, -float("inf"))
|
||||
|
||||
w = F.softmax(self.scaling * out, dim=1)
|
||||
c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1)
|
||||
return c, w
|
37
fairseq/modules/lstm_cell_with_zoneout.py
Normal file
37
fairseq/modules/lstm_cell_with_zoneout.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LSTMCellWithZoneOut(nn.Module):
|
||||
"""
|
||||
Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations
|
||||
https://arxiv.org/abs/1606.01305
|
||||
"""
|
||||
|
||||
def __init__(self, prob: float, input_size: int, hidden_size: int,
|
||||
bias: bool = True):
|
||||
super(LSTMCellWithZoneOut, self).__init__()
|
||||
self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
|
||||
self.prob = prob
|
||||
if prob > 1.0 or prob < 0.0:
|
||||
raise ValueError("zoneout probability must be in the range from "
|
||||
"0.0 to 1.0.")
|
||||
|
||||
def zoneout(self, h, next_h, prob):
|
||||
if isinstance(h, tuple):
|
||||
return tuple(
|
||||
[self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]
|
||||
)
|
||||
|
||||
if self.training:
|
||||
mask = h.new_zeros(*h.size()).bernoulli_(prob)
|
||||
return mask * h + (1 - mask) * next_h
|
||||
|
||||
return prob * h + (1 - prob) * next_h
|
||||
|
||||
def forward(self, x, h):
|
||||
return self.zoneout(h, self.lstm_cell(x, h), self.prob)
|
86
fairseq/optim/lr_scheduler/step_lr_scheduler.py
Normal file
86
fairseq/optim/lr_scheduler/step_lr_scheduler.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections.abc import Collection
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepLRScheduleConfig(FairseqDataclass):
|
||||
warmup_updates: int = field(
|
||||
default=0,
|
||||
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
||||
)
|
||||
warmup_init_lr: float = field(
|
||||
default=-1,
|
||||
metadata={
|
||||
"help": "initial learning rate during warmup phase; default is cfg.lr"
|
||||
},
|
||||
)
|
||||
lr: List[float] = field(
|
||||
default=II("optimization.lr"),
|
||||
metadata={"help": "max learning rate, must be more than cfg.min_lr"},
|
||||
)
|
||||
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
|
||||
lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"})
|
||||
lr_decay: float = field(default=0.5, metadata={"help": "decay factor"})
|
||||
|
||||
|
||||
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
|
||||
class StepLRSchedule(FairseqLRScheduler):
|
||||
"""Decay learning rate every k updates by a fixed factor
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
|
||||
super().__init__(cfg, fairseq_optimizer)
|
||||
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
|
||||
self.min_lr = cfg.min_lr
|
||||
self.lr_deacy_period = cfg.lr_deacy_period
|
||||
self.lr_decay = cfg.lr_decay
|
||||
self.warmup_updates = cfg.warmup_updates
|
||||
self.warmup_init_lr = (
|
||||
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
|
||||
)
|
||||
|
||||
assert(self.lr_deacy_period > 0)
|
||||
assert(self.lr_decay <= 1)
|
||||
assert(self.min_lr >= 0)
|
||||
assert(self.max_lr > self.min_lr)
|
||||
|
||||
if cfg.warmup_updates > 0:
|
||||
# linearly warmup for the first cfg.warmup_updates
|
||||
self.warmup_lr_step = (
|
||||
(self.max_lr - self.warmup_init_lr) / self.warmup_updates
|
||||
)
|
||||
else:
|
||||
self.warmup_lr_step = 1
|
||||
|
||||
# initial learning rate
|
||||
self.lr = self.warmup_init_lr
|
||||
self.optimizer.set_lr(self.lr)
|
||||
|
||||
def step(self, epoch, val_loss=None):
|
||||
"""Update the learning rate at the end of the given epoch."""
|
||||
super().step(epoch, val_loss)
|
||||
# we don't change the learning rate at epoch boundaries
|
||||
return self.optimizer.get_lr()
|
||||
|
||||
def step_update(self, num_updates):
|
||||
"""Update the learning rate after each update."""
|
||||
if num_updates < self.cfg.warmup_updates:
|
||||
self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step
|
||||
else:
|
||||
curr_updates = num_updates - self.cfg.warmup_updates
|
||||
lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period)
|
||||
self.lr = max(self.max_lr * lr_mult, self.min_lr)
|
||||
|
||||
self.optimizer.set_lr(self.lr)
|
||||
return self.lr
|
@ -56,6 +56,14 @@ def get_generation_parser(interactive=False, default_task="translation"):
|
||||
return parser
|
||||
|
||||
|
||||
def get_speech_generation_parser(default_task="text_to_speech"):
|
||||
parser = get_parser("Speech Generation", default_task)
|
||||
add_dataset_args(parser, gen=True)
|
||||
add_distributed_training_args(parser, default_world_size=1)
|
||||
add_speech_generation_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def get_interactive_generation_parser(default_task="translation"):
|
||||
return get_generation_parser(interactive=True, default_task=default_task)
|
||||
|
||||
@ -344,6 +352,16 @@ def add_generation_args(parser):
|
||||
return group
|
||||
|
||||
|
||||
def add_speech_generation_args(parser):
|
||||
group = parser.add_argument_group("Speech Generation")
|
||||
add_common_eval_args(group) # NOTE: remove_bpe is not needed
|
||||
# fmt: off
|
||||
group.add_argument('--eos_prob_threshold', default=0.5, type=float,
|
||||
help='terminate when eos probability exceeds this')
|
||||
# fmt: on
|
||||
return group
|
||||
|
||||
|
||||
def add_interactive_args(parser):
|
||||
group = parser.add_argument_group("Interactive")
|
||||
gen_parser_from_dataclass(group, InteractiveConfig())
|
||||
|
219
fairseq/speech_generator.py
Normal file
219
fairseq/speech_generator.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
||||
|
||||
|
||||
class SpeechGenerator(object):
|
||||
def __init__(self, model, vocoder, data_cfg: S2TDataConfig):
|
||||
self.model = model
|
||||
self.vocoder = vocoder
|
||||
stats_npz_path = data_cfg.global_cmvn_stats_npz
|
||||
self.gcmvn_stats = None
|
||||
if stats_npz_path is not None:
|
||||
self.gcmvn_stats = np.load(stats_npz_path)
|
||||
|
||||
def gcmvn_denormalize(self, x):
|
||||
# x: B x T x C
|
||||
if self.gcmvn_stats is None:
|
||||
return x
|
||||
mean = torch.from_numpy(self.gcmvn_stats["mean"]).to(x)
|
||||
std = torch.from_numpy(self.gcmvn_stats["std"]).to(x)
|
||||
assert len(x.shape) == 3 and mean.shape[0] == std.shape[0] == x.shape[2]
|
||||
x = x * std.view(1, 1, -1).expand_as(x)
|
||||
return x + mean.view(1, 1, -1).expand_as(x)
|
||||
|
||||
def get_waveform(self, feat):
|
||||
# T x C -> T
|
||||
return None if self.vocoder is None else self.vocoder(feat).squeeze(0)
|
||||
|
||||
|
||||
class AutoRegressiveSpeechGenerator(SpeechGenerator):
|
||||
def __init__(
|
||||
self, model, vocoder, data_cfg, max_iter: int = 6000,
|
||||
eos_prob_threshold: float = 0.5,
|
||||
):
|
||||
super().__init__(model, vocoder, data_cfg)
|
||||
self.max_iter = max_iter
|
||||
self.eos_prob_threshold = eos_prob_threshold
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, model, sample, has_targ=False, **kwargs):
|
||||
model.eval()
|
||||
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lengths = sample["net_input"]["src_lengths"]
|
||||
bsz, src_len = src_tokens.size()
|
||||
n_frames_per_step = model.decoder.n_frames_per_step
|
||||
out_dim = model.decoder.out_dim
|
||||
raw_dim = out_dim // n_frames_per_step
|
||||
|
||||
# initialize
|
||||
encoder_out = model.forward_encoder(src_tokens, src_lengths,
|
||||
speaker=sample["speaker"])
|
||||
incremental_state = {}
|
||||
feat, attn, eos_prob = [], [], []
|
||||
finished = src_tokens.new_zeros((bsz,)).bool()
|
||||
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter)
|
||||
|
||||
prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim)
|
||||
for step in range(self.max_iter):
|
||||
cur_out_lens = out_lens.clone()
|
||||
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
|
||||
_, cur_eos_out, cur_extra = model.forward_decoder(
|
||||
prev_feat_out, encoder_out=encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs
|
||||
)
|
||||
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
|
||||
feat.append(cur_extra['feature_out'])
|
||||
attn.append(cur_extra['attn'])
|
||||
eos_prob.append(cur_eos_prob)
|
||||
|
||||
cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold)
|
||||
out_lens.masked_fill_((~finished) & cur_finished, step + 1)
|
||||
finished = finished | cur_finished
|
||||
if finished.sum().item() == bsz:
|
||||
break
|
||||
prev_feat_out = cur_extra['feature_out']
|
||||
|
||||
feat = torch.cat(feat, dim=1)
|
||||
feat = model.decoder.postnet(feat) + feat
|
||||
eos_prob = torch.cat(eos_prob, dim=1)
|
||||
attn = torch.cat(attn, dim=2)
|
||||
alignment = attn.max(dim=1)[1]
|
||||
|
||||
feat = feat.reshape(bsz, -1, raw_dim)
|
||||
feat = self.gcmvn_denormalize(feat)
|
||||
|
||||
eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1)
|
||||
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
|
||||
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
|
||||
out_lens = out_lens * n_frames_per_step
|
||||
|
||||
finalized = [
|
||||
{
|
||||
'feature': feat[b, :out_len],
|
||||
'eos_prob': eos_prob[b, :out_len],
|
||||
'attn': attn[b, :, :out_len],
|
||||
'alignment': alignment[b, :out_len],
|
||||
'waveform': self.get_waveform(feat[b, :out_len]),
|
||||
}
|
||||
for b, out_len in zip(range(bsz), out_lens)
|
||||
]
|
||||
|
||||
if has_targ:
|
||||
assert sample["target"].size(-1) == out_dim
|
||||
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
|
||||
tgt_feats = self.gcmvn_denormalize(tgt_feats)
|
||||
tgt_lens = sample["target_lengths"] * n_frames_per_step
|
||||
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
|
||||
finalized[b]["targ_feature"] = f[:l]
|
||||
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
|
||||
return finalized
|
||||
|
||||
|
||||
class NonAutoregressiveSpeechGenerator(SpeechGenerator):
|
||||
@torch.no_grad()
|
||||
def generate(self, model, sample, has_targ=False, **kwargs):
|
||||
model.eval()
|
||||
|
||||
bsz, max_src_len = sample["net_input"]["src_tokens"].size()
|
||||
n_frames_per_step = model.encoder.n_frames_per_step
|
||||
out_dim = model.encoder.out_dim
|
||||
raw_dim = out_dim // n_frames_per_step
|
||||
|
||||
feat, out_lens, log_dur_out, _, _ = model(
|
||||
src_tokens=sample["net_input"]["src_tokens"],
|
||||
src_lengths=sample["net_input"]["src_lengths"],
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=sample["target_lengths"],
|
||||
speaker=sample["speaker"]
|
||||
)
|
||||
|
||||
feat = feat.view(bsz, -1, raw_dim)
|
||||
feat = self.gcmvn_denormalize(feat)
|
||||
|
||||
dur_out = torch.clamp(
|
||||
torch.round(torch.exp(log_dur_out) - 1).long(), min=0
|
||||
)
|
||||
|
||||
def get_dur_plot_data(d):
|
||||
r = []
|
||||
for i, dd in enumerate(d):
|
||||
r += [i + 1] * dd.item()
|
||||
return r
|
||||
|
||||
out_lens = out_lens * n_frames_per_step
|
||||
finalized = [
|
||||
{
|
||||
'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
|
||||
'waveform': self.get_waveform(
|
||||
feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim])
|
||||
),
|
||||
'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])),
|
||||
}
|
||||
for b, l in zip(range(bsz), out_lens)
|
||||
]
|
||||
|
||||
if has_targ:
|
||||
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
|
||||
tgt_feats = self.gcmvn_denormalize(tgt_feats)
|
||||
tgt_lens = sample["target_lengths"] * n_frames_per_step
|
||||
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
|
||||
finalized[b]["targ_feature"] = f[:l]
|
||||
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
|
||||
return finalized
|
||||
|
||||
|
||||
class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator):
|
||||
@torch.no_grad()
|
||||
def generate(self, model, sample, has_targ=False, **kwargs):
|
||||
model.eval()
|
||||
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lens = sample["net_input"]["src_lengths"]
|
||||
prev_out_tokens = sample["net_input"]["prev_output_tokens"]
|
||||
tgt_lens = sample["target_lengths"]
|
||||
n_frames_per_step = model.decoder.n_frames_per_step
|
||||
raw_dim = model.decoder.out_dim // n_frames_per_step
|
||||
bsz = src_tokens.shape[0]
|
||||
|
||||
feat, eos_prob, extra = model(
|
||||
src_tokens, src_lens, prev_out_tokens, incremental_state=None,
|
||||
target_lengths=tgt_lens, speaker=sample["speaker"]
|
||||
)
|
||||
|
||||
attn = extra["attn"] # B x T_s x T_t
|
||||
alignment = attn.max(dim=1)[1]
|
||||
feat = feat.reshape(bsz, -1, raw_dim)
|
||||
feat = self.gcmvn_denormalize(feat)
|
||||
eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1)
|
||||
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
|
||||
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
|
||||
tgt_lens = sample["target_lengths"] * n_frames_per_step
|
||||
|
||||
finalized = [
|
||||
{
|
||||
'feature': feat[b, :tgt_len],
|
||||
'eos_prob': eos_prob[b, :tgt_len],
|
||||
'attn': attn[b, :, :tgt_len],
|
||||
'alignment': alignment[b, :tgt_len],
|
||||
'waveform': self.get_waveform(feat[b, :tgt_len]),
|
||||
}
|
||||
for b, tgt_len in zip(range(bsz), tgt_lens)
|
||||
]
|
||||
|
||||
if has_targ:
|
||||
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
|
||||
tgt_feats = self.gcmvn_denormalize(tgt_feats)
|
||||
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
|
||||
finalized[b]["targ_feature"] = f[:l]
|
||||
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
|
||||
return finalized
|
56
fairseq/tasks/frm_text_to_speech.py
Normal file
56
fairseq/tasks/frm_text_to_speech.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.text_to_speech import TextToSpeechTask
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task('frm_text_to_speech')
|
||||
class FrmTextToSpeechTask(TextToSpeechTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TextToSpeechTask.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--do_chunk", action="store_true", help="train on chunks"
|
||||
)
|
||||
parser.add_argument("--chunk_bound", default=-1, type=int)
|
||||
parser.add_argument("--chunk_init", default=50, type=int)
|
||||
parser.add_argument("--chunk_incr", default=5, type=int)
|
||||
parser.add_argument("--add_eos", action="store_true")
|
||||
parser.add_argument("--dedup", action="store_true")
|
||||
parser.add_argument("--ref_fpu", default=-1, type=float)
|
||||
|
||||
def load_dataset(self, split, **unused_kwargs):
|
||||
is_train_split = split.startswith("train")
|
||||
pre_tokenizer = self.build_tokenizer(self.args)
|
||||
bpe_tokenizer = self.build_bpe(self.args)
|
||||
self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv(
|
||||
self.args.data,
|
||||
self.data_cfg,
|
||||
split,
|
||||
self.src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split=is_train_split,
|
||||
n_frames_per_step=self.args.n_frames_per_step,
|
||||
speaker_to_id=self.speaker_to_id,
|
||||
do_chunk=self.args.do_chunk,
|
||||
chunk_bound=self.args.chunk_bound,
|
||||
chunk_init=self.args.chunk_init,
|
||||
chunk_incr=self.args.chunk_incr,
|
||||
add_eos=self.args.add_eos,
|
||||
dedup=self.args.dedup,
|
||||
ref_fpu=self.args.ref_fpu
|
||||
)
|
@ -50,6 +50,16 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
|
||||
self.speaker_to_id = self._get_speaker_to_id()
|
||||
|
||||
def _get_speaker_to_id(self):
|
||||
speaker_to_id = None
|
||||
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename")
|
||||
if speaker_set_filename is not None:
|
||||
speaker_set_path = Path(self.args.data) / speaker_set_filename
|
||||
with open(speaker_set_path) as f:
|
||||
speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
|
||||
return speaker_to_id
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
@ -91,6 +101,7 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
is_train_split=is_train_split,
|
||||
epoch=epoch,
|
||||
seed=self.args.seed,
|
||||
speaker_to_id=self.speaker_to_id
|
||||
)
|
||||
|
||||
@property
|
||||
@ -107,6 +118,7 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
def build_model(self, args):
|
||||
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
|
||||
args.input_channels = self.data_cfg.input_channels
|
||||
args.speaker_to_id = self.speaker_to_id
|
||||
return super(SpeechToTextTask, self).build_model(args)
|
||||
|
||||
def build_generator(
|
||||
@ -126,12 +138,13 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
for s, i in self.tgt_dict.indices.items()
|
||||
if SpeechToTextDataset.is_lang_tag(s)
|
||||
}
|
||||
|
||||
if extra_gen_cls_kwargs is None:
|
||||
extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids}
|
||||
else:
|
||||
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
|
||||
extra_gen_cls_kwargs = {}
|
||||
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
|
||||
return super().build_generator(
|
||||
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
models, args, seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
)
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
|
467
fairseq/tasks/text_to_speech.py
Normal file
467
fairseq/tasks/text_to_speech.py
Normal file
@ -0,0 +1,467 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as op
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.speech_to_text import SpeechToTextTask
|
||||
from fairseq.speech_generator import (
|
||||
AutoRegressiveSpeechGenerator, NonAutoregressiveSpeechGenerator,
|
||||
TeacherForcingAutoRegressiveSpeechGenerator
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
logger.info("Please install tensorboardX: pip install tensorboardX")
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
@register_task('text_to_speech')
|
||||
class TextToSpeechTask(SpeechToTextTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument('data', help='manifest root path')
|
||||
parser.add_argument(
|
||||
'--config-yaml', type=str, default='config.yaml',
|
||||
help='Configuration YAML filename (under manifest root)'
|
||||
)
|
||||
parser.add_argument('--max-source-positions', default=1024, type=int,
|
||||
metavar='N',
|
||||
help='max number of tokens in the source sequence')
|
||||
parser.add_argument('--max-target-positions', default=1200, type=int,
|
||||
metavar='N',
|
||||
help='max number of tokens in the target sequence')
|
||||
parser.add_argument("--n-frames-per-step", type=int, default=1)
|
||||
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
|
||||
parser.add_argument("--eval-inference", action="store_true")
|
||||
parser.add_argument("--eval-tb-nsample", type=int, default=8)
|
||||
parser.add_argument("--vocoder", type=str, default="griffin_lim")
|
||||
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
|
||||
|
||||
def __init__(self, args, src_dict):
|
||||
super().__init__(args, src_dict)
|
||||
self.src_dict = src_dict
|
||||
self.sr = self.data_cfg.config.get("features").get("sample_rate")
|
||||
|
||||
self.tensorboard_writer = None
|
||||
self.tensorboard_dir = ""
|
||||
if args.tensorboard_logdir and SummaryWriter is not None:
|
||||
self.tensorboard_dir = os.path.join(args.tensorboard_logdir,
|
||||
"valid_extra")
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
is_train_split = split.startswith('train')
|
||||
pre_tokenizer = self.build_tokenizer(self.args)
|
||||
bpe_tokenizer = self.build_bpe(self.args)
|
||||
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv(
|
||||
self.args.data, self.data_cfg, split, self.src_dict,
|
||||
pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split,
|
||||
epoch=epoch, seed=self.args.seed,
|
||||
n_frames_per_step=self.args.n_frames_per_step,
|
||||
speaker_to_id=self.speaker_to_id
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.src_dict
|
||||
|
||||
def get_speaker_embeddings_path(self):
|
||||
speaker_emb_path = None
|
||||
if self.data_cfg.config.get("speaker_emb_filename") is not None:
|
||||
speaker_emb_path = op.join(
|
||||
self.args.data, self.data_cfg.config.get("speaker_emb_filename")
|
||||
)
|
||||
return speaker_emb_path
|
||||
|
||||
@classmethod
|
||||
def get_speaker_embeddings(cls, args):
|
||||
embed_speaker = None
|
||||
if args.speaker_to_id is not None:
|
||||
if args.speaker_emb_path is None:
|
||||
embed_speaker = torch.nn.Embedding(
|
||||
len(args.speaker_to_id), args.speaker_embed_dim
|
||||
)
|
||||
else:
|
||||
speaker_emb_mat = np.load(args.speaker_emb_path)
|
||||
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim
|
||||
embed_speaker = torch.nn.Embedding.from_pretrained(
|
||||
torch.from_numpy(speaker_emb_mat), freeze=True,
|
||||
)
|
||||
logger.info(
|
||||
f"load speaker embeddings from {args.speaker_emb_path}. "
|
||||
f"train embedding? {embed_speaker.weight.requires_grad}\n"
|
||||
f"embeddings:\n{speaker_emb_mat}"
|
||||
)
|
||||
return embed_speaker
|
||||
|
||||
def build_model(self, cfg):
|
||||
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None)
|
||||
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None)
|
||||
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None)
|
||||
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None)
|
||||
cfg.speaker_emb_path = self.get_speaker_embeddings_path()
|
||||
model = super().build_model(cfg)
|
||||
self.generator = None
|
||||
if getattr(cfg, "eval_inference", False):
|
||||
self.generator = self.build_generator([model], cfg)
|
||||
return model
|
||||
|
||||
def build_generator(self, models, cfg, vocoder=None, **unused):
|
||||
if vocoder is None:
|
||||
vocoder = self.build_default_vocoder()
|
||||
model = models[0]
|
||||
if getattr(model, "NON_AUTOREGRESSIVE", False):
|
||||
return NonAutoregressiveSpeechGenerator(
|
||||
model, vocoder, self.data_cfg
|
||||
)
|
||||
else:
|
||||
generator = AutoRegressiveSpeechGenerator
|
||||
if getattr(cfg, "teacher_forcing", False):
|
||||
generator = TeacherForcingAutoRegressiveSpeechGenerator
|
||||
logger.info("Teacher forcing mode for generation")
|
||||
return generator(
|
||||
model, vocoder, self.data_cfg,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold
|
||||
)
|
||||
|
||||
def build_default_vocoder(self):
|
||||
from fairseq.models.text_to_speech.vocoder import get_vocoder
|
||||
vocoder = get_vocoder(self.args, self.data_cfg)
|
||||
if torch.cuda.is_available() and not self.args.cpu:
|
||||
vocoder = vocoder.cuda()
|
||||
else:
|
||||
vocoder = vocoder.cpu()
|
||||
return vocoder
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
loss, sample_size, logging_output = super().valid_step(
|
||||
sample, model, criterion
|
||||
)
|
||||
|
||||
if getattr(self.args, "eval_inference", False):
|
||||
hypos, inference_losses = self.valid_step_with_inference(
|
||||
sample, model, self.generator
|
||||
)
|
||||
for k, v in inference_losses.items():
|
||||
assert(k not in logging_output)
|
||||
logging_output[k] = v
|
||||
|
||||
picked_id = 0
|
||||
if self.tensorboard_dir and (sample["id"] == picked_id).any():
|
||||
self.log_tensorboard(
|
||||
sample,
|
||||
hypos[:self.args.eval_tb_nsample],
|
||||
model._num_updates,
|
||||
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False)
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step_with_inference(self, sample, model, generator):
|
||||
hypos = generator.generate(model, sample, has_targ=True)
|
||||
|
||||
losses = {
|
||||
"mcd_loss": 0.,
|
||||
"targ_frames": 0.,
|
||||
"pred_frames": 0.,
|
||||
"nins": 0.,
|
||||
"ndel": 0.,
|
||||
}
|
||||
rets = batch_mel_cepstral_distortion(
|
||||
[hypo["targ_waveform"] for hypo in hypos],
|
||||
[hypo["waveform"] for hypo in hypos],
|
||||
self.sr,
|
||||
normalize_type=None
|
||||
)
|
||||
for d, extra in rets:
|
||||
pathmap = extra[-1]
|
||||
losses["mcd_loss"] += d.item()
|
||||
losses["targ_frames"] += pathmap.size(0)
|
||||
losses["pred_frames"] += pathmap.size(1)
|
||||
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
|
||||
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
|
||||
|
||||
return hypos, losses
|
||||
|
||||
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False):
|
||||
if self.tensorboard_writer is None:
|
||||
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
|
||||
tb_writer = self.tensorboard_writer
|
||||
for b in range(len(hypos)):
|
||||
idx = sample["id"][b]
|
||||
text = sample["src_texts"][b]
|
||||
targ = hypos[b]["targ_feature"]
|
||||
pred = hypos[b]["feature"]
|
||||
attn = hypos[b]["attn"]
|
||||
|
||||
if is_na_model:
|
||||
data = plot_tts_output(
|
||||
[targ.transpose(0, 1), pred.transpose(0, 1)],
|
||||
[f"target (idx={idx})", "output"], attn,
|
||||
"alignment", ret_np=True, suptitle=text,
|
||||
)
|
||||
else:
|
||||
eos_prob = hypos[b]["eos_prob"]
|
||||
data = plot_tts_output(
|
||||
[targ.transpose(0, 1), pred.transpose(0, 1), attn],
|
||||
[f"target (idx={idx})", "output", "alignment"], eos_prob,
|
||||
"eos prob", ret_np=True, suptitle=text,
|
||||
)
|
||||
|
||||
tb_writer.add_image(
|
||||
f"inference_sample_{b}", data, num_updates,
|
||||
dataformats="HWC"
|
||||
)
|
||||
|
||||
if hypos[b]["waveform"] is not None:
|
||||
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float()
|
||||
pred_wave = hypos[b]["waveform"].detach().cpu().float()
|
||||
tb_writer.add_audio(
|
||||
f"inference_targ_{b}",
|
||||
targ_wave,
|
||||
num_updates,
|
||||
sample_rate=self.sr
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
f"inference_pred_{b}",
|
||||
pred_wave,
|
||||
num_updates,
|
||||
sample_rate=self.sr
|
||||
)
|
||||
|
||||
|
||||
def save_figure_to_numpy(fig):
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
return data
|
||||
|
||||
|
||||
DEFAULT_V_MIN = np.log(1e-5)
|
||||
|
||||
|
||||
def plot_tts_output(
|
||||
data_2d, title_2d, data_1d, title_1d, figsize=(24, 4),
|
||||
v_min=DEFAULT_V_MIN, v_max=3, ret_np=False, suptitle=""
|
||||
):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
except ImportError:
|
||||
raise ImportError("Please install Matplotlib: pip install matplotlib")
|
||||
|
||||
data_2d = [
|
||||
x.detach().cpu().float().numpy()
|
||||
if isinstance(x, torch.Tensor) else x for x in data_2d
|
||||
]
|
||||
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize)
|
||||
if suptitle:
|
||||
fig.suptitle(suptitle[:400]) # capped at 400 chars
|
||||
axes = [axes] if len(data_2d) == 0 else axes
|
||||
for ax, x, name in zip(axes, data_2d, title_2d):
|
||||
ax.set_title(name)
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes('right', size='5%', pad=0.05)
|
||||
im = ax.imshow(
|
||||
x, origin="lower", aspect="auto", vmin=max(x.min(), v_min),
|
||||
vmax=min(x.max(), v_max)
|
||||
)
|
||||
fig.colorbar(im, cax=cax, orientation='vertical')
|
||||
|
||||
if isinstance(data_1d, torch.Tensor):
|
||||
data_1d = data_1d.detach().cpu().numpy()
|
||||
axes[-1].plot(data_1d)
|
||||
axes[-1].set_title(title_1d)
|
||||
plt.tight_layout()
|
||||
|
||||
if ret_np:
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close(fig)
|
||||
return data
|
||||
|
||||
|
||||
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None):
|
||||
"""
|
||||
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs
|
||||
|
||||
offset=2 (1, 1),
|
||||
offset=3 (2, 1), (1, 2)
|
||||
offset=4 (2, 2), (1, 3)
|
||||
offset=5 (2, 3)
|
||||
|
||||
constraints:
|
||||
i + j = offset
|
||||
min_j <= j < max_j
|
||||
min_i <= offset - j < max_i
|
||||
"""
|
||||
if max_i is None:
|
||||
max_i = offset + 1
|
||||
if max_j is None:
|
||||
max_j = offset + 1
|
||||
min_j = max(min_j, offset - max_i + 1, 0)
|
||||
max_j = min(max_j, offset - min_i + 1, offset + 1)
|
||||
j = torch.arange(min_j, max_j)
|
||||
i = offset - j
|
||||
return torch.stack([i, j])
|
||||
|
||||
|
||||
def batch_dynamic_time_warping(distance, shapes=None):
|
||||
"""full batched DTW without any constraints
|
||||
|
||||
distance: (batchsize, max_M, max_N) matrix
|
||||
shapes: (batchsize,) vector specifying (M, N) for each entry
|
||||
"""
|
||||
# ptr: 0=left, 1=up-left, 2=up
|
||||
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)}
|
||||
|
||||
bsz, m, n = distance.size()
|
||||
cumdist = torch.zeros_like(distance)
|
||||
backptr = torch.zeros_like(distance).type(torch.int32) - 1
|
||||
|
||||
# initialize
|
||||
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1)
|
||||
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1)
|
||||
backptr[:, 0, :] = 0
|
||||
backptr[:, :, 0] = 2
|
||||
|
||||
# DP with optimized anti-diagonal parallelization, O(M+N) steps
|
||||
for offset in range(2, m + n - 1):
|
||||
ind = antidiag_indices(offset, 1, m, 1, n)
|
||||
c = torch.stack(
|
||||
[cumdist[:, ind[0], ind[1] - 1], cumdist[:, ind[0] - 1, ind[1] - 1],
|
||||
cumdist[:, ind[0] - 1, ind[1]], ],
|
||||
dim=2
|
||||
)
|
||||
v, b = c.min(axis=-1)
|
||||
backptr[:, ind[0], ind[1]] = b.int()
|
||||
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]]
|
||||
|
||||
# backtrace
|
||||
pathmap = torch.zeros_like(backptr)
|
||||
for b in range(bsz):
|
||||
i = m - 1 if shapes is None else (shapes[b][0] - 1).item()
|
||||
j = n - 1 if shapes is None else (shapes[b][1] - 1).item()
|
||||
dtwpath = [(i, j)]
|
||||
while (i != 0 or j != 0) and len(dtwpath) < 10000:
|
||||
assert (i >= 0 and j >= 0)
|
||||
di, dj = ptr2dij[backptr[b, i, j].item()]
|
||||
i, j = i + di, j + dj
|
||||
dtwpath.append((i, j))
|
||||
dtwpath = dtwpath[::-1]
|
||||
indices = torch.from_numpy(np.array(dtwpath))
|
||||
pathmap[b, indices[:, 0], indices[:, 1]] = 1
|
||||
|
||||
return cumdist, backptr, pathmap
|
||||
|
||||
|
||||
def compute_l2_dist(x1, x2):
|
||||
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices"""
|
||||
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2)
|
||||
|
||||
|
||||
def compute_rms_dist(x1, x2):
|
||||
l2_dist = compute_l2_dist(x1, x2)
|
||||
return (l2_dist / x1.size(1)).pow(0.5)
|
||||
|
||||
|
||||
def get_divisor(pathmap, normalize_type):
|
||||
if normalize_type is None:
|
||||
return 1
|
||||
elif normalize_type == "len1":
|
||||
return pathmap.size(0)
|
||||
elif normalize_type == "len2":
|
||||
return pathmap.size(1)
|
||||
elif normalize_type == "path":
|
||||
return pathmap.sum().item()
|
||||
else:
|
||||
raise ValueError(f"normalize_type {normalize_type} not supported")
|
||||
|
||||
|
||||
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type):
|
||||
d, s, x1, x2 = [], [], [], []
|
||||
for cur_y1, cur_y2 in zip(y1, y2):
|
||||
assert (cur_y1.ndim == 1 and cur_y2.ndim == 1)
|
||||
cur_x1 = feat_fn(cur_y1)
|
||||
cur_x2 = feat_fn(cur_y2)
|
||||
x1.append(cur_x1)
|
||||
x2.append(cur_x2)
|
||||
|
||||
cur_d = dist_fn(cur_x1, cur_x2)
|
||||
d.append(cur_d)
|
||||
s.append(d[-1].size())
|
||||
max_m = max(ss[0] for ss in s)
|
||||
max_n = max(ss[1] for ss in s)
|
||||
d = torch.stack(
|
||||
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d]
|
||||
)
|
||||
s = torch.LongTensor(s).to(d.device)
|
||||
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s)
|
||||
|
||||
rets = []
|
||||
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps)
|
||||
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr:
|
||||
cumdist = cumdist[:m, :n]
|
||||
backptr = backptr[:m, :n]
|
||||
pathmap = pathmap[:m, :n]
|
||||
divisor = get_divisor(pathmap, normalize_type)
|
||||
|
||||
distortion = cumdist[-1, -1] / divisor
|
||||
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap)
|
||||
rets.append(ret)
|
||||
return rets
|
||||
|
||||
|
||||
def batch_mel_cepstral_distortion(
|
||||
y1, y2, sr, normalize_type="path", mfcc_fn=None
|
||||
):
|
||||
"""
|
||||
https://arxiv.org/pdf/2011.03568.pdf
|
||||
|
||||
The root mean squared error computed on 13-dimensional MFCC using DTW for
|
||||
alignment. MFCC features are computed from an 80-channel log-mel
|
||||
spectrogram using a 50ms Hann window and hop of 12.5ms.
|
||||
|
||||
y1: list of waveforms
|
||||
y2: list of waveforms
|
||||
sr: sampling rate
|
||||
"""
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except ImportError:
|
||||
raise ImportError("Please install torchaudio: pip install torchaudio")
|
||||
|
||||
if mfcc_fn is None or mfcc_fn.sample_rate != sr:
|
||||
melkwargs = {
|
||||
"n_fft": int(0.05 * sr), "win_length": int(0.05 * sr),
|
||||
"hop_length": int(0.0125 * sr), "f_min": 20,
|
||||
"n_mels": 80, "window_fn": torch.hann_window
|
||||
}
|
||||
mfcc_fn = torchaudio.transforms.MFCC(
|
||||
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs
|
||||
).to(y1[0].device)
|
||||
return batch_compute_distortion(
|
||||
y1, y2, sr, lambda y: mfcc_fn(y).transpose(-1, -2), compute_rms_dist,
|
||||
normalize_type
|
||||
)
|
Loading…
Reference in New Issue
Block a user