Summary: [fairseq-py] add TTS

Reviewed By: wnhsu

Differential Revision: D30720666

fbshipit-source-id: b5288acec72bea1d3a9f3884a4ed51b616c7a403
This commit is contained in:
Changhan Wang 2021-09-13 18:12:38 -07:00 committed by Facebook GitHub Bot
parent 32b31173aa
commit 0ac3f3270c
26 changed files with 3801 additions and 8 deletions

View File

@ -0,0 +1,320 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from pathlib import Path
from typing import Optional, List, Dict
import zipfile
import tempfile
from dataclasses import dataclass
from itertools import groupby
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from examples.speech_to_text.data_utils import load_tsv_to_dicts
from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale
def trim_or_pad_to_target_length(
data_1d_or_2d: np.ndarray, target_length: int
) -> np.ndarray:
assert len(data_1d_or_2d.shape) in {1, 2}
delta = data_1d_or_2d.shape[0] - target_length
if delta >= 0: # trim if being longer
data_1d_or_2d = data_1d_or_2d[: target_length]
else: # pad if being shorter
if len(data_1d_or_2d.shape) == 1:
data_1d_or_2d = np.concatenate(
[data_1d_or_2d, np.zeros(-delta)], axis=0
)
else:
data_1d_or_2d = np.concatenate(
[data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))],
axis=0
)
return data_1d_or_2d
def extract_logmel_spectrogram(
waveform: torch.Tensor, sample_rate: int,
output_path: Optional[Path] = None, win_length: int = 1024,
hop_length: int = 256, n_fft: int = 1024,
win_fn: callable = torch.hann_window, n_mels: int = 80,
f_min: float = 0., f_max: float = 8000, eps: float = 1e-5,
overwrite: bool = False, target_length: Optional[int] = None
):
if output_path is not None and output_path.is_file() and not overwrite:
return
spectrogram_transform = TTSSpectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
window_fn=win_fn
)
mel_scale_transform = TTSMelScale(
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
n_stft=n_fft // 2 + 1
)
spectrogram = spectrogram_transform(waveform)
mel_spec = mel_scale_transform(spectrogram)
logmel_spec = torch.clamp(mel_spec, min=eps).log()
assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1
logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D
if target_length is not None:
trim_or_pad_to_target_length(logmel_spec, target_length)
if output_path is not None:
np.save(output_path.as_posix(), logmel_spec)
else:
return logmel_spec
def extract_pitch(
waveform: torch.Tensor, sample_rate: int,
output_path: Optional[Path] = None, hop_length: int = 256,
log_scale: bool = True, phoneme_durations: Optional[List[int]] = None
):
if output_path is not None and output_path.is_file():
return
try:
import pyworld
except ImportError:
raise ImportError("Please install PyWORLD: pip install pyworld")
_waveform = waveform.squeeze(0).double().numpy()
pitch, t = pyworld.dio(
_waveform, sample_rate, frame_period=hop_length / sample_rate * 1000
)
pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate)
if phoneme_durations is not None:
pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations))
try:
from scipy.interpolate import interp1d
except ImportError:
raise ImportError("Please install SciPy: pip install scipy")
nonzero_ids = np.where(pitch != 0)[0]
interp_fn = interp1d(
nonzero_ids,
pitch[nonzero_ids],
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
bounds_error=False,
)
pitch = interp_fn(np.arange(0, len(pitch)))
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
pitch = np.array(
[
np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]])
for i in range(1, len(d_cumsum))
]
)
assert len(pitch) == len(phoneme_durations)
if log_scale:
pitch = np.log(pitch + 1)
if output_path is not None:
np.save(output_path.as_posix(), pitch)
else:
return pitch
def extract_energy(
waveform: torch.Tensor, output_path: Optional[Path] = None,
hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True,
phoneme_durations: Optional[List[int]] = None
):
if output_path is not None and output_path.is_file():
return
assert len(waveform.shape) == 2 and waveform.shape[0] == 1
waveform = waveform.view(1, 1, waveform.shape[1])
waveform = F.pad(
waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0],
mode="reflect"
)
waveform = waveform.squeeze(1)
fourier_basis = np.fft.fft(np.eye(n_fft))
cutoff = int((n_fft / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])]
)
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
forward_transform = F.conv1d(
waveform, forward_basis, stride=hop_length, padding=0
)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
energy = torch.norm(magnitude, dim=1).squeeze(0).numpy()
if phoneme_durations is not None:
energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations))
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
energy = np.array(
[
np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]])
for i in range(1, len(d_cumsum))
]
)
assert len(energy) == len(phoneme_durations)
if log_scale:
energy = np.log(energy + 1)
if output_path is not None:
np.save(output_path.as_posix(), energy)
else:
return energy
def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None):
mean_x, mean_x2, n_frames = None, None, 0
feature_paths = feature_root.glob("*.npy")
for p in tqdm(feature_paths):
with open(p, 'rb') as f:
frames = np.load(f).squeeze()
n_frames += frames.shape[0]
cur_mean_x = frames.sum(axis=0)
if mean_x is None:
mean_x = cur_mean_x
else:
mean_x += cur_mean_x
cur_mean_x2 = (frames ** 2).sum(axis=0)
if mean_x2 is None:
mean_x2 = cur_mean_x2
else:
mean_x2 += cur_mean_x2
mean_x /= n_frames
mean_x2 /= n_frames
var_x = mean_x2 - mean_x ** 2
std_x = np.sqrt(np.maximum(var_x, 1e-10))
if output_path is not None:
with open(output_path, 'wb') as f:
np.savez(f, mean=mean_x, std=std_x)
else:
return {"mean": mean_x, "std": std_x}
def ipa_phonemize(text, lang="en-us", use_g2p=False):
if use_g2p:
assert lang == "en-us", "g2pE phonemizer only works for en-us"
try:
from g2p_en import G2p
g2p = G2p()
return " ".join("|" if p == " " else p for p in g2p(text))
except ImportError:
raise ImportError(
"Please install phonemizer: pip install g2p_en"
)
else:
try:
from phonemizer import phonemize
from phonemizer.separator import Separator
return phonemize(
text, backend='espeak', language=lang,
separator=Separator(word="| ", phone=" ")
)
except ImportError:
raise ImportError(
"Please install phonemizer: pip install phonemizer"
)
@dataclass
class ForceAlignmentInfo(object):
tokens: List[str]
frame_durations: List[int]
start_sec: Optional[float]
end_sec: Optional[float]
def get_mfa_alignment_by_sample_id(
textgrid_zip_path: str, sample_id: str, sample_rate: int,
hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn")
) -> ForceAlignmentInfo:
try:
import tgt
except ImportError:
raise ImportError("Please install TextGridTools: pip install tgt")
filename = f"{sample_id}.TextGrid"
out_root = Path(tempfile.gettempdir())
tgt_path = out_root / filename
with zipfile.ZipFile(textgrid_zip_path) as f_zip:
f_zip.extract(filename, path=out_root)
textgrid = tgt.io.read_textgrid(tgt_path.as_posix())
os.remove(tgt_path)
phones, frame_durations = [], []
start_sec, end_sec, end_idx = 0, 0, 0
for t in textgrid.get_tier_by_name("phones")._objects:
s, e, p = t.start_time, t.end_time, t.text
# Trim leading silences
if len(phones) == 0:
if p in silence_phones:
continue
else:
start_sec = s
phones.append(p)
if p not in silence_phones:
end_sec = e
end_idx = len(phones)
r = sample_rate / hop_length
frame_durations.append(int(np.round(e * r) - np.round(s * r)))
# Trim tailing silences
phones = phones[:end_idx]
frame_durations = frame_durations[:end_idx]
return ForceAlignmentInfo(
tokens=phones, frame_durations=frame_durations, start_sec=start_sec,
end_sec=end_sec
)
def get_mfa_alignment(
textgrid_zip_path: str, sample_ids: List[str], sample_rate: int,
hop_length: int
) -> Dict[str, ForceAlignmentInfo]:
return {
i: get_mfa_alignment_by_sample_id(
textgrid_zip_path, i, sample_rate, hop_length
) for i in tqdm(sample_ids)
}
def get_unit_alignment(
id_to_unit_tsv_path: str, sample_ids: List[str]
) -> Dict[str, ForceAlignmentInfo]:
id_to_units = {
e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path)
}
id_to_units = {i: id_to_units[i].split() for i in sample_ids}
id_to_units_collapsed = {
i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items()
}
id_to_durations = {
i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items()
}
return {
i: ForceAlignmentInfo(
tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i],
start_sec=None, end_sec=None
)
for i in sample_ids
}

View File

@ -0,0 +1,191 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import soundfile as sf
import sys
import torch
import torchaudio
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.logging import progress_bar
from fairseq.tasks.text_to_speech import plot_tts_output
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset
logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def make_parser():
parser = options.get_speech_generation_parser()
parser.add_argument("--dump-features", action="store_true")
parser.add_argument("--dump-waveforms", action="store_true")
parser.add_argument("--dump-attentions", action="store_true")
parser.add_argument("--dump-eos-probs", action="store_true")
parser.add_argument("--dump-plots", action="store_true")
parser.add_argument("--dump-target", action="store_true")
parser.add_argument("--output-sample-rate", default=22050, type=int)
parser.add_argument("--teacher-forcing", action="store_true")
parser.add_argument(
"--audio-format", type=str, default="wav", choices=["wav", "flac"]
)
return parser
def postprocess_results(
dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target
):
def to_np(x):
return None if x is None else x.detach().cpu().numpy()
sample_ids = [dataset.ids[i] for i in sample["id"].tolist()]
texts = sample["src_texts"]
attns = [to_np(hypo["attn"]) for hypo in hypos]
eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos]
feat_preds = [to_np(hypo["feature"]) for hypo in hypos]
wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos]
if dump_target:
feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos]
wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos]
else:
feat_targs = [None for _ in hypos]
wave_targs = [None for _ in hypos]
return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds,
feat_targs, wave_targs)
def dump_result(
is_na_model,
args,
vocoder,
sample_id,
text,
attn,
eos_prob,
feat_pred,
wave_pred,
feat_targ,
wave_targ,
):
sample_rate = args.output_sample_rate
out_root = Path(args.results_path)
if args.dump_features:
feat_dir = out_root / "feat"
feat_dir.mkdir(exist_ok=True, parents=True)
np.save(feat_dir / f"{sample_id}.npy", feat_pred)
if args.dump_target:
feat_tgt_dir = out_root / "feat_tgt"
feat_tgt_dir.mkdir(exist_ok=True, parents=True)
np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ)
if args.dump_attentions:
attn_dir = out_root / "attn"
attn_dir.mkdir(exist_ok=True, parents=True)
np.save(attn_dir / f"{sample_id}.npy", attn.numpy())
if args.dump_eos_probs and not is_na_model:
eos_dir = out_root / "eos"
eos_dir.mkdir(exist_ok=True, parents=True)
np.save(eos_dir / f"{sample_id}.npy", eos_prob)
if args.dump_plots:
images = [feat_pred.T] if is_na_model else [feat_pred.T, attn]
names = ["output"] if is_na_model else ["output", "alignment"]
if feat_targ is not None:
images = [feat_targ.T] + images
names = [f"target (idx={sample_id})"] + names
if is_na_model:
plot_tts_output(images, names, attn, "alignment", suptitle=text)
else:
plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text)
plot_dir = out_root / "plot"
plot_dir.mkdir(exist_ok=True, parents=True)
plt.savefig(plot_dir / f"{sample_id}.png")
plt.close()
if args.dump_waveforms:
ext = args.audio_format
if wave_pred is not None:
wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}"
wav_dir.mkdir(exist_ok=True, parents=True)
sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate)
if args.dump_target and wave_targ is not None:
wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt"
wav_tgt_dir.mkdir(exist_ok=True, parents=True)
sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate)
def main(args):
assert(args.dump_features or args.dump_waveforms or args.dump_attentions
or args.dump_eos_probs or args.dump_plots)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 8000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
task = tasks.setup_task(args)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path],
task=task,
)
model = models[0].cuda() if use_cuda else models[0]
# use the original n_frames_per_step
task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
data_cfg = task.data_cfg
sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050)
resample_fn = {
False: lambda x: x,
True: lambda x: torchaudio.sox_effects.apply_effects_tensor(
x.detach().cpu().unsqueeze(0), sample_rate,
[['rate', str(args.output_sample_rate)]]
)[0].squeeze(0)
}.get(args.output_sample_rate != sample_rate)
if args.output_sample_rate != sample_rate:
logger.info(f"resampling to {args.output_sample_rate}Hz")
generator = task.build_generator([model], args)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
Path(args.results_path).mkdir(exist_ok=True, parents=True)
is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False)
dataset = task.dataset(args.gen_subset)
vocoder = task.args.vocoder
with progress_bar.build_progress_bar(args, itr) as t:
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
hypos = generator.generate(model, sample, has_targ=args.dump_target)
for result in postprocess_results(
dataset, sample, hypos, resample_fn, args.dump_target
):
dump_result(is_na_model, args, vocoder, *result)
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()

View File

@ -0,0 +1,101 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from scipy.interpolate import interp1d
import torchaudio
from fairseq.tasks.text_to_speech import (
batch_compute_distortion, compute_rms_dist
)
def batch_mel_spectral_distortion(
y1, y2, sr, normalize_type="path", mel_fn=None
):
"""
https://arxiv.org/pdf/2011.03568.pdf
Same as Mel Cepstral Distortion, but computed on log-mel spectrograms.
"""
if mel_fn is None or mel_fn.sample_rate != sr:
mel_fn = torchaudio.transforms.MelSpectrogram(
sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr),
hop_length=int(0.0125 * sr), f_min=20, n_mels=80,
window_fn=torch.hann_window
).to(y1[0].device)
offset = 1e-6
return batch_compute_distortion(
y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2),
compute_rms_dist, normalize_type
)
# This code is based on
# "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py"
def _same_t_in_true_and_est(func):
def new_func(true_t, true_f, est_t, est_f):
assert type(true_t) is np.ndarray
assert type(true_f) is np.ndarray
assert type(est_t) is np.ndarray
assert type(est_f) is np.ndarray
interpolated_f = interp1d(
est_t, est_f, bounds_error=False, kind='nearest', fill_value=0
)(true_t)
return func(true_t, true_f, true_t, interpolated_f)
return new_func
@_same_t_in_true_and_est
def gross_pitch_error(true_t, true_f, est_t, est_f):
"""The relative frequency in percent of pitch estimates that are
outside a threshold around the true pitch. Only frames that are
considered pitched by both the ground truth and the estimator (if
applicable) are considered.
"""
correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
gross_pitch_error_frames = _gross_pitch_error_frames(
true_t, true_f, est_t, est_f
)
return np.sum(gross_pitch_error_frames) / np.sum(correct_frames)
def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8):
voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
true_f_p_eps = [x + eps for x in true_f]
pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2
return voiced_frames & pitch_error_frames
def _true_voiced_frames(true_t, true_f, est_t, est_f):
return (est_f != 0) & (true_f != 0)
def _voicing_decision_error_frames(true_t, true_f, est_t, est_f):
return (est_f != 0) != (true_f != 0)
@_same_t_in_true_and_est
def f0_frame_error(true_t, true_f, est_t, est_f):
gross_pitch_error_frames = _gross_pitch_error_frames(
true_t, true_f, est_t, est_f
)
voicing_decision_error_frames = _voicing_decision_error_frames(
true_t, true_f, est_t, est_f
)
return (np.sum(gross_pitch_error_frames) +
np.sum(voicing_decision_error_frames)) / (len(true_t))
@_same_t_in_true_and_est
def voicing_decision_error(true_t, true_f, est_t, est_f):
voicing_decision_error_frames = _voicing_decision_error_frames(
true_t, true_f, est_t, est_f
)
return np.sum(voicing_decision_error_frames) / (len(true_t))

View File

@ -0,0 +1,125 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import List, Dict, Any
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import lengths_to_mask
from fairseq.models.fairseq_model import FairseqEncoderModel
@dataclass
class FastSpeech2CriterionConfig(FairseqDataclass):
ctc_weight: float = field(
default=0.0, metadata={"help": "weight for CTC loss"}
)
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
class FastSpeech2Loss(FairseqCriterion):
def __init__(self, task, ctc_weight):
super().__init__(task)
self.ctc_weight = ctc_weight
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
_feat_out, _, log_dur_out, pitch_out, energy_out = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"]
)
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
tgt_mask = lengths_to_mask(sample["target_lengths"])
pitches, energies = sample["pitches"], sample["energies"]
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
energy_out, energies = energy_out[src_mask], energies[src_mask]
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
log_dur_out = log_dur_out[src_mask]
dur = sample["durations"].float()
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
ctc_loss = torch.tensor(0.).type_as(l1_loss)
if self.ctc_weight > 0.:
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = F.ctc_loss(
lprobs, src_tokens_flat, tgt_lens, src_lens,
reduction=reduction, zero_infinity=True
) * self.ctc_weight
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
sample_size = sample["nsentences"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"dur_loss": utils.item(dur_loss.data),
"pitch_loss": utils.item(pitch_loss.data),
"energy_loss": utils.item(energy_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in [
"loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss",
"ctc_loss"
]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False

View File

@ -0,0 +1,210 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
from typing import Any, Dict, List
from functools import lru_cache
from dataclasses import dataclass, field
import torch
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import lengths_to_mask
import torch.nn.functional as F
logger = logging.getLogger(__name__)
@dataclass
class Tacotron2CriterionConfig(FairseqDataclass):
bce_pos_weight: float = field(
default=1.0,
metadata={"help": "weight of positive examples for BCE loss"},
)
n_frames_per_step: int = field(
default=0,
metadata={"help": "Number of frames per decoding step"},
)
use_guided_attention_loss: bool = field(
default=False,
metadata={"help": "use guided attention loss"},
)
guided_attention_loss_sigma: float = field(
default=0.4,
metadata={"help": "weight of positive examples for BCE loss"},
)
ctc_weight: float = field(
default=0.0, metadata={"help": "weight for CTC loss"}
)
sentence_avg: bool = II("optimization.sentence_avg")
class GuidedAttentionLoss(torch.nn.Module):
"""
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
"""
def __init__(self, sigma):
super().__init__()
self.sigma = sigma
@staticmethod
@lru_cache(maxsize=8)
def _get_weight(s_len, t_len, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
grid_x = grid_x.to(s_len.device)
grid_y = grid_y.to(s_len.device)
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
return 1.0 - torch.exp(-w / (2 * (sigma ** 2)))
def _get_weights(self, src_lens, tgt_lens):
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
weights = torch.zeros((bsz, max_t_len, max_s_len))
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len,
self.sigma)
return weights
@staticmethod
def _get_masks(src_lens, tgt_lens):
in_masks = lengths_to_mask(src_lens)
out_masks = lengths_to_mask(tgt_lens)
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
loss = (weights * attn.transpose(1, 2)).masked_select(masks)
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
return loss
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
class Tacotron2Criterion(FairseqCriterion):
def __init__(self, task, sentence_avg, n_frames_per_step,
use_guided_attention_loss, guided_attention_loss_sigma,
bce_pos_weight, ctc_weight):
super().__init__(task)
self.sentence_avg = sentence_avg
self.n_frames_per_step = n_frames_per_step
self.bce_pos_weight = bce_pos_weight
self.guided_attn = None
if use_guided_attention_loss:
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
self.ctc_weight = ctc_weight
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
feat_out, eos_out, extra = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"]
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt,
tgt_lens, reduction,
)
attn_loss = torch.tensor(0.).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction)
ctc_loss = torch.tensor(0.).type_as(l1_loss)
if self.ctc_weight > 0.:
net_output = (feat_out, eos_out, extra)
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = F.ctc_loss(
lprobs, src_tokens_flat, tgt_lens, src_lens,
reduction=reduction, zero_infinity=True
) * self.ctc_weight
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
sample_size = sample["nsentences"] if self.sentence_avg \
else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt,
eos_tgt, tgt_lens, reduction="mean"):
mask = lengths_to_mask(tgt_lens)
_eos_out = eos_out[mask].squeeze()
_eos_tgt = eos_tgt[mask]
_feat_tgt = feat_tgt[mask]
_feat_out = feat_out[mask]
_feat_out_post = feat_out_post[mask]
l1_loss = (
F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) +
F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction)
)
mse_loss = (
F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) +
F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction)
)
eos_loss = F.binary_cross_entropy_with_logits(
_eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight),
reduction=reduction
)
return l1_loss, mse_loss, eos_loss
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False

View File

@ -0,0 +1,207 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
import csv
import logging
import os.path as op
from typing import List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig
)
from fairseq.data.audio.text_to_speech_dataset import (
TextToSpeechDataset, TextToSpeechDatasetCreator
)
logger = logging.getLogger(__name__)
class FrmTextToSpeechDataset(TextToSpeechDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
do_chunk=False,
chunk_bound=-1,
chunk_init=50,
chunk_incr=5,
add_eos=True,
dedup=True,
ref_fpu=-1
):
# It assumes texts are encoded at a fixed frame-rate
super().__init__(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id
)
self.do_chunk = do_chunk
self.chunk_bound = chunk_bound
self.chunk_init = chunk_init
self.chunk_incr = chunk_incr
self.add_eos = add_eos
self.dedup = dedup
self.ref_fpu = ref_fpu
self.chunk_size = -1
if do_chunk:
assert self.chunk_incr >= 0
assert self.pre_tokenizer is None
def __getitem__(self, index):
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
if target[-1].item() == self.tgt_dict.eos_index:
target = target[:-1]
fpu = source.size(0) / target.size(0) # frame-per-unit
fps = self.n_frames_per_step
assert (
self.ref_fpu == -1 or
abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
), f"{fpu*fps} != {self.ref_fpu}"
# only chunk training split
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag):]
size = len(text)
chunk_size = min(self.chunk_size, size)
chunk_start = np.random.randint(size - chunk_size + 1)
text = text[chunk_start:chunk_start+chunk_size]
target = torch.cat((lang, text), 0)
f_size = int(np.floor(chunk_size * fpu))
f_start = int(np.floor(chunk_start * fpu))
assert(f_size > 0)
source = source[f_start:f_start+f_size, :]
if self.dedup:
target = torch.unique_consecutive(target)
if self.add_eos:
eos_idx = self.tgt_dict.eos_index
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
return index, source, target, speaker_id
def set_epoch(self, epoch):
if self.is_train_split and self.do_chunk:
old = self.chunk_size
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
if self.chunk_bound > 0:
self.chunk_size = min(self.chunk_size, self.chunk_bound)
logger.info((
f"{self.split}: setting chunk size "
f"from {old} to {self.chunk_size}"
))
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
# inherit for key names
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2TDataConfig,
split: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
n_frames_per_step: int,
speaker_to_id,
do_chunk: bool = False,
chunk_bound: int = -1,
chunk_init: int = 50,
chunk_incr: int = 5,
add_eos: bool = True,
dedup: bool = True,
ref_fpu: float = -1
) -> FrmTextToSpeechDataset:
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
s = [dict(e) for e in reader]
assert len(s) > 0
ids = [ss[cls.KEY_ID] for ss in s]
audio_paths = [
op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s
]
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
return FrmTextToSpeechDataset(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
do_chunk=do_chunk,
chunk_bound=chunk_bound,
chunk_init=chunk_init,
chunk_incr=chunk_incr,
add_eos=add_eos,
dedup=dedup,
ref_fpu=ref_fpu
)

View File

@ -0,0 +1,215 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
import numpy as np
import torch
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig,
_collate_frames, get_features_or_waveform
)
from fairseq.data import Dictionary, data_utils as fairseq_data_utils
@dataclass
class TextToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
duration: Optional[torch.Tensor] = None
pitch: Optional[torch.Tensor] = None
energy: Optional[torch.Tensor] = None
class TextToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None
):
super(TextToSpeechDataset, self).__init__(
split, is_train_split, cfg, audio_paths, n_frames,
src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers,
src_langs=src_langs, tgt_langs=tgt_langs, ids=ids,
tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id
)
self.durations = durations
self.pitches = pitches
self.energies = energies
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
s2t_item = super().__getitem__(index)
duration, pitch, energy = None, None, None
if self.durations is not None:
duration = torch.tensor(
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
)
if self.pitches is not None:
pitch = get_features_or_waveform(self.pitches[index])
pitch = torch.from_numpy(
np.concatenate((pitch, [0])) # pad 0 for EOS
).float()
if self.energies is not None:
energy = get_features_or_waveform(self.energies[index])
energy = torch.from_numpy(
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToSpeechDatasetItem(
index=index, source=s2t_item.source, target=s2t_item.target,
speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch,
energy=energy
)
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
if len(samples) == 0:
return {}
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples],
dtype=torch.long).index_select(0, order)
feat = _collate_frames(
[s.source for s in samples], self.cfg.use_audio_input
).index_select(0, order)
target_lengths = torch.tensor(
[s.source.shape[0] for s in samples], dtype=torch.long
).index_select(0, order)
src_tokens = fairseq_data_utils.collate_tokens(
[s.target for s in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
).index_select(0, order)
speaker = None
if self.speaker_to_id is not None:
speaker = torch.tensor(
[s.speaker_id for s in samples], dtype=torch.long
).index_select(0, order).view(-1, 1)
bsz, _, d = feat.size()
prev_output_tokens = torch.cat(
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
)
durations, pitches, energies = None, None, None
if self.durations is not None:
durations = fairseq_data_utils.collate_tokens(
[s.duration for s in samples], 0
).index_select(0, order)
assert src_tokens.shape[1] == durations.shape[1]
if self.pitches is not None:
pitches = _collate_frames([s.pitch for s in samples], True)
pitches = pitches.index_select(0, order)
assert src_tokens.shape[1] == pitches.shape[1]
if self.energies is not None:
energies = _collate_frames([s.energy for s in samples], True)
energies = energies.index_select(0, order)
assert src_tokens.shape[1] == energies.shape[1]
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
return {
"id": id_,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens,
},
"speaker": speaker,
"target": feat,
"durations": durations,
"pitches": pitches,
"energies": energies,
"target_lengths": target_lengths,
"ntokens": sum(target_lengths).item(),
"nsentences": len(samples),
"src_texts": src_texts,
}
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
KEY_DURATION = "duration"
KEY_PITCH = "pitch"
KEY_ENERGY = "energy"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id
) -> TextToSpeechDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")]
for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix()
for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix()
for ee in energies]
energies = None if any(ee is None for ee in energies) else energies
return TextToSpeechDataset(
split_name, is_train_split, cfg, audio_paths, n_frames,
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id,
durations, pitches, energies
)

View File

@ -92,7 +92,8 @@ class Dictionary:
)
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
extra_symbols_to_ignore.add(self.eos())
if not include_eos:
extra_symbols_to_ignore.add(self.eos())
def token_string(i):
if i == self.unk():

View File

@ -5,5 +5,5 @@
from .berard import * # noqa
from .convtransformer import * # noqa
from .s2t_transformer import * # noqa
from .xm_transformer import * # noqa
from .s2t_transformer import * # noqa
from .xm_transformer import * # noqa

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#

View File

@ -0,0 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .tacotron2 import * # noqa
from .tts_transformer import * # noqa
from .fastspeech2 import * # noqa

View File

@ -0,0 +1,352 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
from torch import nn
from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model,
register_model_architecture)
from fairseq.modules import (
LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention
)
from fairseq import utils
from fairseq.data.data_utils import lengths_to_padding_mask
logger = logging.getLogger(__name__)
def model_init(m):
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
return m
class PositionwiseFeedForward(nn.Module):
def __init__(self, in_dim, hidden_dim, kernel_size, dropout):
super().__init__()
self.ffn = nn.Sequential(
nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size,
padding=(kernel_size - 1) // 2),
nn.ReLU(),
nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size,
padding=(kernel_size - 1) // 2)
)
self.layer_norm = LayerNorm(in_dim)
self.dropout = self.dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__
)
def forward(self, x):
# B x T x C
residual = x
x = self.ffn(x.transpose(1, 2)).transpose(1, 2)
x = self.dropout(x)
return self.layer_norm(x + residual)
class FFTLayer(torch.nn.Module):
def __init__(
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout,
attention_dropout
):
super().__init__()
self.self_attn = MultiheadAttention(
embed_dim, n_heads, dropout=attention_dropout, self_attention=True
)
self.layer_norm = LayerNorm(embed_dim)
self.ffn = PositionwiseFeedForward(
embed_dim, hidden_dim, kernel_size, dropout=dropout
)
def forward(self, x, padding_mask=None):
# B x T x C
residual = x
x = x.transpose(0, 1)
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=padding_mask,
need_weights=False
)
x = x.transpose(0, 1)
x = self.layer_norm(x + residual)
return self.ffn(x)
class LengthRegulator(nn.Module):
def forward(self, x, durations):
# x: B x T x C
out_lens = durations.sum(dim=1)
max_len = out_lens.max()
bsz, seq_len, dim = x.size()
out = x.new_zeros((bsz, max_len, dim))
for b in range(bsz):
indices = []
for t in range(seq_len):
indices.extend([t] * utils.item(durations[b, t]))
indices = torch.tensor(indices, dtype=torch.long).to(x.device)
out_len = utils.item(out_lens[b])
out[b, :out_len] = x[b].index_select(0, indices)
return out, out_lens
class VariancePredictor(nn.Module):
def __init__(self, args):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(
args.encoder_embed_dim, args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size,
padding=(args.var_pred_kernel_size - 1) // 2
),
nn.ReLU()
)
self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim)
self.dropout_module = FairseqDropout(
p=args.var_pred_dropout, module_name=self.__class__.__name__
)
self.conv2 = nn.Sequential(
nn.Conv1d(
args.var_pred_hidden_dim, args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size, padding=1
),
nn.ReLU()
)
self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim)
self.proj = nn.Linear(args.var_pred_hidden_dim, 1)
def forward(self, x):
# Input: B x T x C; Output: B x T
x = self.conv1(x.transpose(1, 2)).transpose(1, 2)
x = self.dropout_module(self.ln1(x))
x = self.conv2(x.transpose(1, 2)).transpose(1, 2)
x = self.dropout_module(self.ln2(x))
return self.proj(x).squeeze(dim=2)
class VarianceAdaptor(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.length_regulator = LengthRegulator()
self.duration_predictor = VariancePredictor(args)
self.pitch_predictor = VariancePredictor(args)
self.energy_predictor = VariancePredictor(args)
n_bins, steps = self.args.var_pred_n_bins, self.args.var_pred_n_bins - 1
self.pitch_bins = torch.linspace(args.pitch_min, args.pitch_max, steps)
self.embed_pitch = Embedding(n_bins, args.encoder_embed_dim)
self.energy_bins = torch.linspace(args.energy_min, args.energy_max, steps)
self.embed_energy = Embedding(n_bins, args.encoder_embed_dim)
def get_pitch_emb(self, x, tgt=None, factor=1.0):
out = self.pitch_predictor(x)
bins = self.pitch_bins.to(x.device)
if tgt is None:
out = out * factor
emb = self.embed_pitch(torch.bucketize(out, bins))
else:
emb = self.embed_pitch(torch.bucketize(tgt, bins))
return out, emb
def get_energy_emb(self, x, tgt=None, factor=1.0):
out = self.energy_predictor(x)
bins = self.energy_bins.to(x.device)
if tgt is None:
out = out * factor
emb = self.embed_energy(torch.bucketize(out, bins))
else:
emb = self.embed_energy(torch.bucketize(tgt, bins))
return out, emb
def forward(
self, x, padding_mask, durations=None, pitches=None, energies=None,
d_factor=1.0, p_factor=1.0, e_factor=1.0
):
# x: B x T x C
log_dur_out = self.duration_predictor(x)
dur_out = torch.clamp(
torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0
)
dur_out.masked_fill_(padding_mask, 0)
pitch_out, pitch_emb = self.get_pitch_emb(x, pitches, p_factor)
x = x + pitch_emb
energy_out, energy_emb = self.get_energy_emb(x, energies, e_factor)
x = x + energy_emb
x, out_lens = self.length_regulator(
x, dur_out if durations is None else durations
)
return x, out_lens, log_dur_out, pitch_out, energy_out
class FastSpeech2Encoder(FairseqEncoder):
def __init__(self, args, src_dict, embed_speaker):
super().__init__(src_dict)
self.args = args
self.padding_idx = src_dict.pad()
self.n_frames_per_step = args.n_frames_per_step
self.out_dim = args.output_frame_dim * args.n_frames_per_step
self.embed_speaker = embed_speaker
self.spk_emb_proj = None
if embed_speaker is not None:
self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim,
args.encoder_embed_dim
)
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_tokens = Embedding(
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1))
self.encoder_fft_layers = nn.ModuleList(
FFTLayer(
args.encoder_embed_dim, args.encoder_attention_heads,
args.fft_hidden_dim, args.fft_kernel_size,
dropout=args.dropout, attention_dropout=args.attention_dropout
)
for _ in range(args.encoder_layers)
)
self.var_adaptor = VarianceAdaptor(args)
self.decoder_fft_layers = nn.ModuleList(
FFTLayer(
args.decoder_embed_dim, args.decoder_attention_heads,
args.fft_hidden_dim, args.fft_kernel_size,
dropout=args.dropout, attention_dropout=args.attention_dropout
)
for _ in range(args.decoder_layers)
)
self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
self.apply(model_init)
def forward(self, src_tokens, src_lengths=None, speaker=None,
durations=None, pitches=None, energies=None, **kwargs):
x = self.embed_tokens(src_tokens)
enc_padding_mask = src_tokens.eq(self.padding_idx)
x += self.pos_emb_alpha * self.embed_positions(enc_padding_mask)
x = self.dropout_module(x)
for layer in self.encoder_fft_layers:
x = layer(x, enc_padding_mask)
if self.embed_speaker is not None:
bsz, seq_len, _ = x.size()
emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1)
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
x, out_lens, log_dur_out, pitch_out, energy_out = \
self.var_adaptor(x, enc_padding_mask, durations, pitches, energies)
dec_padding_mask = lengths_to_padding_mask(out_lens)
x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask)
for layer in self.decoder_fft_layers:
x = layer(x, dec_padding_mask)
x = self.out_proj(x)
return x, out_lens, log_dur_out, pitch_out, energy_out
@register_model("fastspeech2")
class FastSpeech2Model(FairseqEncoderModel):
"""
Implementation for https://arxiv.org/abs/2006.04558
"""
NON_AUTOREGRESSIVE = True
@staticmethod
def add_args(parser):
parser.add_argument("--dropout", type=float)
parser.add_argument("--output-frame-dim", type=int)
parser.add_argument("--speaker-embed-dim", type=int)
# FFT blocks
parser.add_argument("--fft-hidden-dim", type=int)
parser.add_argument("--fft-kernel-size", type=int)
parser.add_argument("--attention-dropout", type=float)
parser.add_argument("--encoder-layers", type=int)
parser.add_argument("--encoder-embed-dim", type=int)
parser.add_argument("--encoder-attention-heads", type=int)
parser.add_argument("--decoder-layers", type=int)
parser.add_argument("--decoder-embed-dim", type=int)
parser.add_argument("--decoder-attention-heads", type=int)
# variance predictor
parser.add_argument("--var-pred-n-bins", type=int)
parser.add_argument("--var-pred-hidden-dim", type=int)
parser.add_argument("--var-pred-kernel-size", type=int)
parser.add_argument("--var-pred-dropout", type=float)
def __init__(self, encoder, args, src_dict):
super().__init__(encoder)
self._num_updates = 0
out_dim = args.output_frame_dim * args.n_frames_per_step
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.:
self.ctc_proj = nn.Linear(out_dim, len(src_dict))
@classmethod
def build_model(cls, args, task):
embed_speaker = task.get_speaker_embeddings(args)
encoder = FastSpeech2Encoder(args, task.src_dict, embed_speaker)
return cls(encoder, args, task.src_dict)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
self._num_updates = num_updates
def get_normalized_probs(self, net_output, log_probs, sample=None):
logits = self.ctc_proj(net_output[0])
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
@register_model_architecture("fastspeech2", "fastspeech2")
def base_architecture(args):
args.dropout = getattr(args, "dropout", 0.2)
args.output_frame_dim = getattr(args, "output_frame_dim", 80)
args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
# FFT blocks
args.fft_hidden_dim = getattr(args, "fft_hidden_dim", 1024)
args.fft_kernel_size = getattr(args, "fft_kernel_size", 9)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.encoder_layers = getattr(args, "encoder_layers", 4)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
args.decoder_layers = getattr(args, "decoder_layers", 4)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
# variance predictor
args.var_pred_n_bins = getattr(args, "var_pred_n_bins", 256)
args.var_pred_hidden_dim = getattr(args, "var_pred_hidden_dim", 256)
args.var_pred_kernel_size = getattr(args, "var_pred_kernel_size", 3)
args.var_pred_dropout = getattr(args, "var_pred_dropout", 0.5)

View File

@ -0,0 +1,173 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class ResBlock(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
remove_weight_norm(layer)
for layer in self.convs2:
remove_weight_norm(layer)
class Generator(torch.nn.Module):
def __init__(self, cfg):
super(Generator, self).__init__()
self.num_kernels = len(cfg["resblock_kernel_sizes"])
self.num_upsamples = len(cfg["upsample_rates"])
self.conv_pre = weight_norm(
Conv1d(80, cfg["upsample_initial_channel"], 7, 1, padding=3)
)
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
):
self.ups.append(
weight_norm(
ConvTranspose1d(
cfg["upsample_initial_channel"] // (2 ** i),
cfg["upsample_initial_channel"] // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = cfg["upsample_initial_channel"] // (2 ** (i + 1))
for k, d in zip(
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
):
self.resblocks.append(ResBlock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for layer in self.ups:
remove_weight_norm(layer)
for layer in self.resblocks:
layer.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)

View File

@ -0,0 +1,350 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
from torch import nn
from torch.nn import functional as F
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
FairseqIncrementalDecoder, register_model,
register_model_architecture)
from fairseq.modules import LSTMCellWithZoneOut, LocationAttention
logger = logging.getLogger(__name__)
def encoder_init(m):
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
class Tacotron2Encoder(FairseqEncoder):
def __init__(self, args, src_dict, embed_speaker):
super().__init__(src_dict)
self.padding_idx = src_dict.pad()
self.embed_speaker = embed_speaker
self.spk_emb_proj = None
if embed_speaker is not None:
self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim,
args.encoder_embed_dim
)
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
padding_idx=self.padding_idx)
assert(args.encoder_conv_kernel_size % 2 == 1)
self.convolutions = nn.ModuleList(
nn.Sequential(
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2)),
nn.BatchNorm1d(args.encoder_embed_dim),
nn.ReLU(),
nn.Dropout(args.encoder_dropout)
)
for _ in range(args.encoder_conv_layers)
)
self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2,
num_layers=args.encoder_lstm_layers,
batch_first=True, bidirectional=True)
self.apply(encoder_init)
def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs):
x = self.embed_tokens(src_tokens)
x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T
for conv in self.convolutions:
x = conv(x)
x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C
src_lengths = src_lengths.cpu().long()
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
x = self.lstm(x)[0]
x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0]
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if self.embed_speaker is not None:
seq_len, bsz, _ = x.size()
emb = self.embed_speaker(speaker).expand(seq_len, bsz, -1)
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
return {
"encoder_out": [x], # B x T x C
"encoder_padding_mask": encoder_padding_mask, # B x T
}
class Prenet(nn.Module):
def __init__(self, in_dim, n_layers, n_units, dropout):
super().__init__()
self.layers = nn.ModuleList(
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units),
nn.ReLU())
for i in range(n_layers)
)
self.dropout = dropout
def forward(self, x):
for layer in self.layers:
x = F.dropout(layer(x), p=self.dropout) # always applies dropout
return x
class Postnet(nn.Module):
def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
assert(kernel_size % 2 == 1)
for i in range(n_layers):
cur_layers = [
nn.Conv1d(in_dim if i == 0 else n_channels,
n_channels if i < n_layers - 1 else in_dim,
kernel_size=kernel_size,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim)
] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)]
nn.init.xavier_uniform_(
cur_layers[0].weight,
torch.nn.init.calculate_gain(
"tanh" if i < n_layers - 1 else "linear"
)
)
self.convolutions.append(nn.Sequential(*cur_layers))
def forward(self, x):
x = x.transpose(1, 2) # B x T x C -> B x C x T
for conv in self.convolutions:
x = conv(x)
return x.transpose(1, 2)
def decoder_init(m):
if isinstance(m, torch.nn.Conv1d):
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
class Tacotron2Decoder(FairseqIncrementalDecoder):
def __init__(self, args, src_dict):
super().__init__(None)
self.args = args
self.n_frames_per_step = args.n_frames_per_step
self.out_dim = args.output_frame_dim * args.n_frames_per_step
self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
args.prenet_dropout)
# take prev_context, prev_frame, (speaker embedding) as input
self.attention_lstm = LSTMCellWithZoneOut(
args.zoneout,
args.prenet_dim + args.encoder_embed_dim,
args.decoder_lstm_dim
)
# take attention_lstm output, attention_state, encoder_out as input
self.attention = LocationAttention(
args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim,
(1 + int(args.attention_use_cumprob)),
args.attention_conv_dim, args.attention_conv_kernel_size
)
# take attention_lstm output, context, (gated_latent) as input
self.lstm = nn.ModuleList(
LSTMCellWithZoneOut(
args.zoneout,
args.encoder_embed_dim + args.decoder_lstm_dim,
args.decoder_lstm_dim
)
for i in range(args.decoder_lstm_layers)
)
proj_in_dim = args.encoder_embed_dim + args.decoder_lstm_dim
self.feat_proj = nn.Linear(proj_in_dim, self.out_dim)
self.eos_proj = nn.Linear(proj_in_dim, 1)
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers, args.postnet_dropout)
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.:
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
self.apply(decoder_init)
def _get_states(self, incremental_state, enc_out):
bsz, in_len, _ = enc_out.size()
alstm_h = self.get_incremental_state(incremental_state, "alstm_h")
if alstm_h is None:
alstm_h = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
alstm_c = self.get_incremental_state(incremental_state, "alstm_c")
if alstm_c is None:
alstm_c = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
lstm_h = self.get_incremental_state(incremental_state, "lstm_h")
if lstm_h is None:
lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)]
lstm_c = self.get_incremental_state(incremental_state, "lstm_c")
if lstm_c is None:
lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)]
attn_w = self.get_incremental_state(incremental_state, "attn_w")
if attn_w is None:
attn_w = enc_out.new_zeros(bsz, in_len)
attn_w_cum = self.get_incremental_state(incremental_state, "attn_w_cum")
if attn_w_cum is None:
attn_w_cum = enc_out.new_zeros(bsz, in_len)
return alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum
def _get_init_attn_c(self, enc_out, enc_mask):
bsz = enc_out.size(0)
if self.args.init_attn_c == "zero":
return enc_out.new_zeros(bsz, self.args.encoder_embed_dim)
elif self.args.init_attn_c == "avg":
enc_w = (~enc_mask).type(enc_out.type())
enc_w = enc_w / enc_w.sum(dim=1, keepdim=True)
return torch.sum(enc_out * enc_w.unsqueeze(2), dim=1)
else:
raise ValueError(f"{self.args.init_attn_c} not supported")
def forward(self, prev_output_tokens, encoder_out=None,
incremental_state=None, target_lengths=None, **kwargs):
enc_mask = encoder_out["encoder_padding_mask"]
enc_out = encoder_out["encoder_out"][0]
in_len = enc_out.size(1)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:, :]
bsz, out_len, _ = prev_output_tokens.size()
prenet_out = self.prenet(prev_output_tokens)
(alstm_h, alstm_c, lstm_h, lstm_c,
attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out)
attn_ctx = self._get_init_attn_c(enc_out, enc_mask)
attn_out = enc_out.new_zeros(bsz, in_len, out_len)
feat_out = enc_out.new_zeros(bsz, out_len, self.out_dim)
eos_out = enc_out.new_zeros(bsz, out_len)
for t in range(out_len):
alstm_in = torch.cat((attn_ctx, prenet_out[:, t, :]), dim=1)
alstm_h, alstm_c = self.attention_lstm(alstm_in, (alstm_h, alstm_c))
attn_state = attn_w.unsqueeze(1)
if self.args.attention_use_cumprob:
attn_state = torch.stack((attn_w, attn_w_cum), dim=1)
attn_ctx, attn_w = self.attention(
enc_out, enc_mask, alstm_h, attn_state
)
attn_w_cum = attn_w_cum + attn_w
attn_out[:, :, t] = attn_w
for i, cur_lstm in enumerate(self.lstm):
if i == 0:
lstm_in = torch.cat((attn_ctx, alstm_h), dim=1)
else:
lstm_in = torch.cat((attn_ctx, lstm_h[i - 1]), dim=1)
lstm_h[i], lstm_c[i] = cur_lstm(lstm_in, (lstm_h[i], lstm_c[i]))
proj_in = torch.cat((attn_ctx, lstm_h[-1]), dim=1)
feat_out[:, t, :] = self.feat_proj(proj_in)
eos_out[:, t] = self.eos_proj(proj_in).squeeze(1)
self.attention.clear_cache()
self.set_incremental_state(incremental_state, "alstm_h", alstm_h)
self.set_incremental_state(incremental_state, "alstm_c", alstm_c)
self.set_incremental_state(incremental_state, "lstm_h", lstm_h)
self.set_incremental_state(incremental_state, "lstm_c", lstm_c)
self.set_incremental_state(incremental_state, "attn_w", attn_w)
self.set_incremental_state(incremental_state, "attn_w_cum", attn_w_cum)
post_feat_out = feat_out + self.postnet(feat_out)
eos_out = eos_out.view(bsz, out_len, 1)
return post_feat_out, eos_out, {"attn": attn_out, "feature_out": feat_out}
@register_model("tacotron_2")
class Tacotron2Model(FairseqEncoderDecoderModel):
"""
Implementation for https://arxiv.org/pdf/1712.05884.pdf
"""
@staticmethod
def add_args(parser):
# encoder
parser.add_argument("--encoder-dropout", type=float)
parser.add_argument("--encoder-embed-dim", type=int)
parser.add_argument("--encoder-conv-layers", type=int)
parser.add_argument("--encoder-conv-kernel-size", type=int)
parser.add_argument("--encoder-lstm-layers", type=int)
# decoder
parser.add_argument("--attention-dim", type=int)
parser.add_argument("--attention-conv-dim", type=int)
parser.add_argument("--attention-conv-kernel-size", type=int)
parser.add_argument("--prenet-dropout", type=float)
parser.add_argument("--prenet-layers", type=int)
parser.add_argument("--prenet-dim", type=int)
parser.add_argument("--postnet-dropout", type=float)
parser.add_argument("--postnet-layers", type=int)
parser.add_argument("--postnet-conv-dim", type=int)
parser.add_argument("--postnet-conv-kernel-size", type=int)
parser.add_argument("--init-attn-c", type=str)
parser.add_argument("--attention-use-cumprob", action='store_true')
parser.add_argument("--zoneout", type=float)
parser.add_argument("--decoder-lstm-layers", type=int)
parser.add_argument("--decoder-lstm-dim", type=int)
parser.add_argument("--output-frame-dim", type=int)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_updates = 0
@classmethod
def build_model(cls, args, task):
embed_speaker = task.get_speaker_embeddings(args)
encoder = Tacotron2Encoder(args, task.src_dict, embed_speaker)
decoder = Tacotron2Decoder(args, task.src_dict)
return cls(encoder, decoder)
def forward_encoder(self, src_tokens, src_lengths, **kwargs):
return self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
self._num_updates = num_updates
@register_model_architecture("tacotron_2", "tacotron_2")
def base_architecture(args):
# encoder
args.encoder_dropout = getattr(args, "encoder_dropout", 0.5)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3)
args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5)
args.encoder_lstm_layers = getattr(args, "encoder_lstm_layers", 1)
# decoder
args.attention_dim = getattr(args, "attention_dim", 128)
args.attention_conv_dim = getattr(args, "attention_conv_dim", 32)
args.attention_conv_kernel_size = getattr(args,
"attention_conv_kernel_size", 15)
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
args.prenet_layers = getattr(args, "prenet_layers", 2)
args.prenet_dim = getattr(args, "prenet_dim", 256)
args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
args.postnet_layers = getattr(args, "postnet_layers", 5)
args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
args.init_attn_c = getattr(args, "init_attn_c", "zero")
args.attention_use_cumprob = getattr(args, "attention_use_cumprob", True)
args.zoneout = getattr(args, "zoneout", 0.1)
args.decoder_lstm_layers = getattr(args, "decoder_lstm_layers", 2)
args.decoder_lstm_dim = getattr(args, "decoder_lstm_dim", 1024)
args.output_frame_dim = getattr(args, "output_frame_dim", 80)

View File

@ -0,0 +1,371 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional
import torch
from torch import nn
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
FairseqIncrementalDecoder, register_model,
register_model_architecture)
from fairseq.modules import (
TransformerEncoderLayer, TransformerDecoderLayer
)
from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet
from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq import utils
logger = logging.getLogger(__name__)
def encoder_init(m):
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
def Embedding(num_embeddings, embedding_dim):
m = nn.Embedding(num_embeddings, embedding_dim)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
return m
class TTSTransformerEncoder(FairseqEncoder):
def __init__(self, args, src_dict, embed_speaker):
super().__init__(src_dict)
self.padding_idx = src_dict.pad()
self.embed_speaker = embed_speaker
self.spk_emb_proj = None
if embed_speaker is not None:
self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim,
args.encoder_embed_dim
)
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
padding_idx=self.padding_idx)
assert(args.encoder_conv_kernel_size % 2 == 1)
self.prenet = nn.ModuleList(
nn.Sequential(
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2)),
nn.BatchNorm1d(args.encoder_embed_dim),
nn.ReLU(),
nn.Dropout(args.encoder_dropout),
)
for _ in range(args.encoder_conv_layers)
)
self.prenet_proj = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
self.transformer_layers = nn.ModuleList(
TransformerEncoderLayer(args)
for _ in range(args.encoder_transformer_layers)
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim)
else:
self.layer_norm = None
self.apply(encoder_init)
def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs):
x = self.embed_tokens(src_tokens)
x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T
for conv in self.prenet:
x = conv(x)
x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C
x = self.prenet_proj(x)
padding_mask = src_tokens.eq(self.padding_idx)
positions = self.embed_positions(padding_mask)
x += self.pos_emb_alpha * positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
for layer in self.transformer_layers:
x = layer(x, padding_mask)
if self.layer_norm is not None:
x = self.layer_norm(x)
if self.embed_speaker is not None:
seq_len, bsz, _ = x.size()
emb = self.embed_speaker(speaker).transpose(0, 1)
emb = emb.expand(seq_len, bsz, -1)
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
def decoder_init(m):
if isinstance(m, torch.nn.Conv1d):
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
class TTSTransformerDecoder(FairseqIncrementalDecoder):
def __init__(self, args, src_dict):
super().__init__(None)
self._future_mask = torch.empty(0)
self.args = args
self.padding_idx = src_dict.pad()
self.n_frames_per_step = args.n_frames_per_step
self.out_dim = args.output_frame_dim * args.n_frames_per_step
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.embed_positions = PositionalEmbedding(
args.max_target_positions, args.decoder_embed_dim, self.padding_idx
)
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
self.prenet = nn.Sequential(
Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
args.prenet_dropout),
nn.Linear(args.prenet_dim, args.decoder_embed_dim),
)
self.n_transformer_layers = args.decoder_transformer_layers
self.transformer_layers = nn.ModuleList(
TransformerDecoderLayer(args)
for _ in range(self.n_transformer_layers)
)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(args.decoder_embed_dim)
else:
self.layer_norm = None
self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers, args.postnet_dropout)
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.:
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
self.apply(decoder_init)
def extract_features(
self, prev_outputs, encoder_out=None, incremental_state=None,
target_lengths=None, speaker=None, **kwargs
):
alignment_layer = self.n_transformer_layers - 1
self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
positions = self.embed_positions(
self_attn_padding_mask, incremental_state=incremental_state
)
if incremental_state is not None:
prev_outputs = prev_outputs[:, -1:, :]
self_attn_padding_mask = self_attn_padding_mask[:, -1:]
if positions is not None:
positions = positions[:, -1:]
x = self.prenet(prev_outputs)
x += self.pos_emb_alpha * positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if not self_attn_padding_mask.any():
self_attn_padding_mask = None
attn: Optional[torch.Tensor] = None
inner_states: List[Optional[torch.Tensor]] = [x]
for idx, transformer_layer in enumerate(self.transformer_layers):
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = transformer_layer(
x,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
# average probabilities over heads, transpose to
# (B, src_len, tgt_len)
attn = attn.mean(dim=0).transpose(2, 1)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, {"attn": attn, "inner_states": inner_states}
def forward(self, prev_output_tokens, encoder_out=None,
incremental_state=None, target_lengths=None, speaker=None,
**kwargs):
x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out,
incremental_state=incremental_state, target_lengths=target_lengths,
speaker=speaker, **kwargs
)
attn = extra["attn"]
feat_out = self.feat_proj(x)
bsz, seq_len, _ = x.size()
eos_out = self.eos_proj(x)
post_feat_out = feat_out + self.postnet(feat_out)
return post_feat_out, eos_out, {"attn": attn, "feature_out": feat_out}
def get_normalized_probs(self, net_output, log_probs, sample):
logits = self.ctc_proj(net_output[2]["feature_out"])
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
@register_model("tts_transformer")
class TTSTransformerModel(FairseqEncoderDecoderModel):
"""
Implementation for https://arxiv.org/pdf/1809.08895.pdf
"""
@staticmethod
def add_args(parser):
parser.add_argument("--dropout", type=float)
parser.add_argument("--output-frame-dim", type=int)
parser.add_argument("--speaker-embed-dim", type=int)
# encoder prenet
parser.add_argument("--encoder-dropout", type=float)
parser.add_argument("--encoder-conv-layers", type=int)
parser.add_argument("--encoder-conv-kernel-size", type=int)
# encoder transformer layers
parser.add_argument("--encoder-transformer-layers", type=int)
parser.add_argument("--encoder-embed-dim", type=int)
parser.add_argument("--encoder-ffn-embed-dim", type=int)
parser.add_argument("--encoder-normalize-before", action="store_true")
parser.add_argument("--encoder-attention-heads", type=int)
parser.add_argument("--attention-dropout", type=float)
parser.add_argument("--activation-dropout", "--relu-dropout", type=float)
parser.add_argument("--activation-fn", type=str, default="relu")
# decoder prenet
parser.add_argument("--prenet-dropout", type=float)
parser.add_argument("--prenet-layers", type=int)
parser.add_argument("--prenet-dim", type=int)
# decoder postnet
parser.add_argument("--postnet-dropout", type=float)
parser.add_argument("--postnet-layers", type=int)
parser.add_argument("--postnet-conv-dim", type=int)
parser.add_argument("--postnet-conv-kernel-size", type=int)
# decoder transformer layers
parser.add_argument("--decoder-transformer-layers", type=int)
parser.add_argument("--decoder-embed-dim", type=int)
parser.add_argument("--decoder-ffn-embed-dim", type=int)
parser.add_argument("--decoder-normalize-before", action="store_true")
parser.add_argument("--decoder-attention-heads", type=int)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_updates = 0
@classmethod
def build_model(cls, args, task):
embed_speaker = task.get_speaker_embeddings(args)
encoder = TTSTransformerEncoder(args, task.src_dict, embed_speaker)
decoder = TTSTransformerDecoder(args, task.src_dict)
return cls(encoder, decoder)
def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
return self.encoder(src_tokens, src_lengths=src_lengths,
speaker=speaker, **kwargs)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
self._num_updates = num_updates
@register_model_architecture("tts_transformer", "tts_transformer")
def base_architecture(args):
args.dropout = getattr(args, "dropout", 0.1)
args.output_frame_dim = getattr(args, "output_frame_dim", 80)
args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
# encoder prenet
args.encoder_dropout = getattr(args, "encoder_dropout", 0.5)
args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3)
args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5)
# encoder transformer layers
args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
# decoder prenet
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
args.prenet_layers = getattr(args, "prenet_layers", 2)
args.prenet_dim = getattr(args, "prenet_dim", 256)
# decoder postnet
args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
args.postnet_layers = getattr(args, "postnet_layers", 5)
args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
# decoder transformer layers
args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)

View File

@ -0,0 +1,197 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import json
from typing import Dict
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from fairseq.data.audio.audio_utils import (
get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram
)
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel
logger = logging.getLogger(__name__)
class PseudoInverseMelScale(torch.nn.Module):
def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None:
super(PseudoInverseMelScale, self).__init__()
self.n_mels = n_mels
basis = get_mel_filters(
sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max
)
basis = torch.pinverse(basis) # F x F_mel
self.register_buffer('basis', basis)
def forward(self, melspec: torch.Tensor) -> torch.Tensor:
# pack batch
shape = melspec.shape # B_1 x ... x B_K x F_mel x T
n_mels, time = shape[-2], shape[-1]
melspec = melspec.view(-1, n_mels, time)
freq, _ = self.basis.size() # F x F_mel
assert self.n_mels == n_mels, (self.n_mels, n_mels)
specgram = self.basis.matmul(melspec).clamp(min=0)
# unpack batch
specgram = specgram.view(shape[:-2] + (freq, time))
return specgram
class GriffinLim(torch.nn.Module):
def __init__(
self, n_fft: int, win_length: int, hop_length: int, n_iter: int,
window_fn=torch.hann_window
):
super(GriffinLim, self).__init__()
self.transform = TTSSpectrogram(
n_fft, win_length, hop_length, return_phase=True
)
basis = get_fourier_basis(n_fft)
basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :]
basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer('basis', basis)
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.n_iter = n_iter
self.tiny = 1.1754944e-38
@classmethod
def get_window_sum_square(
cls, n_frames, hop_length, win_length, n_fft,
window_fn=torch.hann_window
) -> torch.Tensor:
w_sq = get_window(window_fn, n_fft, win_length) ** 2
n = n_fft + hop_length * (n_frames - 1)
x = torch.zeros(n, dtype=torch.float32)
for i in range(n_frames):
ofst = i * hop_length
x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))]
return x
def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor:
x = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)],
dim=1
)
x = F.conv_transpose1d(x, self.basis, stride=self.hop_length)
win_sum_sq = self.get_window_sum_square(
magnitude.shape[-1], hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.n_fft
).to(magnitude.device)
# remove modulation effects
approx_nonzero_indices = win_sum_sq > self.tiny
x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices]
x *= self.n_fft / self.hop_length
x = x[:, :, self.n_fft // 2:]
x = x[:, :, :-self.n_fft // 2:]
return x
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*specgram.shape)))
angles = torch.from_numpy(angles).to(specgram)
_specgram = specgram.view(-1, specgram.shape[-2], specgram.shape[-1])
waveform = self.inverse(_specgram, angles).squeeze(1)
for _ in range(self.n_iter):
_, angles = self.transform(waveform)
waveform = self.inverse(_specgram, angles).squeeze(1)
return waveform.squeeze(0)
class GriffinLimVocoder(nn.Module):
def __init__(self, sample_rate, win_size, hop_size, n_fft,
n_mels, f_min, f_max, window_fn,
spec_bwd_max_iter=32,
fp16=False):
super().__init__()
self.inv_mel_transform = PseudoInverseMelScale(
n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate,
f_min=f_min, f_max=f_max
)
self.gl_transform = GriffinLim(
n_fft=n_fft, win_length=win_size, hop_length=hop_size,
window_fn=window_fn, n_iter=spec_bwd_max_iter
)
if fp16:
self.half()
self.inv_mel_transform.half()
self.gl_transform.half()
else:
self.float()
self.inv_mel_transform.float()
self.gl_transform.float()
def forward(self, x):
# x: (B x) T x D -> (B x) 1 x T
# NOTE: batched forward produces noisier waveform. recommend running
# one utterance at a time
self.eval()
x = x.exp().transpose(-1, -2)
x = self.inv_mel_transform(x)
x = self.gl_transform(x)
return x
@classmethod
def from_data_cfg(cls, args, data_cfg: S2TDataConfig):
feat_cfg = data_cfg.config["features"]
window_fn = getattr(torch, feat_cfg["window_fn"] + "_window")
return cls(
sample_rate=feat_cfg["sample_rate"],
win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]),
hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]),
n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"],
f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"],
window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter,
fp16=args.fp16
)
class HiFiGANVocoder(nn.Module):
def __init__(
self, checkpoint_path: str, model_cfg: Dict[str, str],
fp16: bool = False
) -> None:
super().__init__()
self.model = HiFiGANModel(model_cfg)
state_dict = torch.load(checkpoint_path)
self.model.load_state_dict(state_dict["generator"])
if fp16:
self.model.half()
logger.info(f"loaded HiFiGAN checkpoint from {checkpoint_path}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
# (B x) T x D -> (B x) 1 x T
model = self.model.eval()
if len(x.shape) == 2:
return model(x.unsqueeze(0).transpose(1, 2)).detach().squeeze(0)
else:
return model(x.transpose(-1, -2)).detach()
@classmethod
def from_data_cfg(cls, args, data_cfg: S2TDataConfig):
vocoder_cfg = data_cfg.vocoder
assert vocoder_cfg.get("type", "griffin_lim") == "hifigan"
with open(vocoder_cfg["config"]) as f:
model_cfg = json.load(f)
return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16)
def get_vocoder(args, data_cfg: S2TDataConfig):
if args.vocoder == "griffin_lim":
return GriffinLimVocoder.from_data_cfg(args, data_cfg)
elif args.vocoder == "hifigan":
return HiFiGANVocoder.from_data_cfg(args, data_cfg)
else:
raise ValueError("Unknown vocoder")

View File

@ -25,6 +25,8 @@ from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
from .location_attention import LocationAttention
from .lstm_cell_with_zoneout import LSTMCellWithZoneOut
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .same_pad import SamePad
@ -63,6 +65,8 @@ __all__ = [
"LightweightConv1dTBC",
"LightweightConv",
"LinearizedConvolution",
"LocationAttention",
"LSTMCellWithZoneOut",
"MultiheadAttention",
"PositionalEmbedding",
"SamePad",

View File

@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch
import torch.nn.functional as F
class LocationAttention(nn.Module):
"""
Attention-Based Models for Speech Recognition
https://arxiv.org/pdf/1506.07503.pdf
:param int encoder_dim: # projection-units of encoder
:param int decoder_dim: # units of decoder
:param int attn_dim: attention dimension
:param int conv_dim: # channels of attention convolution
:param int conv_kernel_size: filter size of attention convolution
"""
def __init__(self, attn_dim, encoder_dim, decoder_dim,
attn_state_kernel_size, conv_dim, conv_kernel_size,
scaling=2.0):
super(LocationAttention, self).__init__()
self.attn_dim = attn_dim
self.decoder_dim = decoder_dim
self.scaling = scaling
self.proj_enc = nn.Linear(encoder_dim, attn_dim)
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim,
2 * conv_kernel_size + 1,
padding=conv_kernel_size, bias=False)
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
self.proj_enc_out = None # cache
def clear_cache(self):
self.proj_enc_out = None
def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state):
"""
:param torch.Tensor encoder_out: padded encoder hidden state B x T x D
:param torch.Tensor encoder_padding_mask: encoder padding mask
:param torch.Tensor decoder_h: decoder hidden state B x D
:param torch.Tensor attn_prev: previous attention weight B x K x T
:return: attention weighted encoder state (B, D)
:rtype: torch.Tensor
:return: previous attention weights (B x T)
:rtype: torch.Tensor
"""
bsz, seq_len, _ = encoder_out.size()
if self.proj_enc_out is None:
self.proj_enc_out = self.proj_enc(encoder_out)
# B x K x T -> B x C x T
attn = self.conv(attn_state)
# B x C x T -> B x T x C -> B x T x D
attn = self.proj_attn(attn.transpose(1, 2))
if decoder_h is None:
decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim)
dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim)
out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2)
out.masked_fill_(encoder_padding_mask, -float("inf"))
w = F.softmax(self.scaling * out, dim=1)
c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1)
return c, w

View File

@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
class LSTMCellWithZoneOut(nn.Module):
"""
Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations
https://arxiv.org/abs/1606.01305
"""
def __init__(self, prob: float, input_size: int, hidden_size: int,
bias: bool = True):
super(LSTMCellWithZoneOut, self).__init__()
self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
self.prob = prob
if prob > 1.0 or prob < 0.0:
raise ValueError("zoneout probability must be in the range from "
"0.0 to 1.0.")
def zoneout(self, h, next_h, prob):
if isinstance(h, tuple):
return tuple(
[self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]
)
if self.training:
mask = h.new_zeros(*h.size()).bernoulli_(prob)
return mask * h + (1 - mask) * next_h
return prob * h + (1 - prob) * next_h
def forward(self, x, h):
return self.zoneout(h, self.lstm_cell(x, h), self.prob)

View File

@ -0,0 +1,86 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Collection
from dataclasses import dataclass, field
from typing import List
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class StepLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is cfg.lr"
},
)
lr: List[float] = field(
default=II("optimization.lr"),
metadata={"help": "max learning rate, must be more than cfg.min_lr"},
)
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"})
lr_decay: float = field(default=0.5, metadata={"help": "decay factor"})
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
class StepLRSchedule(FairseqLRScheduler):
"""Decay learning rate every k updates by a fixed factor
"""
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
super().__init__(cfg, fairseq_optimizer)
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
self.min_lr = cfg.min_lr
self.lr_deacy_period = cfg.lr_deacy_period
self.lr_decay = cfg.lr_decay
self.warmup_updates = cfg.warmup_updates
self.warmup_init_lr = (
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
)
assert(self.lr_deacy_period > 0)
assert(self.lr_decay <= 1)
assert(self.min_lr >= 0)
assert(self.max_lr > self.min_lr)
if cfg.warmup_updates > 0:
# linearly warmup for the first cfg.warmup_updates
self.warmup_lr_step = (
(self.max_lr - self.warmup_init_lr) / self.warmup_updates
)
else:
self.warmup_lr_step = 1
# initial learning rate
self.lr = self.warmup_init_lr
self.optimizer.set_lr(self.lr)
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.cfg.warmup_updates:
self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step
else:
curr_updates = num_updates - self.cfg.warmup_updates
lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period)
self.lr = max(self.max_lr * lr_mult, self.min_lr)
self.optimizer.set_lr(self.lr)
return self.lr

View File

@ -56,6 +56,14 @@ def get_generation_parser(interactive=False, default_task="translation"):
return parser
def get_speech_generation_parser(default_task="text_to_speech"):
parser = get_parser("Speech Generation", default_task)
add_dataset_args(parser, gen=True)
add_distributed_training_args(parser, default_world_size=1)
add_speech_generation_args(parser)
return parser
def get_interactive_generation_parser(default_task="translation"):
return get_generation_parser(interactive=True, default_task=default_task)
@ -344,6 +352,16 @@ def add_generation_args(parser):
return group
def add_speech_generation_args(parser):
group = parser.add_argument_group("Speech Generation")
add_common_eval_args(group) # NOTE: remove_bpe is not needed
# fmt: off
group.add_argument('--eos_prob_threshold', default=0.5, type=float,
help='terminate when eos probability exceeds this')
# fmt: on
return group
def add_interactive_args(parser):
group = parser.add_argument_group("Interactive")
gen_parser_from_dataclass(group, InteractiveConfig())

219
fairseq/speech_generator.py Normal file
View File

@ -0,0 +1,219 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
class SpeechGenerator(object):
def __init__(self, model, vocoder, data_cfg: S2TDataConfig):
self.model = model
self.vocoder = vocoder
stats_npz_path = data_cfg.global_cmvn_stats_npz
self.gcmvn_stats = None
if stats_npz_path is not None:
self.gcmvn_stats = np.load(stats_npz_path)
def gcmvn_denormalize(self, x):
# x: B x T x C
if self.gcmvn_stats is None:
return x
mean = torch.from_numpy(self.gcmvn_stats["mean"]).to(x)
std = torch.from_numpy(self.gcmvn_stats["std"]).to(x)
assert len(x.shape) == 3 and mean.shape[0] == std.shape[0] == x.shape[2]
x = x * std.view(1, 1, -1).expand_as(x)
return x + mean.view(1, 1, -1).expand_as(x)
def get_waveform(self, feat):
# T x C -> T
return None if self.vocoder is None else self.vocoder(feat).squeeze(0)
class AutoRegressiveSpeechGenerator(SpeechGenerator):
def __init__(
self, model, vocoder, data_cfg, max_iter: int = 6000,
eos_prob_threshold: float = 0.5,
):
super().__init__(model, vocoder, data_cfg)
self.max_iter = max_iter
self.eos_prob_threshold = eos_prob_threshold
@torch.no_grad()
def generate(self, model, sample, has_targ=False, **kwargs):
model.eval()
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size()
n_frames_per_step = model.decoder.n_frames_per_step
out_dim = model.decoder.out_dim
raw_dim = out_dim // n_frames_per_step
# initialize
encoder_out = model.forward_encoder(src_tokens, src_lengths,
speaker=sample["speaker"])
incremental_state = {}
feat, attn, eos_prob = [], [], []
finished = src_tokens.new_zeros((bsz,)).bool()
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter)
prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim)
for step in range(self.max_iter):
cur_out_lens = out_lens.clone()
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
_, cur_eos_out, cur_extra = model.forward_decoder(
prev_feat_out, encoder_out=encoder_out,
incremental_state=incremental_state,
target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs
)
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
feat.append(cur_extra['feature_out'])
attn.append(cur_extra['attn'])
eos_prob.append(cur_eos_prob)
cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold)
out_lens.masked_fill_((~finished) & cur_finished, step + 1)
finished = finished | cur_finished
if finished.sum().item() == bsz:
break
prev_feat_out = cur_extra['feature_out']
feat = torch.cat(feat, dim=1)
feat = model.decoder.postnet(feat) + feat
eos_prob = torch.cat(eos_prob, dim=1)
attn = torch.cat(attn, dim=2)
alignment = attn.max(dim=1)[1]
feat = feat.reshape(bsz, -1, raw_dim)
feat = self.gcmvn_denormalize(feat)
eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1)
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
out_lens = out_lens * n_frames_per_step
finalized = [
{
'feature': feat[b, :out_len],
'eos_prob': eos_prob[b, :out_len],
'attn': attn[b, :, :out_len],
'alignment': alignment[b, :out_len],
'waveform': self.get_waveform(feat[b, :out_len]),
}
for b, out_len in zip(range(bsz), out_lens)
]
if has_targ:
assert sample["target"].size(-1) == out_dim
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
tgt_feats = self.gcmvn_denormalize(tgt_feats)
tgt_lens = sample["target_lengths"] * n_frames_per_step
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
finalized[b]["targ_feature"] = f[:l]
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
return finalized
class NonAutoregressiveSpeechGenerator(SpeechGenerator):
@torch.no_grad()
def generate(self, model, sample, has_targ=False, **kwargs):
model.eval()
bsz, max_src_len = sample["net_input"]["src_tokens"].size()
n_frames_per_step = model.encoder.n_frames_per_step
out_dim = model.encoder.out_dim
raw_dim = out_dim // n_frames_per_step
feat, out_lens, log_dur_out, _, _ = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=sample["target_lengths"],
speaker=sample["speaker"]
)
feat = feat.view(bsz, -1, raw_dim)
feat = self.gcmvn_denormalize(feat)
dur_out = torch.clamp(
torch.round(torch.exp(log_dur_out) - 1).long(), min=0
)
def get_dur_plot_data(d):
r = []
for i, dd in enumerate(d):
r += [i + 1] * dd.item()
return r
out_lens = out_lens * n_frames_per_step
finalized = [
{
'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
'waveform': self.get_waveform(
feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim])
),
'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])),
}
for b, l in zip(range(bsz), out_lens)
]
if has_targ:
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
tgt_feats = self.gcmvn_denormalize(tgt_feats)
tgt_lens = sample["target_lengths"] * n_frames_per_step
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
finalized[b]["targ_feature"] = f[:l]
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
return finalized
class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator):
@torch.no_grad()
def generate(self, model, sample, has_targ=False, **kwargs):
model.eval()
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
prev_out_tokens = sample["net_input"]["prev_output_tokens"]
tgt_lens = sample["target_lengths"]
n_frames_per_step = model.decoder.n_frames_per_step
raw_dim = model.decoder.out_dim // n_frames_per_step
bsz = src_tokens.shape[0]
feat, eos_prob, extra = model(
src_tokens, src_lens, prev_out_tokens, incremental_state=None,
target_lengths=tgt_lens, speaker=sample["speaker"]
)
attn = extra["attn"] # B x T_s x T_t
alignment = attn.max(dim=1)[1]
feat = feat.reshape(bsz, -1, raw_dim)
feat = self.gcmvn_denormalize(feat)
eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1)
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
tgt_lens = sample["target_lengths"] * n_frames_per_step
finalized = [
{
'feature': feat[b, :tgt_len],
'eos_prob': eos_prob[b, :tgt_len],
'attn': attn[b, :, :tgt_len],
'alignment': alignment[b, :tgt_len],
'waveform': self.get_waveform(feat[b, :tgt_len]),
}
for b, tgt_len in zip(range(bsz), tgt_lens)
]
if has_targ:
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
tgt_feats = self.gcmvn_denormalize(tgt_feats)
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
finalized[b]["targ_feature"] = f[:l]
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
return finalized

View File

@ -0,0 +1,56 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator
from fairseq.tasks import register_task
from fairseq.tasks.text_to_speech import TextToSpeechTask
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
)
logger = logging.getLogger(__name__)
@register_task('frm_text_to_speech')
class FrmTextToSpeechTask(TextToSpeechTask):
@staticmethod
def add_args(parser):
TextToSpeechTask.add_args(parser)
parser.add_argument(
"--do_chunk", action="store_true", help="train on chunks"
)
parser.add_argument("--chunk_bound", default=-1, type=int)
parser.add_argument("--chunk_init", default=50, type=int)
parser.add_argument("--chunk_incr", default=5, type=int)
parser.add_argument("--add_eos", action="store_true")
parser.add_argument("--dedup", action="store_true")
parser.add_argument("--ref_fpu", default=-1, type=float)
def load_dataset(self, split, **unused_kwargs):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv(
self.args.data,
self.data_cfg,
split,
self.src_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split=is_train_split,
n_frames_per_step=self.args.n_frames_per_step,
speaker_to_id=self.speaker_to_id,
do_chunk=self.args.do_chunk,
chunk_bound=self.args.chunk_bound,
chunk_init=self.args.chunk_init,
chunk_incr=self.args.chunk_incr,
add_eos=self.args.add_eos,
dedup=self.args.dedup,
ref_fpu=self.args.ref_fpu
)

View File

@ -50,6 +50,16 @@ class SpeechToTextTask(LegacyFairseqTask):
super().__init__(args)
self.tgt_dict = tgt_dict
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
self.speaker_to_id = self._get_speaker_to_id()
def _get_speaker_to_id(self):
speaker_to_id = None
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename")
if speaker_set_filename is not None:
speaker_set_path = Path(self.args.data) / speaker_set_filename
with open(speaker_set_path) as f:
speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
return speaker_to_id
@classmethod
def setup_task(cls, args, **kwargs):
@ -91,6 +101,7 @@ class SpeechToTextTask(LegacyFairseqTask):
is_train_split=is_train_split,
epoch=epoch,
seed=self.args.seed,
speaker_to_id=self.speaker_to_id
)
@property
@ -107,6 +118,7 @@ class SpeechToTextTask(LegacyFairseqTask):
def build_model(self, args):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_channels
args.speaker_to_id = self.speaker_to_id
return super(SpeechToTextTask, self).build_model(args)
def build_generator(
@ -126,12 +138,13 @@ class SpeechToTextTask(LegacyFairseqTask):
for s, i in self.tgt_dict.indices.items()
if SpeechToTextDataset.is_lang_tag(s)
}
if extra_gen_cls_kwargs is None:
extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids}
else:
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
extra_gen_cls_kwargs = {}
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
return super().build_generator(
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
models, args, seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_tokenizer(self, args):

View File

@ -0,0 +1,467 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import os.path as op
import torch
import torch.nn.functional as F
import numpy as np
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator
from fairseq.tasks import register_task
from fairseq.tasks.speech_to_text import SpeechToTextTask
from fairseq.speech_generator import (
AutoRegressiveSpeechGenerator, NonAutoregressiveSpeechGenerator,
TeacherForcingAutoRegressiveSpeechGenerator
)
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
)
logger = logging.getLogger(__name__)
try:
from tensorboardX import SummaryWriter
except ImportError:
logger.info("Please install tensorboardX: pip install tensorboardX")
SummaryWriter = None
@register_task('text_to_speech')
class TextToSpeechTask(SpeechToTextTask):
@staticmethod
def add_args(parser):
parser.add_argument('data', help='manifest root path')
parser.add_argument(
'--config-yaml', type=str, default='config.yaml',
help='Configuration YAML filename (under manifest root)'
)
parser.add_argument('--max-source-positions', default=1024, type=int,
metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1200, type=int,
metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument("--n-frames-per-step", type=int, default=1)
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
parser.add_argument("--eval-inference", action="store_true")
parser.add_argument("--eval-tb-nsample", type=int, default=8)
parser.add_argument("--vocoder", type=str, default="griffin_lim")
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
def __init__(self, args, src_dict):
super().__init__(args, src_dict)
self.src_dict = src_dict
self.sr = self.data_cfg.config.get("features").get("sample_rate")
self.tensorboard_writer = None
self.tensorboard_dir = ""
if args.tensorboard_logdir and SummaryWriter is not None:
self.tensorboard_dir = os.path.join(args.tensorboard_logdir,
"valid_extra")
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split = split.startswith('train')
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv(
self.args.data, self.data_cfg, split, self.src_dict,
pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split,
epoch=epoch, seed=self.args.seed,
n_frames_per_step=self.args.n_frames_per_step,
speaker_to_id=self.speaker_to_id
)
@property
def target_dictionary(self):
return None
@property
def source_dictionary(self):
return self.src_dict
def get_speaker_embeddings_path(self):
speaker_emb_path = None
if self.data_cfg.config.get("speaker_emb_filename") is not None:
speaker_emb_path = op.join(
self.args.data, self.data_cfg.config.get("speaker_emb_filename")
)
return speaker_emb_path
@classmethod
def get_speaker_embeddings(cls, args):
embed_speaker = None
if args.speaker_to_id is not None:
if args.speaker_emb_path is None:
embed_speaker = torch.nn.Embedding(
len(args.speaker_to_id), args.speaker_embed_dim
)
else:
speaker_emb_mat = np.load(args.speaker_emb_path)
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim
embed_speaker = torch.nn.Embedding.from_pretrained(
torch.from_numpy(speaker_emb_mat), freeze=True,
)
logger.info(
f"load speaker embeddings from {args.speaker_emb_path}. "
f"train embedding? {embed_speaker.weight.requires_grad}\n"
f"embeddings:\n{speaker_emb_mat}"
)
return embed_speaker
def build_model(self, cfg):
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None)
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None)
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None)
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None)
cfg.speaker_emb_path = self.get_speaker_embeddings_path()
model = super().build_model(cfg)
self.generator = None
if getattr(cfg, "eval_inference", False):
self.generator = self.build_generator([model], cfg)
return model
def build_generator(self, models, cfg, vocoder=None, **unused):
if vocoder is None:
vocoder = self.build_default_vocoder()
model = models[0]
if getattr(model, "NON_AUTOREGRESSIVE", False):
return NonAutoregressiveSpeechGenerator(
model, vocoder, self.data_cfg
)
else:
generator = AutoRegressiveSpeechGenerator
if getattr(cfg, "teacher_forcing", False):
generator = TeacherForcingAutoRegressiveSpeechGenerator
logger.info("Teacher forcing mode for generation")
return generator(
model, vocoder, self.data_cfg,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold
)
def build_default_vocoder(self):
from fairseq.models.text_to_speech.vocoder import get_vocoder
vocoder = get_vocoder(self.args, self.data_cfg)
if torch.cuda.is_available() and not self.args.cpu:
vocoder = vocoder.cuda()
else:
vocoder = vocoder.cpu()
return vocoder
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(
sample, model, criterion
)
if getattr(self.args, "eval_inference", False):
hypos, inference_losses = self.valid_step_with_inference(
sample, model, self.generator
)
for k, v in inference_losses.items():
assert(k not in logging_output)
logging_output[k] = v
picked_id = 0
if self.tensorboard_dir and (sample["id"] == picked_id).any():
self.log_tensorboard(
sample,
hypos[:self.args.eval_tb_nsample],
model._num_updates,
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False)
)
return loss, sample_size, logging_output
def valid_step_with_inference(self, sample, model, generator):
hypos = generator.generate(model, sample, has_targ=True)
losses = {
"mcd_loss": 0.,
"targ_frames": 0.,
"pred_frames": 0.,
"nins": 0.,
"ndel": 0.,
}
rets = batch_mel_cepstral_distortion(
[hypo["targ_waveform"] for hypo in hypos],
[hypo["waveform"] for hypo in hypos],
self.sr,
normalize_type=None
)
for d, extra in rets:
pathmap = extra[-1]
losses["mcd_loss"] += d.item()
losses["targ_frames"] += pathmap.size(0)
losses["pred_frames"] += pathmap.size(1)
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
return hypos, losses
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False):
if self.tensorboard_writer is None:
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
tb_writer = self.tensorboard_writer
for b in range(len(hypos)):
idx = sample["id"][b]
text = sample["src_texts"][b]
targ = hypos[b]["targ_feature"]
pred = hypos[b]["feature"]
attn = hypos[b]["attn"]
if is_na_model:
data = plot_tts_output(
[targ.transpose(0, 1), pred.transpose(0, 1)],
[f"target (idx={idx})", "output"], attn,
"alignment", ret_np=True, suptitle=text,
)
else:
eos_prob = hypos[b]["eos_prob"]
data = plot_tts_output(
[targ.transpose(0, 1), pred.transpose(0, 1), attn],
[f"target (idx={idx})", "output", "alignment"], eos_prob,
"eos prob", ret_np=True, suptitle=text,
)
tb_writer.add_image(
f"inference_sample_{b}", data, num_updates,
dataformats="HWC"
)
if hypos[b]["waveform"] is not None:
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float()
pred_wave = hypos[b]["waveform"].detach().cpu().float()
tb_writer.add_audio(
f"inference_targ_{b}",
targ_wave,
num_updates,
sample_rate=self.sr
)
tb_writer.add_audio(
f"inference_pred_{b}",
pred_wave,
num_updates,
sample_rate=self.sr
)
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
DEFAULT_V_MIN = np.log(1e-5)
def plot_tts_output(
data_2d, title_2d, data_1d, title_1d, figsize=(24, 4),
v_min=DEFAULT_V_MIN, v_max=3, ret_np=False, suptitle=""
):
try:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
except ImportError:
raise ImportError("Please install Matplotlib: pip install matplotlib")
data_2d = [
x.detach().cpu().float().numpy()
if isinstance(x, torch.Tensor) else x for x in data_2d
]
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize)
if suptitle:
fig.suptitle(suptitle[:400]) # capped at 400 chars
axes = [axes] if len(data_2d) == 0 else axes
for ax, x, name in zip(axes, data_2d, title_2d):
ax.set_title(name)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
im = ax.imshow(
x, origin="lower", aspect="auto", vmin=max(x.min(), v_min),
vmax=min(x.max(), v_max)
)
fig.colorbar(im, cax=cax, orientation='vertical')
if isinstance(data_1d, torch.Tensor):
data_1d = data_1d.detach().cpu().numpy()
axes[-1].plot(data_1d)
axes[-1].set_title(title_1d)
plt.tight_layout()
if ret_np:
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close(fig)
return data
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None):
"""
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs
offset=2 (1, 1),
offset=3 (2, 1), (1, 2)
offset=4 (2, 2), (1, 3)
offset=5 (2, 3)
constraints:
i + j = offset
min_j <= j < max_j
min_i <= offset - j < max_i
"""
if max_i is None:
max_i = offset + 1
if max_j is None:
max_j = offset + 1
min_j = max(min_j, offset - max_i + 1, 0)
max_j = min(max_j, offset - min_i + 1, offset + 1)
j = torch.arange(min_j, max_j)
i = offset - j
return torch.stack([i, j])
def batch_dynamic_time_warping(distance, shapes=None):
"""full batched DTW without any constraints
distance: (batchsize, max_M, max_N) matrix
shapes: (batchsize,) vector specifying (M, N) for each entry
"""
# ptr: 0=left, 1=up-left, 2=up
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)}
bsz, m, n = distance.size()
cumdist = torch.zeros_like(distance)
backptr = torch.zeros_like(distance).type(torch.int32) - 1
# initialize
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1)
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1)
backptr[:, 0, :] = 0
backptr[:, :, 0] = 2
# DP with optimized anti-diagonal parallelization, O(M+N) steps
for offset in range(2, m + n - 1):
ind = antidiag_indices(offset, 1, m, 1, n)
c = torch.stack(
[cumdist[:, ind[0], ind[1] - 1], cumdist[:, ind[0] - 1, ind[1] - 1],
cumdist[:, ind[0] - 1, ind[1]], ],
dim=2
)
v, b = c.min(axis=-1)
backptr[:, ind[0], ind[1]] = b.int()
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]]
# backtrace
pathmap = torch.zeros_like(backptr)
for b in range(bsz):
i = m - 1 if shapes is None else (shapes[b][0] - 1).item()
j = n - 1 if shapes is None else (shapes[b][1] - 1).item()
dtwpath = [(i, j)]
while (i != 0 or j != 0) and len(dtwpath) < 10000:
assert (i >= 0 and j >= 0)
di, dj = ptr2dij[backptr[b, i, j].item()]
i, j = i + di, j + dj
dtwpath.append((i, j))
dtwpath = dtwpath[::-1]
indices = torch.from_numpy(np.array(dtwpath))
pathmap[b, indices[:, 0], indices[:, 1]] = 1
return cumdist, backptr, pathmap
def compute_l2_dist(x1, x2):
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices"""
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2)
def compute_rms_dist(x1, x2):
l2_dist = compute_l2_dist(x1, x2)
return (l2_dist / x1.size(1)).pow(0.5)
def get_divisor(pathmap, normalize_type):
if normalize_type is None:
return 1
elif normalize_type == "len1":
return pathmap.size(0)
elif normalize_type == "len2":
return pathmap.size(1)
elif normalize_type == "path":
return pathmap.sum().item()
else:
raise ValueError(f"normalize_type {normalize_type} not supported")
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type):
d, s, x1, x2 = [], [], [], []
for cur_y1, cur_y2 in zip(y1, y2):
assert (cur_y1.ndim == 1 and cur_y2.ndim == 1)
cur_x1 = feat_fn(cur_y1)
cur_x2 = feat_fn(cur_y2)
x1.append(cur_x1)
x2.append(cur_x2)
cur_d = dist_fn(cur_x1, cur_x2)
d.append(cur_d)
s.append(d[-1].size())
max_m = max(ss[0] for ss in s)
max_n = max(ss[1] for ss in s)
d = torch.stack(
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d]
)
s = torch.LongTensor(s).to(d.device)
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s)
rets = []
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps)
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr:
cumdist = cumdist[:m, :n]
backptr = backptr[:m, :n]
pathmap = pathmap[:m, :n]
divisor = get_divisor(pathmap, normalize_type)
distortion = cumdist[-1, -1] / divisor
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap)
rets.append(ret)
return rets
def batch_mel_cepstral_distortion(
y1, y2, sr, normalize_type="path", mfcc_fn=None
):
"""
https://arxiv.org/pdf/2011.03568.pdf
The root mean squared error computed on 13-dimensional MFCC using DTW for
alignment. MFCC features are computed from an 80-channel log-mel
spectrogram using a 50ms Hann window and hop of 12.5ms.
y1: list of waveforms
y2: list of waveforms
sr: sampling rate
"""
try:
import torchaudio
except ImportError:
raise ImportError("Please install torchaudio: pip install torchaudio")
if mfcc_fn is None or mfcc_fn.sample_rate != sr:
melkwargs = {
"n_fft": int(0.05 * sr), "win_length": int(0.05 * sr),
"hop_length": int(0.0125 * sr), "f_min": 20,
"n_mels": 80, "window_fn": torch.hann_window
}
mfcc_fn = torchaudio.transforms.MFCC(
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs
).to(y1[0].device)
return batch_compute_distortion(
y1, y2, sr, lambda y: mfcc_fn(y).transpose(-1, -2), compute_rms_dist,
normalize_type
)

View File

@ -210,6 +210,7 @@ def do_setup(package_data):
"torch",
"tqdm",
"bitarray",
"torchaudio>=0.8.0",
],
dependency_links=dependency_links,
packages=find_packages(