From 0f33ccf7cfae3fdc918d40a8c8388eea00f5ea2a Mon Sep 17 00:00:00 2001 From: Felix Kreuk Date: Mon, 12 Dec 2022 16:00:01 +0200 Subject: [PATCH] Emotion Conversion Paper Open Source (#4895) --- examples/emotion_conversion/README.md | 214 +++++++ .../emotion_models/__init__.py | 0 .../emotion_models/duration_predictor.py | 243 ++++++++ .../emotion_models/duration_predictor.yaml | 48 ++ .../emotion_models/pitch_predictor.py | 559 ++++++++++++++++++ .../emotion_models/pitch_predictor.yaml | 64 ++ .../emotion_models/utils.py | 78 +++ .../fairseq_models/__init__.py | 226 +++++++ .../emotion_conversion/preprocess/__init__.py | 0 .../preprocess/build_hifigan_manifest.py | 38 ++ .../preprocess/build_translation_manifests.py | 258 ++++++++ .../preprocess/create_core_manifest.py | 91 +++ .../preprocess/extract_f0.py | 57 ++ .../preprocess/process_km.py | 40 ++ .../preprocess/split_emov_km_tsv_by_uttid.py | 70 +++ .../emotion_conversion/preprocess/split_km.py | 50 ++ .../preprocess/split_km_tsv.py | 65 ++ examples/emotion_conversion/requirements.txt | 11 + examples/emotion_conversion/synthesize.py | 322 ++++++++++ 19 files changed, 2434 insertions(+) create mode 100644 examples/emotion_conversion/README.md create mode 100644 examples/emotion_conversion/emotion_models/__init__.py create mode 100644 examples/emotion_conversion/emotion_models/duration_predictor.py create mode 100644 examples/emotion_conversion/emotion_models/duration_predictor.yaml create mode 100644 examples/emotion_conversion/emotion_models/pitch_predictor.py create mode 100644 examples/emotion_conversion/emotion_models/pitch_predictor.yaml create mode 100644 examples/emotion_conversion/emotion_models/utils.py create mode 100644 examples/emotion_conversion/fairseq_models/__init__.py create mode 100644 examples/emotion_conversion/preprocess/__init__.py create mode 100644 examples/emotion_conversion/preprocess/build_hifigan_manifest.py create mode 100644 examples/emotion_conversion/preprocess/build_translation_manifests.py create mode 100644 examples/emotion_conversion/preprocess/create_core_manifest.py create mode 100644 examples/emotion_conversion/preprocess/extract_f0.py create mode 100644 examples/emotion_conversion/preprocess/process_km.py create mode 100644 examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py create mode 100644 examples/emotion_conversion/preprocess/split_km.py create mode 100644 examples/emotion_conversion/preprocess/split_km_tsv.py create mode 100644 examples/emotion_conversion/requirements.txt create mode 100644 examples/emotion_conversion/synthesize.py diff --git a/examples/emotion_conversion/README.md b/examples/emotion_conversion/README.md new file mode 100644 index 00000000..caf22bef --- /dev/null +++ b/examples/emotion_conversion/README.md @@ -0,0 +1,214 @@ +# Textless speech emotion conversion using decomposed and discrete representations +[Felix Kreuk](https://felixkreuk.github.io), Adam Polyak, Jade Copet, Eugene Kharitonov, Tu-Anh Nguyen, Morgane Rivière, Wei-Ning Hsu, Abdelrahman Mohamed, Emmanuel Dupoux, [Yossi Adi](https://adiyoss.github.io) + +_abstract_: Speech emotion conversion is the task of modifying the perceived emotion of a speech utterance while preserving the lexical content and speaker identity. In this study, we cast the problem of emotion conversion as a spoken language translation task. We decompose speech into discrete and disentangled learned representations, consisting of content units, F0, speaker, and emotion. First, we modify the speech content by translating the content units to a target emotion, and then predict the prosodic features based on these units. Finally, the speech waveform is generated by feeding the predicted representations into a neural vocoder. Such a paradigm allows us to go beyond spectral and parametric changes of the signal, and model non-verbal vocalizations, such as laughter insertion, yawning removal, etc. We demonstrate objectively and subjectively that the proposed method is superior to the baselines in terms of perceived emotion and audio quality. We rigorously evaluate all components of such a complex system and conclude with an extensive model analysis and ablation study to better emphasize the architectural choices, strengths and weaknesses of the proposed method. Samples and code will be publicly available under the following link: https://speechbot.github.io/emotion. + +## Installation +First, create a conda virtual environment and activate it: +``` +conda create -n emotion python=3.8 -y +conda activate emotion +``` + +Then, clone this repository: +``` +git clone https://github.com/facebookresearch/fairseq.git +cd fairseq/examples/emotion_conversion +git clone https://github.com/felixkreuk/speech-resynthesis +``` + +Next, download the EmoV discrete tokens: +``` +wget https://dl.fbaipublicfiles.com/textless_nlp/emotion_conversion/data.tar.gz # (still in fairseq/examples/emotion_conversion) +tar -xzvf data.tar.gz +``` + +Your `fairseq/examples/emotion_conversion` directory should like this: +``` +drwxrwxr-x 3 felixkreuk felixkreuk 0 Feb 6 2022 data +drwxrwxr-x 3 felixkreuk felixkreuk 0 Sep 28 10:41 emotion_models +drwxr-xr-x 3 felixkreuk felixkreuk 0 Jun 29 05:43 fairseq_models +drwxr-xr-x 3 felixkreuk felixkreuk 0 Sep 28 10:41 preprocess +-rw-rw-r-- 1 felixkreuk felixkreuk 11K Dec 5 09:00 README.md +-rw-rw-r-- 1 felixkreuk felixkreuk 88 Mar 6 2022 requirements.txt +-rw-rw-r-- 1 felixkreuk felixkreuk 13K Jun 29 06:26 synthesize.py +``` + +Lastly, install fairseq and the other packages: +``` +pip install --editable ./ +pip install -r examples/emotion_conversion/requirements.txt +``` + +## Data preprocessing + +### Convert your audio to discrete representations +Please follow the steps described [here](https://github.com/pytorch/fairseq/tree/main/examples/hubert/simple_kmeans). +To generate the same discrete representations please use the following: +1. [HuBERT checkpoint](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) +2. k-means model at `data/hubert_base_ls960_layer9_clusters200/data_hubert_base_ls960_layer9_clusters200.bin` + +### Construct data splits +This step will use the discrete representations from the previous step and split them to train/valid/test sets for 3 tasks: +1. Translation model pre-training (BART language denoising) +2. Translation model training (content units emotion translation mechanism) +3. HiFiGAN model training (for synthesizing audio from discrete representations) + +Your processed data should be at `data/`: +1. `hubert_base_ls960_layer9_clusters200` - discrete representations extracted using HuBERT layer 9, clustered into 200 clusters. +2. `data.tsv` - a tsv file pointing to the EmoV dataset in your environment (Please edit the first line of this file according to your path). + +The following command will create the above splits: +``` +python examples/emotion_conversion/preprocess/create_core_manifest.py \ + --tsv data/data.tsv \ + --emov-km data/hubert_base_ls960_layer9_clusters200/data.km \ + --km data/hubert_base_ls960_layer9_clusters200/vctk.km \ + --dict data/hubert_base_ls960_layer9_clusters200/dict.txt \ + --manifests-dir $DATA +``` +* Set `$DATA` as the directory that will contain the processed data. + +### Extract F0 +To train the HiFiGAN vocoder we need to first extract the F0 curves: +``` +python examples/emotion_conversion/preprocess/extract_f0.py \ + --tsv data/data.tsv \ + --extractor pyaapt \ +``` + +## HiFiGAN training +Now we are all set to train the HiFiGAN vocoder: +``` +python examples/emotion_conversion/speech-resynthesis/train.py + --checkpoint_path \ + --config examples/emotion_conversion/speech-resynthesis/configs/EmoV/emov_hubert-layer9-cluster200_fixed-spkr-embedder_f0-raw_gst.json +``` + +## Translation Pre-training +Before translating emotions, we first need to pre-train the translation model as a denoising autoencoder (similarly to BART). +``` +python train.py \ + $DATA/fairseq-data/emov_multilingual_denoising_cross-speaker_dedup_nonzeroshot/tokenized \ + --save-dir \ + --tensorboard-logdir \ + --langs neutral,amused,angry,sleepy,disgusted,vctk.km \ + --dataset-impl mmap \ + --task multilingual_denoising \ + --arch transformer_small --criterion cross_entropy \ + --multilang-sampling-alpha 1.0 --sample-break-mode eos --max-tokens 16384 \ + --update-freq 1 --max-update 3000000 \ + --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.0 \ + --optimizer adam --weight-decay 0.01 --adam-eps 1e-06 \ + --clip-norm 0.1 --lr-scheduler polynomial_decay --lr 0.0003 \ + --total-num-update 3000000 --warmup-updates 10000 --fp16 \ + --poisson-lambda 3.5 --mask 0.3 --mask-length span-poisson --replace-length 1 --rotate 0 --mask-random 0.1 --insert 0 --permute-sentences 1.0 \ + --skip-invalid-size-inputs-valid-test \ + --user-dir examples/emotion_conversion/fairseq_models +``` + +## Translation Training +Now we are ready to train our emotion translation model: +``` +python train.py \ + --distributed-world-size 1 \ + $DATA/fairseq-data/emov_multilingual_translation_cross-speaker_dedup/tokenized/ \ + --save-dir \ + --tensorboard-logdir \ + --arch multilingual_small --task multilingual_translation \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --lang-pairs neutral-amused,neutral-sleepy,neutral-disgusted,neutral-angry,amused-sleepy,amused-disgusted,amused-neutral,amused-angry,angry-amused,angry-sleepy,angry-disgusted,angry-neutral,disgusted-amused,disgusted-sleepy,disgusted-neutral,disgusted-angry,sleepy-amused,sleepy-neutral,sleepy-disgusted,sleepy-angry \ + --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --lr 1e-05 --clip-norm 0 --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.01 --warmup-updates 2000 --lr-scheduler inverse_sqrt \ + --max-tokens 4096 --update-freq 1 --max-update 100000 \ + --required-batch-size-multiple 8 --fp16 --num-workers 4 \ + --seed 2 --log-format json --log-interval 25 --save-interval-updates 1000 \ + --no-epoch-checkpoints --keep-best-checkpoints 1 --keep-interval-updates 1 \ + --finetune-from-model \ + --user-dir examples/emotion_conversion/fairseq_models +``` +* To share encoders/decoders use the `--share-encoders` and `--share-decoders` flags. +* To add source/target emotion tokens use the `--encoder-langtok {'src'|'tgt'}` and `--decoder-langtok` flags. + +## F0-predictor Training +The following command trains the F0 prediction module: +``` +cd examples/emotion_conversion +python -m emotion_models.pitch_predictor n_tokens=200 \ + train_tsv="$DATA/denoising/emov/train.tsv" \ + train_km="$DATA/denoising/emov/train.km" \ + valid_tsv="$DATA/denoising/emov/valid.tsv" \ + valid_km="$DATA/denoising/emov/valid.km" +``` +* See `hyra.run.dir` to configure directory for saving models. + +## Duration-predictor Training +The following command trains the duration prediction modules: +``` +cd examples/emotion_conversion +for emotion in "neutral" "amused" "angry" "disgusted" "sleepy"; do + python -m emotion_models.duration_predictor n_tokens=200 substring=$emotion \ + train_tsv="$DATA/denoising/emov/train.tsv" \ + train_km="$DATA/denoising/emov/train.km" \ + valid_tsv="$DATA/denoising/emov/valid.tsv" \ + valid_km="$DATA/denoising/emov/valid.km" +done +``` +* See `hyra.run.dir` to configure directory for saving models. +* After the above command you should have 5 duration models in your checkpoint directory: +``` +❯ ll duration_predictor/ +total 21M +-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 amused.ckpt +-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 angry.ckpt +-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 disgusted.ckpt +-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 neutral.ckpt +-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 sleepy.ckpt +``` + +## Token Generation +The following command uses `fairseq-generate` to generate the token sequences based on the source and target emotions. +``` +fairseq-generate \ + $DATA/fairseq-data/emov_multilingual_translation_cross-speaker_dedup/tokenized/ \ + --task multilingual_translation \ + --gen-subset test \ + --path \ + --beam 5 \ + --batch-size 4 --max-len-a 1.8 --max-len-b 10 --lenpen 1 --min-len 1 \ + --skip-invalid-size-inputs-valid-test --distributed-world-size 1 \ + --source-lang neutral --target-lang amused \ + --lang-pairs neutral-amused,neutral-sleepy,neutral-disgusted,neutral-angry,amused-sleepy,amused-disgusted,amused-neutral,amused-angry,angry-amused,angry-sleepy,angry-disgusted,angry-neutral,disgusted-amused,disgusted-sleepy,disgusted-neutral,disgusted-angry,sleepy-amused,sleepy-neutral,sleepy-disgusted,sleepy-angry \ + --results-path \ + --user-dir examples/emotion_conversion/fairseq_models +``` +* Modify `--source-lang` and `--target-lang` to control for the source and target emotions. +* See [fairseq documentation](https://fairseq.readthedocs.io/en/latest/command_line_tools.html#fairseq-generate) for a full overview of generation parameters (e.g., top-k/top-p sampling). + +## Waveform Synthesis +Using the output of the above command, the HiFiGAN vocoder, and the prosody prediction modules (F0 and duration) we can now generate the output waveforms: +``` +python examples/emotion_conversion/synthesize.py \ + --result-path /generate-test.txt \ + --data $DATA/fairseq-data/emov_multilingual_translation_cross-speaker_dedup/neutral-amused \ + --orig-tsv examples/emotion_conversion/data/data.tsv \ + --orig-km examples/emotion_conversion/data/hubert_base_ls960_layer9_clusters200/data.km \ + --checkpoint-file /g_00400000 \ + --dur-model duration_predictor/ \ + --f0-model pitch_predictor/pitch_predictor.ckpt \ + -s neutral -t amused \ + --outdir ~/tmp/emotion_results/wavs/neutral-amused +``` +* Please make sure the source and target emotions here match those of the previous command. + +# Citation +If you find this useful in your research, please use the following BibTeX entry for citation. +``` +@article{kreuk2021textless, + title={Textless speech emotion conversion using decomposed and discrete representations}, + author={Kreuk, Felix and Polyak, Adam and Copet, Jade and Kharitonov, Eugene and Nguyen, Tu-Anh and Rivi{\`e}re, Morgane and Hsu, Wei-Ning and Mohamed, Abdelrahman and Dupoux, Emmanuel and Adi, Yossi}, + journal={Conference on Empirical Methods in Natural Language Processing (EMNLP)}, + year={2022} +} +``` diff --git a/examples/emotion_conversion/emotion_models/__init__.py b/examples/emotion_conversion/emotion_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/emotion_conversion/emotion_models/duration_predictor.py b/examples/emotion_conversion/emotion_models/duration_predictor.py new file mode 100644 index 00000000..eb47df0a --- /dev/null +++ b/examples/emotion_conversion/emotion_models/duration_predictor.py @@ -0,0 +1,243 @@ +import logging +import os + +import hydra +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.layers.torch import Rearrange +from torch.utils.data import DataLoader, Dataset + +from .utils import Accuracy + +logger = logging.getLogger(__name__) + + +def save_ckpt(model, path, model_class): + ckpt = { + "state_dict": model.state_dict(), + "padding_token": model.padding_token, + "model_class": model_class, + } + torch.save(ckpt, path) + + +def load_ckpt(path): + ckpt = torch.load(path) + ckpt["model_class"]["_target_"] = "emotion_models.duration_predictor.CnnPredictor" + model = hydra.utils.instantiate(ckpt["model_class"]) + model.load_state_dict(ckpt["state_dict"]) + model.padding_token = ckpt["padding_token"] + model = model.cpu() + model.eval() + return model + + +class Collator: + def __init__(self, padding_idx): + self.padding_idx = padding_idx + + def __call__(self, batch): + x = [item[0] for item in batch] + lengths = [len(item) for item in x] + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.padding_idx) + y = [item[1] for item in batch] + y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=self.padding_idx) + mask = (x != self.padding_idx) + return x, y, mask, lengths + + +class Predictor(nn.Module): + def __init__(self, n_tokens, emb_dim): + super(Predictor, self).__init__() + self.n_tokens = n_tokens + self.emb_dim = emb_dim + self.padding_token = n_tokens + # add 1 extra embedding for padding token, set the padding index to be the last token + # (tokens from the clustering start at index 0) + self.emb = nn.Embedding(n_tokens + 1, emb_dim, padding_idx=self.padding_token) + + def inflate_input(self, batch): + """ get a sequence of tokens, predict their durations + and inflate them accordingly """ + batch_durs = self.forward(batch) + batch_durs = torch.exp(batch_durs) - 1 + batch_durs = batch_durs.round() + output = [] + for seq, durs in zip(batch, batch_durs): + inflated_seq = [] + for token, n in zip(seq, durs): + if token == self.padding_token: + break + n = int(n.item()) + token = int(token.item()) + inflated_seq.extend([token for _ in range(n)]) + output.append(inflated_seq) + output = torch.LongTensor(output) + return output + + +class CnnPredictor(Predictor): + def __init__(self, n_tokens, emb_dim, channels, kernel, output_dim, dropout, n_layers): + super(CnnPredictor, self).__init__(n_tokens=n_tokens, emb_dim=emb_dim) + layers = [ + Rearrange("b t c -> b c t"), + nn.Conv1d(emb_dim, channels, kernel_size=kernel, padding=(kernel - 1) // 2), + Rearrange("b c t -> b t c"), + nn.ReLU(), + nn.LayerNorm(channels), + nn.Dropout(dropout), + ] + for _ in range(n_layers-1): + layers += [ + Rearrange("b t c -> b c t"), + nn.Conv1d(channels, channels, kernel_size=kernel, padding=(kernel - 1) // 2), + Rearrange("b c t -> b t c"), + nn.ReLU(), + nn.LayerNorm(channels), + nn.Dropout(dropout), + ] + self.conv_layer = nn.Sequential(*layers) + self.proj = nn.Linear(channels, output_dim) + + def forward(self, x): + x = self.emb(x) + x = self.conv_layer(x) + x = self.proj(x) + x = x.squeeze(-1) + return x + + +def l2_log_loss(input, target): + return F.mse_loss( + input=input.float(), + target=torch.log(target.float() + 1), + reduce=False + ) + + +class DurationDataset(Dataset): + def __init__(self, tsv_path, km_path, substring=""): + lines = open(tsv_path, "r").readlines() + self.root, self.tsv = lines[0], lines[1:] + self.km = open(km_path, "r").readlines() + logger.info(f"loaded {len(self.km)} files") + + if substring != "": + tsv, km = [], [] + for tsv_line, km_line in zip(self.tsv, self.km): + if substring.lower() in tsv_line.lower(): + tsv.append(tsv_line) + km.append(km_line) + self.tsv, self.km = tsv, km + logger.info(f"after filtering: {len(self.km)} files") + + def __len__(self): + return len(self.km) + + def __getitem__(self, i): + x = self.km[i] + x = x.split(" ") + x = list(map(int, x)) + + y = [] + xd = [] + count = 1 + for x1, x2 in zip(x[:-1], x[1:]): + if x1 == x2: + count += 1 + continue + else: + y.append(count) + xd.append(x1) + count = 1 + + xd = torch.LongTensor(xd) + y = torch.LongTensor(y) + return xd, y + + +def train(cfg): + device = "cuda:0" + model = hydra.utils.instantiate(cfg[cfg.model]).to(device) + optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters()) + # add 1 extra embedding for padding token, set the padding index to be the last token + # (tokens from the clustering start at index 0) + collate_fn = Collator(padding_idx=model.padding_token) + logger.info(f"data: {cfg.train_tsv}") + train_ds = DurationDataset(cfg.train_tsv, cfg.train_km, substring=cfg.substring) + valid_ds = DurationDataset(cfg.valid_tsv, cfg.valid_km, substring=cfg.substring) + train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn) + valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=False, collate_fn=collate_fn) + + best_loss = float("inf") + for epoch in range(cfg.epochs): + train_loss, train_loss_scaled = train_epoch(model, train_dl, l2_log_loss, optimizer, device) + valid_loss, valid_loss_scaled, *acc = valid_epoch(model, valid_dl, l2_log_loss, device) + acc0, acc1, acc2, acc3 = acc + if valid_loss_scaled < best_loss: + path = f"{os.getcwd()}/{cfg.substring}.ckpt" + save_ckpt(model, path, cfg[cfg.model]) + best_loss = valid_loss_scaled + logger.info(f"saved checkpoint: {path}") + logger.info(f"[epoch {epoch}] train loss: {train_loss:.3f}, train scaled: {train_loss_scaled:.3f}") + logger.info(f"[epoch {epoch}] valid loss: {valid_loss:.3f}, valid scaled: {valid_loss_scaled:.3f}") + logger.info(f"acc: {acc0,acc1,acc2,acc3}") + + +def train_epoch(model, loader, criterion, optimizer, device): + model.train() + epoch_loss = 0 + epoch_loss_scaled = 0 + for x, y, mask, _ in loader: + x, y, mask = x.to(device), y.to(device), mask.to(device) + yhat = model(x) + loss = criterion(yhat, y) * mask + loss = torch.mean(loss) + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + epoch_loss += loss.item() + # get normal scale loss + yhat_scaled = torch.exp(yhat) - 1 + yhat_scaled = torch.round(yhat_scaled) + scaled_loss = torch.mean(torch.abs(yhat_scaled - y) * mask) + epoch_loss_scaled += scaled_loss.item() + return epoch_loss / len(loader), epoch_loss_scaled / len(loader) + + +def valid_epoch(model, loader, criterion, device): + model.eval() + epoch_loss = 0 + epoch_loss_scaled = 0 + acc = Accuracy() + for x, y, mask, _ in loader: + x, y, mask = x.to(device), y.to(device), mask.to(device) + yhat = model(x) + loss = criterion(yhat, y) * mask + loss = torch.mean(loss) + epoch_loss += loss.item() + # get normal scale loss + yhat_scaled = torch.exp(yhat) - 1 + yhat_scaled = torch.round(yhat_scaled) + scaled_loss = torch.sum(torch.abs(yhat_scaled - y) * mask) / mask.sum() + acc.update(yhat_scaled[mask].view(-1).float(), y[mask].view(-1).float()) + epoch_loss_scaled += scaled_loss.item() + logger.info(f"example y: {y[0, :10].tolist()}") + logger.info(f"example yhat: {yhat_scaled[0, :10].tolist()}") + acc0 = acc.acc(tol=0) + acc1 = acc.acc(tol=1) + acc2 = acc.acc(tol=2) + acc3 = acc.acc(tol=3) + logger.info(f"accs: {acc0,acc1,acc2,acc3}") + return epoch_loss / len(loader), epoch_loss_scaled / len(loader), acc0, acc1, acc2, acc3 + + +@hydra.main(config_path=".", config_name="duration_predictor.yaml") +def main(cfg): + logger.info(f"{cfg}") + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/examples/emotion_conversion/emotion_models/duration_predictor.yaml b/examples/emotion_conversion/emotion_models/duration_predictor.yaml new file mode 100644 index 00000000..0e976f48 --- /dev/null +++ b/examples/emotion_conversion/emotion_models/duration_predictor.yaml @@ -0,0 +1,48 @@ +train_tsv: "/denoising/emov/train.tsv" +train_km: "/denoising/emov/train.km" +valid_tsv: "/denoising/emov/valid.tsv" +valid_km: "/denoising/emov/valid.km" + +n_tokens: 200 +batch_size: 32 +lr: 0.0001 +epochs: 300 +model: "cnn" +substring: "" + +rnn: + _target_: emotion_models.duration_predictor.RnnPredictor + n_tokens: ${n_tokens} + emb_dim: 128 + rnn_hidden: 128 + output_dim: 1 + dropout: 0 + n_layers: 1 + +optimizer: + _target_: torch.optim.Adam + lr: ${lr} + betas: [0.9, 0.98] + eps: 0.000000001 + weight_decay: 0 + +cnn: + _target_: emotion_models.duration_predictor.CnnPredictor + n_tokens: ${n_tokens} + emb_dim: 128 + channels: 256 + kernel: 3 + output_dim: 1 + dropout: 0.5 + n_layers: 1 + +hydra: + run: + dir: /checkpoint/felixkreuk/experiments/duration_predictor/${hydra.job.override_dirname} + job: + config: + # configuration for the ${hydra.job.override_dirname} runtime variable + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: ['train_tsv', 'train_km', 'valid_tsv', 'valid_km'] diff --git a/examples/emotion_conversion/emotion_models/pitch_predictor.py b/examples/emotion_conversion/emotion_models/pitch_predictor.py new file mode 100644 index 00000000..43144699 --- /dev/null +++ b/examples/emotion_conversion/emotion_models/pitch_predictor.py @@ -0,0 +1,559 @@ +import logging +import os +import random +import sys +from collections import defaultdict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange +from scipy.io.wavfile import read +from scipy.ndimage import gaussian_filter1d +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +dir_path = os.path.dirname(__file__) +resynth_path = os.path.dirname(dir_path) + "/speech-resynthesis" +sys.path.append(resynth_path) +from dataset import parse_speaker, parse_style +from .utils import F0Stat + +MAX_WAV_VALUE = 32768.0 +logger = logging.getLogger(__name__) + + +def quantize_f0(speaker_to_f0, nbins, normalize, log): + f0_all = [] + for speaker, f0 in speaker_to_f0.items(): + f0 = f0.raw_data + if log: + f0 = f0.log() + mean = speaker_to_f0[speaker].mean_log if log else speaker_to_f0[speaker].mean + std = speaker_to_f0[speaker].std_log if log else speaker_to_f0[speaker].std + if normalize == "mean": + f0 = f0 - mean + elif normalize == "meanstd": + f0 = (f0 - mean) / std + f0_all.extend(f0.tolist()) + + hist, bin_x = np.histogram(f0_all, 100000) + cum_hist = np.cumsum(hist) / len(f0_all) * 100 + + bin_offset = [] + bin_size = 100 / nbins + threshold = bin_size + for i in range(nbins - 1): + index = (np.abs(cum_hist - threshold)).argmin() + bin_offset.append(bin_x[index]) + threshold += bin_size + bins = np.array(bin_offset) + bins = torch.FloatTensor(bins) + + return bins + + +def save_ckpt(model, path, model_class, f0_min, f0_max, f0_bins, speaker_stats): + ckpt = { + "state_dict": model.state_dict(), + "padding_token": model.padding_token, + "model_class": model_class, + "speaker_stats": speaker_stats, + "f0_min": f0_min, + "f0_max": f0_max, + "f0_bins": f0_bins, + } + torch.save(ckpt, path) + + +def load_ckpt(path): + ckpt = torch.load(path) + ckpt["model_class"]["_target_"] = "emotion_models.pitch_predictor.CnnPredictor" + model = hydra.utils.instantiate(ckpt["model_class"]) + model.load_state_dict(ckpt["state_dict"]) + model.setup_f0_stats( + ckpt["f0_min"], + ckpt["f0_max"], + ckpt["f0_bins"], + ckpt["speaker_stats"], + ) + return model + + +def freq2bin(f0, f0_min, f0_max, bins): + f0 = f0.clone() + f0[f0 < f0_min] = f0_min + f0[f0 > f0_max] = f0_max + f0 = torch.bucketize(f0, bins) + return f0 + + +def bin2freq(x, f0_min, f0_max, bins, mode): + n_bins = len(bins) + 1 + assert x.shape[-1] == n_bins + bins = torch.cat([torch.tensor([f0_min]), bins]).to(x.device) + if mode == "mean": + f0 = (x * bins).sum(-1, keepdims=True) / x.sum(-1, keepdims=True) + elif mode == "argmax": + idx = F.one_hot(x.argmax(-1), num_classes=n_bins) + f0 = (idx * bins).sum(-1, keepdims=True) + else: + raise NotImplementedError() + return f0[..., 0] + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def l1_loss(input, target): + return F.l1_loss(input=input.float(), target=target.float(), reduce=False) + + +def l2_loss(input, target): + return F.mse_loss(input=input.float(), target=target.float(), reduce=False) + + +class Collator: + def __init__(self, padding_idx): + self.padding_idx = padding_idx + + def __call__(self, batch): + tokens = [item[0] for item in batch] + lengths = [len(item) for item in tokens] + tokens = torch.nn.utils.rnn.pad_sequence( + tokens, batch_first=True, padding_value=self.padding_idx + ) + f0 = [item[1] for item in batch] + f0 = torch.nn.utils.rnn.pad_sequence( + f0, batch_first=True, padding_value=self.padding_idx + ) + f0_raw = [item[2] for item in batch] + f0_raw = torch.nn.utils.rnn.pad_sequence( + f0_raw, batch_first=True, padding_value=self.padding_idx + ) + spk = [item[3] for item in batch] + spk = torch.LongTensor(spk) + gst = [item[4] for item in batch] + gst = torch.LongTensor(gst) + mask = tokens != self.padding_idx + return tokens, f0, f0_raw, spk, gst, mask, lengths + + +class CnnPredictor(nn.Module): + def __init__( + self, + n_tokens, + emb_dim, + channels, + kernel, + dropout, + n_layers, + spk_emb, + gst_emb, + n_bins, + f0_pred, + f0_log, + f0_norm, + ): + super(CnnPredictor, self).__init__() + self.n_tokens = n_tokens + self.emb_dim = emb_dim + self.f0_log = f0_log + self.f0_pred = f0_pred + self.padding_token = n_tokens + self.f0_norm = f0_norm + # add 1 extra embedding for padding token, set the padding index to be the last token + # (tokens from the clustering start at index 0) + self.token_emb = nn.Embedding( + n_tokens + 1, emb_dim, padding_idx=self.padding_token + ) + + self.spk_emb = spk_emb + self.gst_emb = nn.Embedding(20, gst_emb) + self.setup = False + + feats = emb_dim + gst_emb + # feats = emb_dim + gst_emb + (256 if spk_emb else 0) + layers = [ + nn.Sequential( + Rearrange("b t c -> b c t"), + nn.Conv1d( + feats, channels, kernel_size=kernel, padding=(kernel - 1) // 2 + ), + Rearrange("b c t -> b t c"), + nn.ReLU(), + nn.LayerNorm(channels), + nn.Dropout(dropout), + ) + ] + for _ in range(n_layers - 1): + layers += [ + nn.Sequential( + Rearrange("b t c -> b c t"), + nn.Conv1d( + channels, + channels, + kernel_size=kernel, + padding=(kernel - 1) // 2, + ), + Rearrange("b c t -> b t c"), + nn.ReLU(), + nn.LayerNorm(channels), + nn.Dropout(dropout), + ) + ] + self.conv_layer = nn.ModuleList(layers) + self.proj = nn.Linear(channels, n_bins) + + def forward(self, x, gst=None): + x = self.token_emb(x) + feats = [x] + + if gst is not None: + gst = self.gst_emb(gst) + gst = rearrange(gst, "b c -> b c 1") + gst = F.interpolate(gst, x.shape[1]) + gst = rearrange(gst, "b c t -> b t c") + feats.append(gst) + + x = torch.cat(feats, dim=-1) + + for i, conv in enumerate(self.conv_layer): + if i != 0: + x = conv(x) + x + else: + x = conv(x) + + x = self.proj(x) + x = x.squeeze(-1) + + if self.f0_pred == "mean": + x = torch.sigmoid(x) + elif self.f0_pred == "argmax": + x = torch.softmax(x, dim=-1) + else: + raise NotImplementedError + return x + + def setup_f0_stats(self, f0_min, f0_max, f0_bins, speaker_stats): + self.f0_min = f0_min + self.f0_max = f0_max + self.f0_bins = f0_bins + self.speaker_stats = speaker_stats + self.setup = True + + def inference(self, x, spk_id=None, gst=None): + assert ( + self.setup == True + ), "make sure that `setup_f0_stats` was called before inference!" + probs = self(x, gst) + f0 = bin2freq(probs, self.f0_min, self.f0_max, self.f0_bins, self.f0_pred) + for i in range(f0.shape[0]): + mean = ( + self.speaker_stats[spk_id[i].item()].mean_log + if self.f0_log + else self.speaker_stats[spk_id[i].item()].mean + ) + std = ( + self.speaker_stats[spk_id[i].item()].std_log + if self.f0_log + else self.speaker_stats[spk_id[i].item()].std + ) + if self.f0_norm == "mean": + f0[i] = f0[i] + mean + if self.f0_norm == "meanstd": + f0[i] = (f0[i] * std) + mean + if self.f0_log: + f0 = f0.exp() + return f0 + + +class PitchDataset(Dataset): + def __init__( + self, + tsv_path, + km_path, + substring, + spk, + spk2id, + gst, + gst2id, + f0_bins, + f0_bin_type, + f0_smoothing, + f0_norm, + f0_log, + ): + lines = open(tsv_path, "r").readlines() + self.root, self.tsv = lines[0], lines[1:] + self.root = self.root.strip() + self.km = open(km_path, "r").readlines() + print(f"loaded {len(self.km)} files") + + self.spk = spk + self.spk2id = spk2id + self.gst = gst + self.gst2id = gst2id + + self.f0_bins = f0_bins + self.f0_smoothing = f0_smoothing + self.f0_norm = f0_norm + self.f0_log = f0_log + + if substring != "": + tsv, km = [], [] + for tsv_line, km_line in zip(self.tsv, self.km): + if substring.lower() in tsv_line.lower(): + tsv.append(tsv_line) + km.append(km_line) + self.tsv, self.km = tsv, km + print(f"after filtering: {len(self.km)} files") + + self.speaker_stats = self._compute_f0_stats() + self.f0_min, self.f0_max = self._compute_f0_minmax() + if f0_bin_type == "adaptive": + self.f0_bins = quantize_f0( + self.speaker_stats, self.f0_bins, self.f0_norm, self.f0_log + ) + elif f0_bin_type == "uniform": + self.f0_bins = torch.linspace(self.f0_min, self.f0_max, self.f0_bins + 1)[ + 1:-1 + ] + else: + raise NotImplementedError + print(f"f0 min: {self.f0_min}, f0 max: {self.f0_max}") + print(f"bins: {self.f0_bins} (shape: {self.f0_bins.shape})") + + def __len__(self): + return len(self.km) + + def _load_f0(self, tsv_line): + tsv_line = tsv_line.split("\t")[0] + f0 = self.root + "/" + tsv_line.replace(".wav", ".yaapt.f0.npy") + f0 = np.load(f0) + f0 = torch.FloatTensor(f0) + return f0 + + def _preprocess_f0(self, f0, spk): + mask = f0 != -999999 # process all frames + # mask = (f0 != 0) # only process voiced frames + mean = ( + self.speaker_stats[spk].mean_log + if self.f0_log + else self.speaker_stats[spk].mean + ) + std = ( + self.speaker_stats[spk].std_log + if self.f0_log + else self.speaker_stats[spk].std + ) + if self.f0_log: + f0[f0 == 0] = 1e-5 + f0[mask] = f0[mask].log() + if self.f0_norm == "mean": + f0[mask] = f0[mask] - mean + if self.f0_norm == "meanstd": + f0[mask] = (f0[mask] - mean) / std + return f0 + + def _compute_f0_minmax(self): + f0_min, f0_max = float("inf"), -float("inf") + for tsv_line in tqdm(self.tsv, desc="computing f0 minmax"): + spk = self.spk2id[parse_speaker(tsv_line, self.spk)] + f0 = self._load_f0(tsv_line) + f0 = self._preprocess_f0(f0, spk) + f0_min = min(f0_min, f0.min().item()) + f0_max = max(f0_max, f0.max().item()) + return f0_min, f0_max + + def _compute_f0_stats(self): + from functools import partial + + speaker_stats = defaultdict(partial(F0Stat, True)) + for tsv_line in tqdm(self.tsv, desc="computing speaker stats"): + spk = self.spk2id[parse_speaker(tsv_line, self.spk)] + f0 = self._load_f0(tsv_line) + mask = f0 != 0 + f0 = f0[mask] # compute stats only on voiced parts + speaker_stats[spk].update(f0) + return speaker_stats + + def __getitem__(self, i): + x = self.km[i] + x = x.split(" ") + x = list(map(int, x)) + x = torch.LongTensor(x) + + gst = parse_style(self.tsv[i], self.gst) + gst = self.gst2id[gst] + spk = parse_speaker(self.tsv[i], self.spk) + spk = self.spk2id[spk] + + f0_raw = self._load_f0(self.tsv[i]) + f0 = self._preprocess_f0(f0_raw.clone(), spk) + + f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0] + f0_raw = F.interpolate(f0_raw.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0] + + f0 = freq2bin(f0, f0_min=self.f0_min, f0_max=self.f0_max, bins=self.f0_bins) + f0 = F.one_hot(f0.long(), num_classes=len(self.f0_bins) + 1).float() + if self.f0_smoothing > 0: + f0 = torch.tensor( + gaussian_filter1d(f0.float().numpy(), sigma=self.f0_smoothing) + ) + return x, f0, f0_raw, spk, gst + + +def train(cfg): + device = "cuda:0" + # add 1 extra embedding for padding token, set the padding index to be the last token + # (tokens from the clustering start at index 0) + padding_token = cfg.n_tokens + collate_fn = Collator(padding_idx=padding_token) + train_ds = PitchDataset( + cfg.train_tsv, + cfg.train_km, + substring=cfg.substring, + spk=cfg.spk, + spk2id=cfg.spk2id, + gst=cfg.gst, + gst2id=cfg.gst2id, + f0_bins=cfg.f0_bins, + f0_bin_type=cfg.f0_bin_type, + f0_smoothing=cfg.f0_smoothing, + f0_norm=cfg.f0_norm, + f0_log=cfg.f0_log, + ) + valid_ds = PitchDataset( + cfg.valid_tsv, + cfg.valid_km, + substring=cfg.substring, + spk=cfg.spk, + spk2id=cfg.spk2id, + gst=cfg.gst, + gst2id=cfg.gst2id, + f0_bins=cfg.f0_bins, + f0_bin_type=cfg.f0_bin_type, + f0_smoothing=cfg.f0_smoothing, + f0_norm=cfg.f0_norm, + f0_log=cfg.f0_log, + ) + train_dl = DataLoader( + train_ds, + num_workers=0, + batch_size=cfg.batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + valid_dl = DataLoader( + valid_ds, num_workers=0, batch_size=16, shuffle=False, collate_fn=collate_fn + ) + + f0_min = train_ds.f0_min + f0_max = train_ds.f0_max + f0_bins = train_ds.f0_bins + speaker_stats = train_ds.speaker_stats + + model = hydra.utils.instantiate(cfg["model"]).to(device) + model.setup_f0_stats(f0_min, f0_max, f0_bins, speaker_stats) + + optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters()) + + best_loss = float("inf") + for epoch in range(cfg.epochs): + train_loss, train_l2_loss, train_l2_voiced_loss = run_epoch( + model, train_dl, optimizer, device, cfg, mode="train" + ) + valid_loss, valid_l2_loss, valid_l2_voiced_loss = run_epoch( + model, valid_dl, None, device, cfg, mode="valid" + ) + print( + f"[epoch {epoch}] train loss: {train_loss:.3f}, l2 loss: {train_l2_loss:.3f}, l2 voiced loss: {train_l2_voiced_loss:.3f}" + ) + print( + f"[epoch {epoch}] valid loss: {valid_loss:.3f}, l2 loss: {valid_l2_loss:.3f}, l2 voiced loss: {valid_l2_voiced_loss:.3f}" + ) + if valid_l2_voiced_loss < best_loss: + path = f"{os.getcwd()}/pitch_predictor.ckpt" + save_ckpt(model, path, cfg["model"], f0_min, f0_max, f0_bins, speaker_stats) + best_loss = valid_l2_voiced_loss + print(f"saved checkpoint: {path}") + print(f"[epoch {epoch}] best loss: {best_loss:.3f}") + + +def run_epoch(model, loader, optimizer, device, cfg, mode): + if mode == "train": + model.train() + else: + model.eval() + + epoch_loss = 0 + l1 = 0 + l1_voiced = 0 + for x, f0_bin, f0_raw, spk_id, gst, mask, _ in tqdm(loader): + x, f0_bin, f0_raw, spk_id, gst, mask = ( + x.to(device), + f0_bin.to(device), + f0_raw.to(device), + spk_id.to(device), + gst.to(device), + mask.to(device), + ) + b, t, n_bins = f0_bin.shape + yhat = model(x, gst) + nonzero_mask = (f0_raw != 0).logical_and(mask) + yhat_raw = model.inference(x, spk_id, gst) + expanded_mask = mask.unsqueeze(-1).expand(-1, -1, n_bins) + if cfg.f0_pred == "mean": + loss = F.binary_cross_entropy( + yhat[expanded_mask], f0_bin[expanded_mask] + ).mean() + elif cfg.f0_pred == "argmax": + loss = F.cross_entropy( + rearrange(yhat, "b t d -> (b t) d"), + rearrange(f0_bin.argmax(-1), "b t -> (b t)"), + reduce=False, + ) + loss = rearrange(loss, "(b t) -> b t", b=b, t=t) + loss = (loss * mask).sum() / mask.float().sum() + else: + raise NotImplementedError + l1 += F.l1_loss(yhat_raw[mask], f0_raw[mask]).item() + l1_voiced += F.l1_loss(yhat_raw[nonzero_mask], f0_raw[nonzero_mask]).item() + epoch_loss += loss.item() + + if mode == "train": + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + print(f"{mode} example y: {f0_bin.argmax(-1)[0, 50:60].tolist()}") + print(f"{mode} example yhat: {yhat.argmax(-1)[0, 50:60].tolist()}") + print(f"{mode} example y: {f0_raw[0, 50:60].round().tolist()}") + print(f"{mode} example yhat: {yhat_raw[0, 50:60].round().tolist()}") + return epoch_loss / len(loader), l1 / len(loader), l1_voiced / len(loader) + + +@hydra.main(config_path=dir_path, config_name="pitch_predictor.yaml") +def main(cfg): + np.random.seed(1) + random.seed(1) + torch.manual_seed(1) + from hydra.core.hydra_config import HydraConfig + + overrides = { + x.split("=")[0]: x.split("=")[1] + for x in HydraConfig.get().overrides.task + if "/" not in x + } + print(f"{cfg}") + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/examples/emotion_conversion/emotion_models/pitch_predictor.yaml b/examples/emotion_conversion/emotion_models/pitch_predictor.yaml new file mode 100644 index 00000000..d2dbb862 --- /dev/null +++ b/examples/emotion_conversion/emotion_models/pitch_predictor.yaml @@ -0,0 +1,64 @@ +train_tsv: "/denoising/emov/train.tsv" +train_km: "/denoising/emov/train.km" +valid_tsv: "/denoising/emov/valid.tsv" +valid_km: "/denoising/emov/valid.km" + +n_tokens: 200 +batch_size: 64 +lr: 0.0001 +epochs: 1000 + +substring: "" +loss: "l2" +spk: "parent_parent_name" +gst: "emotion" + +f0_bins: 50 +f0_pred: "mean" # [argmax, mean] +f0_smoothing: 0.1 +f0_norm: "mean" +f0_log: false +f0_bin_type: "adaptive" # [uniform, adaptive] + +spk2id: + bea: 0 + jenie: 1 + josh: 2 + sam: 3 + +gst2id: + amused: 0 + angry: 1 + disgusted: 2 + neutral: 3 + sleepy: 4 + +optimizer: + _target_: torch.optim.Adam + lr: ${lr} + +model: + _target_: emotion_models.pitch_predictor.CnnPredictor + n_tokens: ${n_tokens} + emb_dim: 256 + channels: 256 + kernel: 5 + dropout: 0.1 + n_layers: 6 + spk_emb: true + gst_emb: 8 + n_bins: ${f0_bins} + f0_pred: ${f0_pred} + f0_log: ${f0_log} + f0_norm: ${f0_norm} + +hydra: + run: + dir: /checkpoint/felixkreuk/experiments/pitch_predictor/${hydra.job.override_dirname} + job: + config: + # configuration for the ${hydra.job.override_dirname} runtime variable + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: ['train_tsv', 'train_km', 'valid_tsv', 'valid_km'] diff --git a/examples/emotion_conversion/emotion_models/utils.py b/examples/emotion_conversion/emotion_models/utils.py new file mode 100644 index 00000000..4199c310 --- /dev/null +++ b/examples/emotion_conversion/emotion_models/utils.py @@ -0,0 +1,78 @@ +import torch + + +class Stat: + def __init__(self, keep_raw=False): + self.x = 0.0 + self.x2 = 0.0 + self.z = 0.0 # z = logx + self.z2 = 0.0 + self.n = 0.0 + self.u = 0.0 + self.keep_raw = keep_raw + self.raw = [] + + def update(self, new_x): + new_z = new_x.log() + + self.x += new_x.sum() + self.x2 += (new_x**2).sum() + self.z += new_z.sum() + self.z2 += (new_z**2).sum() + self.n += len(new_x) + self.u += 1 + + if self.keep_raw: + self.raw.append(new_x) + + @property + def mean(self): + return self.x / self.n + + @property + def std(self): + return (self.x2 / self.n - self.mean**2) ** 0.5 + + @property + def mean_log(self): + return self.z / self.n + + @property + def std_log(self): + return (self.z2 / self.n - self.mean_log**2) ** 0.5 + + @property + def n_frms(self): + return self.n + + @property + def n_utts(self): + return self.u + + @property + def raw_data(self): + assert self.keep_raw, "does not support storing raw data!" + return torch.cat(self.raw) + + +class F0Stat(Stat): + def update(self, new_x): + # assume unvoiced frames are 0 and consider only voiced frames + if new_x is not None: + super().update(new_x[new_x != 0]) + + +class Accuracy: + def __init__(self): + self.y, self.yhat = [], [] + + def update(self, yhat, y): + self.yhat.append(yhat) + self.y.append(y) + + def acc(self, tol): + yhat = torch.cat(self.yhat) + y = torch.cat(self.y) + acc = torch.abs(yhat - y) <= tol + acc = acc.float().mean().item() + return acc diff --git a/examples/emotion_conversion/fairseq_models/__init__.py b/examples/emotion_conversion/fairseq_models/__init__.py new file mode 100644 index 00000000..441bc03d --- /dev/null +++ b/examples/emotion_conversion/fairseq_models/__init__.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqMultiModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + Embedding, + base_architecture, +) +from fairseq.models.multilingual_transformer import ( + MultilingualTransformerModel, + base_multilingual_architecture, +) +from fairseq.utils import safe_hasattr +from collections import OrderedDict + + +@register_model("multilingual_transformer_from_mbart") +class MultilingualTransformerModelFromMbart(MultilingualTransformerModel): + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + from fairseq.tasks.multilingual_translation import MultilingualTranslationTask + + assert isinstance(task, MultilingualTranslationTask) + + # make sure all arguments are present in older models + base_multilingual_architecture(args) + + if not safe_hasattr(args, "max_source_positions"): + args.max_source_positions = 1024 + if not safe_hasattr(args, "max_target_positions"): + args.max_target_positions = 1024 + + src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] + tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] + + if args.share_encoders: + args.share_encoder_embeddings = True + if args.share_decoders: + args.share_decoder_embeddings = True + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + # build shared embeddings (if applicable) + shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None + if args.share_all_embeddings: + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=task.langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, + ) + shared_decoder_embed_tokens = shared_encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + if args.share_encoder_embeddings: + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=src_langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, + ) + if args.share_decoder_embeddings: + shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=tgt_langs, + embed_dim=args.decoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.decoder_embed_path, + ) + + # encoders/decoders for each language + lang_encoders, lang_decoders = {}, {} + + def get_encoder(lang): + if lang not in lang_encoders: + if shared_encoder_embed_tokens is not None: + encoder_embed_tokens = shared_encoder_embed_tokens + else: + encoder_embed_tokens = build_embedding( + task.dicts[lang], + args.encoder_embed_dim, + args.encoder_embed_path, + ) + lang_encoders[lang] = MultilingualTransformerModel._get_module_class( + True, args, task.dicts[lang], encoder_embed_tokens, src_langs + ) + return lang_encoders[lang] + + def get_decoder(lang): + if lang not in lang_decoders: + if shared_decoder_embed_tokens is not None: + decoder_embed_tokens = shared_decoder_embed_tokens + else: + decoder_embed_tokens = build_embedding( + task.dicts[lang], + args.decoder_embed_dim, + args.decoder_embed_path, + ) + lang_decoders[lang] = MultilingualTransformerModel._get_module_class( + False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs + ) + return lang_decoders[lang] + + # shared encoders/decoders (if applicable) + shared_encoder, shared_decoder = None, None + if args.share_encoders: + shared_encoder = get_encoder(src_langs[0]) + if args.share_decoders: + shared_decoder = get_decoder(tgt_langs[0]) + + encoders, decoders = OrderedDict(), OrderedDict() + for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): + encoders[lang_pair] = ( + shared_encoder if shared_encoder is not None else get_encoder(src) + ) + decoders[lang_pair] = ( + shared_decoder if shared_decoder is not None else get_decoder(tgt) + ) + + return MultilingualTransformerModelFromMbart(encoders, decoders) + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + state_dict_subset = state_dict.copy() + lang_pairs = set([x.split(".")[1] for x in state_dict.keys()]) + finetune_mode = not any("neutral" in lp for lp in lang_pairs) + + if finetune_mode: + # load a pre-trained mBART/BART model + # we need this code because mBART/BART are not of type FairseqMultiModel but FairseqModel + # so we hackishly load the weights by replicating them for all lang pairs + print("loading pre-trained BART") + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + for lang_pair in self.models: + new_key = k if "models." in k else f"models.{lang_pair}.{k}" + # print(new_key) + if self_state_dict[new_key].shape == v.shape: + state_dict_subset[new_key] = v + elif any( + w in k + for w in [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "decoder.output_projection.weight", + ] + ): + # why vocab_size - 5? because there are `vocab_size` tokens from the language + # and 5 additional tokens in the denoising task: eos,bos,pad,unk,mask. + # but in the translation task there are only `vocab_size` + 4 (no mask). + print( + f"{k}: {self_state_dict[new_key].shape} != {v.shape}", + end="", + flush=True, + ) + vocab_size = v.shape[0] - 5 + state_dict_subset[new_key] = self_state_dict[new_key] + state_dict_subset[new_key] = v[: vocab_size + 4] + print(f" => fixed by using first {vocab_size + 4} dims") + else: + raise ValueError("unable to load model due to mimatched dims!") + del state_dict_subset[k] + else: + print("loading pre-trained emotion translation model") + for k, _ in state_dict.items(): + assert k.startswith("models.") + lang_pair = k.split(".")[1] + if lang_pair not in self.models: + del state_dict_subset[k] + + super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) + + +@register_model_architecture("transformer", "transformer_small") +def transformer_small(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 3) + base_architecture(args) + + +@register_model_architecture( + "multilingual_transformer_from_mbart", "multilingual_small" +) +def multilingual_small(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 3) + base_multilingual_architecture(args) diff --git a/examples/emotion_conversion/preprocess/__init__.py b/examples/emotion_conversion/preprocess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/emotion_conversion/preprocess/build_hifigan_manifest.py b/examples/emotion_conversion/preprocess/build_hifigan_manifest.py new file mode 100644 index 00000000..29c0d79c --- /dev/null +++ b/examples/emotion_conversion/preprocess/build_hifigan_manifest.py @@ -0,0 +1,38 @@ +import torchaudio +import argparse +import json + +def main(): + parser = argparse.ArgumentParser(description="example: python create_hifigan_manifest.py --tsv /checkpoint/felixkreuk/datasets/vctk/splits/vctk_16khz/train.tsv --km /checkpoint/felixkreuk/experiments/hubert/hubert_feats/vctk_16khz_km_100/train.km --km_type hubert_100km > ~/tmp/tmp_mani.txt") + parser.add_argument("--tsv", required=True, help="path to fairseq tsv file") + parser.add_argument("--km", required=True, help="path to a km file generated by HuBERT clustering") + parser.add_argument("--km_type", required=True, help="name of the codes in the output json (for example: 'cpc_100km')") + args = parser.parse_args() + + km_lines = open(args.km, "r").readlines() + tsv_lines = open(args.tsv, "r").readlines() + assert len(km_lines) == len(tsv_lines) - 1, "tsv and km files are not of the same length!" + + wav_root = tsv_lines[0].strip() + tsv_lines = tsv_lines[1:] + + for tsv_line, km_line in zip(tsv_lines, km_lines): + tsv_line, km_line = tsv_line.strip(), km_line.strip() + wav_basename, wav_num_frames = tsv_line.split("\t") + wav_path = wav_root + "/" + wav_basename + wav_info = torchaudio.info(wav_path) + assert int(wav_num_frames) == wav_info.num_frames, "tsv duration and actual duration don't match!" + wav_duration = wav_info.num_frames / wav_info.sample_rate + manifest_line = {"audio": wav_path, "duration": wav_duration, args.km_type: km_line} + print(json.dumps(manifest_line)) + +if __name__ == "__main__": + """ + usage: + python create_hifigan_manifest.py \ + --tsv /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/valid.tsv \ + --km /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/valid.km \ + --km_type hubert \ + > /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/hifigan_valid_manifest.txt + """ + main() diff --git a/examples/emotion_conversion/preprocess/build_translation_manifests.py b/examples/emotion_conversion/preprocess/build_translation_manifests.py new file mode 100644 index 00000000..d38454a7 --- /dev/null +++ b/examples/emotion_conversion/preprocess/build_translation_manifests.py @@ -0,0 +1,258 @@ +from glob import glob +import argparse +from collections import defaultdict, Counter +from itertools import combinations, product, groupby +from pathlib import Path +import os +from sklearn.utils import shuffle +import numpy as np +import random +from shutil import copy +from subprocess import check_call + +np.random.seed(42) +random.seed(42) + + +def get_fname(s): + return s.split("\t")[0] + +def get_emotion(s): + return get_fname(s).split("_")[0].split("/")[1].lower() + +def get_utt_id(s): + return get_fname(s).split(".")[0].split("_")[-1] + +def dedup(seq): + """ >> remove_repetitions("1 2 2 3 100 2 2 1") + '1 2 3 100 2 1' """ + seq = seq.strip().split(" ") + result = seq[:1] + reps = [] + rep_counter = 1 + for k in seq[1:]: + if k != result[-1]: + result += [k] + reps += [rep_counter] + rep_counter = 1 + else: + rep_counter += 1 + reps += [rep_counter] + assert len(reps) == len(result) and sum(reps) == len(seq) + return " ".join(result) + "\n" #, reps + +def remove_under_k(seq, k): + """ remove tokens that repeat less then k times in a row + >> remove_under_k("a a a a b c c c", 1) ==> a a a a c c c """ + seq = seq.strip().split(" ") + result = [] + + freqs = [(k,len(list(g))) for k, g in groupby(seq)] + for c, f in freqs: + if f > k: + result += [c for _ in range(f)] + return " ".join(result) + "\n" #, reps + + +def call(cmd): + print(cmd) + check_call(cmd, shell=True) + + +def denoising_preprocess(path, lang, dict): + bin = 'fairseq-preprocess' + cmd = [ + bin, + f'--trainpref {path}/train.{lang} --validpref {path}/valid.{lang} --testpref {path}/test.{lang}', + f'--destdir {path}/tokenized/{lang}', + '--only-source', + '--task multilingual_denoising', + '--workers 40', + ] + if dict != "": + cmd += [f'--srcdict {dict}'] + cmd = " ".join(cmd) + call(cmd) + + +def translation_preprocess(path, src_lang, trg_lang, dict, only_train=False): + bin = 'fairseq-preprocess' + cmd = [ + bin, + f'--source-lang {src_lang} --target-lang {trg_lang}', + f'--trainpref {path}/train', + f'--destdir {path}/tokenized', + '--workers 40', + ] + if not only_train: + cmd += [f'--validpref {path}/valid --testpref {path}/test'] + if dict != "": + cmd += [ + f'--srcdict {dict}', + f'--tgtdict {dict}', + ] + cmd = " ".join(cmd) + call(cmd) + + +def load_tsv_km(tsv_path, km_path): + assert tsv_path.exists() and km_path.exists() + tsv_lines = open(tsv_path, "r").readlines() + root, tsv_lines = tsv_lines[0], tsv_lines[1:] + km_lines = open(km_path, "r").readlines() + assert len(tsv_lines) == len(km_lines), ".tsv and .km should be the same length!" + return root, tsv_lines, km_lines + + +def main(): + desc = """ + this script takes as input .tsv and .km files for EMOV dataset, and a pairs of emotions. + it generates parallel .tsv and .km files for these emotions. for exmaple: + ❯ python build_emov_translation_manifests.py \ + /checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/train.tsv \ + /checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/emov_16khz_km_100/train.km \ + ~/tmp/emov_pairs \ + --src-emotion amused --trg-emotion neutral \ + --dedup --shuffle --cross-speaker --dry-run + """ + parser = argparse.ArgumentParser(description=desc) + parser.add_argument("data", type=Path, help="path to a dir containing .tsv and .km files containing emov dataset") + parser.add_argument("output_path", type=Path, help="output directory with the manifests will be created") + parser.add_argument("-cs", "--cross-speaker", action='store_true', help="if set then translation will occur also between speakers, meaning the same sentence can be translated between different speakers (default: false)") + parser.add_argument("-dd", "--dedup", action='store_true', help="remove repeated tokens (example: 'aaabc=>abc')") + parser.add_argument("-sh", "--shuffle", action='store_true', help="shuffle the data") + parser.add_argument("-ae", "--autoencode", action='store_true', help="include training pairs from the same emotion (this includes examples of the same sentence uttered by different people and examples where the src and trg are the exact same seq)") + parser.add_argument("-dr", "--dry-run", action='store_true', help="don't write anything to disk") + parser.add_argument("-zs", "--zero-shot", action='store_true', help="if true, the denoising task will train on the same splits as the translation task (split by utterance id). if false, the denoising task will train on randomly sampled splits (not split by utterance id)") + parser.add_argument("--km-ext", default="km", help="") + parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt", help="") + args = parser.parse_args() + SPEAKERS = ["bea", "jenie", "josh", "sam", "SAME"] + EMOTIONS = ['neutral', 'amused', 'angry', 'disgusted', 'sleepy'] + + suffix = "" + if args.cross_speaker: suffix += "_cross-speaker" + if args.dedup: suffix += "_dedup" + translation_suffix = "" + if args.autoencode: translation_suffix += "_autoencode" + denoising_suffix = "" + denoising_suffix += "_zeroshot" if args.zero_shot else "_nonzeroshot" + + translation_dir = Path(args.output_path) / ("emov_multilingual_translation" + suffix + translation_suffix) + os.makedirs(translation_dir, exist_ok=True) + denoising_dir = Path(args.output_path) / ("emov_multilingual_denoising" + suffix + denoising_suffix) + os.makedirs(denoising_dir, exist_ok=True) + + denoising_data = [p.name for p in (args.data / "denoising").glob("*") if "emov" not in p.name] + + for split in ["train", "valid", "test"]: + root, tsv_lines, km_lines = load_tsv_km( + tsv_path = args.data / "denoising" / "emov" / f"{split}.tsv", + km_path = args.data / "denoising" / "emov" / f"{split}.{args.km_ext}" + ) + + # generate data for the multilingual denoising task + for EMOTION in EMOTIONS: + print("---") + print(split) + print(f"denoising: {EMOTION}") + emotion_tsv, emotion_km = [], [] + for tsv_line, km_line in zip(tsv_lines, km_lines): + if EMOTION.lower() in tsv_line.lower(): + km_line = km_line if not args.dedup else dedup(km_line) + emotion_tsv.append(tsv_line) + emotion_km.append(km_line) + print(f"{len(emotion_km)} samples") + open(denoising_dir / f"files.{split}.{EMOTION}", "w").writelines([root] + emotion_tsv) + open(denoising_dir / f"{split}.{EMOTION}", "w").writelines(emotion_km) + + for data in denoising_data: + with open(args.data / "denoising" / data / f"{split}.{args.km_ext}", "r") as f1: + with open(denoising_dir / f"{split}.{data}", "w") as f2: + f2.writelines([l if not args.dedup else dedup(l) for l in f1.readlines()]) + + # start of translation preprocessing + root, tsv_lines, km_lines = load_tsv_km( + tsv_path = args.data / "translation" / f"{split}.tsv", + km_path = args.data / "translation" / f"{split}.{args.km_ext}" + ) + + # generate data for the multilingual translation task + for SRC_EMOTION in EMOTIONS: + TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION]) + for TRG_EMOTION in TRG_EMOTIONS: + # when translating back to the same emotion - we dont want these emotion + # pairs to be part of the validation/test sets (because its not really emotion conversino) + # if SRC_EMOTION == TRG_EMOTION and split in ["valid", "test"]: continue + print("---") + print(split) + print(f"src emotions: {SRC_EMOTION}\ntrg emotions: {TRG_EMOTION}") + + # create a dictionary with the following structure: + # output[SPEAKER][UTT_ID] = list with indexes of line from the tsv file + # that match the speaker and utterance id. for exmaple: + # output = {'sam': {'0493': [875, 1608, 1822], ...}, ...} + # meaning, for speaker 'sam', utterance id '0493', the indexes in tsv_lines + # are 875, 1608, 1822 + spkr2utts = defaultdict(lambda: defaultdict(list)) + for i, tsv_line in enumerate(tsv_lines): + speaker = tsv_line.split("/")[0] + if args.cross_speaker: speaker = "SAME" + assert speaker in SPEAKERS, "unknown speaker! make sure the .tsv contains EMOV data" + utt_id = get_utt_id(tsv_line) + spkr2utts[speaker][utt_id].append(i) + + # create a tsv and km files with all the combinations for translation + src_tsv, trg_tsv, src_km, trg_km = [], [], [], [] + for speaker, utt_ids in spkr2utts.items(): + for utt_id, indices in utt_ids.items(): + # generate all pairs + pairs = [(x,y) for x in indices for y in indices] + # self-translation + if SRC_EMOTION == TRG_EMOTION: + pairs = [(x,y) for (x,y) in pairs if x == y] + # filter according to src and trg emotions + pairs = [(x,y) for (x,y) in pairs + if get_emotion(tsv_lines[x]) == SRC_EMOTION and get_emotion(tsv_lines[y]) == TRG_EMOTION] + + for idx1, idx2 in pairs: + assert get_utt_id(tsv_lines[idx1]) == get_utt_id(tsv_lines[idx2]) + src_tsv.append(tsv_lines[idx1]) + trg_tsv.append(tsv_lines[idx2]) + km_line_idx1 = km_lines[idx1] + km_line_idx2 = km_lines[idx2] + km_line_idx1 = km_line_idx1 if not args.dedup else dedup(km_line_idx1) + km_line_idx2 = km_line_idx2 if not args.dedup else dedup(km_line_idx2) + src_km.append(km_line_idx1) + trg_km.append(km_line_idx2) + assert len(src_tsv) == len(trg_tsv) == len(src_km) == len(trg_km) + print(f"{len(src_tsv)} pairs") + + if len(src_tsv) == 0: + raise Exception("ERROR: generated 0 pairs!") + + if args.dry_run: continue + + # create files + os.makedirs(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", exist_ok=True) + open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{SRC_EMOTION}", "w").writelines([root] + src_tsv) + open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{TRG_EMOTION}", "w").writelines([root] + trg_tsv) + open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{SRC_EMOTION}", "w").writelines(src_km) + open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{TRG_EMOTION}", "w").writelines(trg_km) + + + # fairseq-preprocess the denoising data + for EMOTION in EMOTIONS + denoising_data: + denoising_preprocess(denoising_dir, EMOTION, args.dict) + os.system(f"cp {args.dict} {denoising_dir}/tokenized/dict.txt") + + # fairseq-preprocess the translation data + os.makedirs(translation_dir / "tokenized", exist_ok=True) + for SRC_EMOTION in EMOTIONS: + TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION]) + for TRG_EMOTION in TRG_EMOTIONS: + translation_preprocess(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", SRC_EMOTION, TRG_EMOTION, args.dict)#, only_train=SRC_EMOTION==TRG_EMOTION) + os.system(f"cp -rf {translation_dir}/**/tokenized/* {translation_dir}/tokenized") + +if __name__ == "__main__": + main() diff --git a/examples/emotion_conversion/preprocess/create_core_manifest.py b/examples/emotion_conversion/preprocess/create_core_manifest.py new file mode 100644 index 00000000..b55740e0 --- /dev/null +++ b/examples/emotion_conversion/preprocess/create_core_manifest.py @@ -0,0 +1,91 @@ +from pathlib import Path +import os +import sys +import subprocess +import argparse +from datetime import datetime +import logging + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[logging.FileHandler('debug.log'), logging.StreamHandler()] +) +logger = logging.getLogger(__name__) + + +def verify_dict_size(km, dict): + logger.info(f"verifying: {km}") + dict_size = len(open(dict, "r").readlines()) + km_vocab = set(open(km, "r").read().replace("\n", " ").split(" ")) + if "" in km_vocab: km_vocab.remove("") + km_vocab_size = len(km_vocab) + return dict_size == km_vocab_size + + +def verify_files_exist(l): + for f in l: + if not f.exists(): + logging.error(f"{f} doesn't exist!") + return False + return True + + +def run_cmd(cmd, print_output=True): + try: + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True, shell=True) + if print_output: + logger.info(f"command output:\n{out}") + return out + except subprocess.CalledProcessError as grepexc: + logger.info(f"error executing command!:\n{cmd}") + logger.info(grepexc.output) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tsv", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/data.tsv", type=Path) + parser.add_argument("--emov-km", required=True, type=Path) + parser.add_argument("--km", nargs='+', required=True, type=Path) + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt") + parser.add_argument("--manifests-dir", type=Path, default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz") + args = parser.parse_args() + + manifests_dir = args.manifests_dir + date = datetime.now().strftime('%d%m%y') + outdir = manifests_dir / f"{date}" + + # verify input and create folders + all_kms = args.km + [args.emov_km] + assert verify_files_exist(all_kms), "make sure the km dir contains: train-clean-all.km, blizzard2013.km, data.km" + for codes in all_kms: + assert verify_dict_size(codes, args.dict), "dict argument doesn't match the vocabulary of the km file!" + assert not outdir.exists(), "data dir already exists!" + outdir.mkdir(parents=True, exist_ok=True) + + logger.info("generating denoising split (emov)") + run_cmd(f"python preprocess/split_km_tsv.py {args.tsv} {args.emov_km} --destdir {outdir}/denoising/emov -sh --seed {args.seed}") + for codes in args.km: + codes_name = os.path.basename(codes) + run_cmd(f"python preprocess/split_km.py {codes} --destdir {outdir}/denoising/{codes_name} -sh --seed {args.seed}") + + logger.info("generating translation split") + run_cmd(f"python preprocess/split_emov_km_tsv_by_uttid.py {args.tsv} {args.emov_km} --destdir {outdir}/translation --seed {args.seed}") + + emov_code_name = os.path.basename(args.emov_km) + logger.info("generating hifigan split") + run_cmd( + f"mkdir -p {outdir}/hifigan &&" + f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/train.tsv --km {outdir}/denoising/emov/train.km > {outdir}/hifigan/train.txt &&" + f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/valid.tsv --km {outdir}/denoising/emov/valid.km > {outdir}/hifigan/valid.txt &&" + f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/test.tsv --km {outdir}/denoising/emov/test.km > {outdir}/hifigan/test.txt" + ) + + logger.info("generating fairseq manifests") + run_cmd(f"python preprocess/build_translation_manifests.py {outdir} {outdir}/fairseq-data -dd -cs --dict {args.dict}") + + logger.info(f"finished processing data at:\n{outdir}") + + +if __name__ == "__main__": + main() diff --git a/examples/emotion_conversion/preprocess/extract_f0.py b/examples/emotion_conversion/preprocess/extract_f0.py new file mode 100644 index 00000000..4204aa4d --- /dev/null +++ b/examples/emotion_conversion/preprocess/extract_f0.py @@ -0,0 +1,57 @@ +import argparse +from tqdm import tqdm +from multiprocessing import Manager, Pool + +from scipy.io.wavfile import read +from librosa.util import normalize +import numpy as np +import amfm_decompy.pYAAPT as pYAAPT +import amfm_decompy.basic_tools as basic + +MAX_WAV_VALUE = 32768.0 + +parser = argparse.ArgumentParser(description="") +parser.add_argument("tsv", help="") +parser.add_argument("--extractor", choices=["crepe", "pyaapt"], default="pyaapt", help="") +parser.add_argument("--interp", action="store_true", help="") +parser.add_argument("--n_workers", type=int, default=40, help="") +args = parser.parse_args() + +tsv_lines = open(args.tsv, "r").readlines() +root, tsv_lines = tsv_lines[0].strip(), tsv_lines[1:] + + +def extract_f0(tsv_line): + wav_path, _ = tsv_line.split("\t") + wav_path = root.strip() + "/" + wav_path + sr, wav = read(wav_path) + wav = wav / MAX_WAV_VALUE + wav = normalize(wav) * 0.95 + + if args.extractor == "pyaapt": + frame_length = 20.0 + pad = int(frame_length / 1000 * sr) // 2 + wav = np.pad(wav.squeeze(), (pad, pad), "constant", constant_values=0) + signal = basic.SignalObj(wav, sr) + pitch = pYAAPT.yaapt( + signal, + **{ + 'frame_length': frame_length, + 'frame_space': 5.0, + 'nccf_thresh1': 0.25, + 'tda_frame_length': 25.0 + }) + pitch = pitch.samp_interp[None, None, :] if args.interp else pitch.samp_values[None, None, :] + pitch = pitch[0, 0] + f0_path = wav_path.replace(".wav", ".yaapt") + f0_path += ".interp.f0" if args.interp else ".f0" + np.save(f0_path, pitch) + + +def main(): + with Pool(args.n_workers) as p: + r = list(tqdm(p.imap(extract_f0, tsv_lines), total=len(tsv_lines))) + + +if __name__ == "__main__": + main() diff --git a/examples/emotion_conversion/preprocess/process_km.py b/examples/emotion_conversion/preprocess/process_km.py new file mode 100644 index 00000000..864a0221 --- /dev/null +++ b/examples/emotion_conversion/preprocess/process_km.py @@ -0,0 +1,40 @@ +import sys +import argparse +from tqdm import tqdm +from build_emov_translation_manifests import dedup, remove_under_k + + +if __name__ == "__main__": + """ + this is a standalone script to process a km file + specifically, to dedup or remove tokens that repeat less + than k times in a row + """ + parser = argparse.ArgumentParser(description="") + parser.add_argument("km", type=str, help="path to km file") + parser.add_argument("--dedup", action='store_true') + parser.add_argument("--remove-under-k", type=int, default=0) + parser.add_argument("--output", default=None) + args = parser.parse_args() + + if not args.dedup and args.remove_under_k == 0: + print("nothing to do! quitting...") + sys.exit(0) + + km = open(args.km, "r").readlines() + out = [] + for line in tqdm(km): + if args.remove_under_k > 0: + line = remove_under_k(line, args.remove_under_k) + if args.dedup: + line = dedup(line) + out.append(line) + + path = args.km if args.output is None else args.output + if args.remove_under_k > 0: + path = path.replace(".km", f"-k{args.remove_under_k}.km") + if args.dedup: + path = path.replace(".km", f"-deduped.km") + + open(path, "w").writelines(out) + print(f"written to {path}") diff --git a/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py b/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py new file mode 100644 index 00000000..94221afb --- /dev/null +++ b/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py @@ -0,0 +1,70 @@ +from pathlib import Path +import os +import sys +import argparse +import random +import numpy as np +from tqdm import tqdm +from sklearn.model_selection import train_test_split +from build_translation_manifests import get_utt_id + + +def train_val_test_split(tsv_lines, km_lines, valid_percent, test_percent, seed=42): + utt_ids = list(sorted(set([get_utt_id(x) for x in tsv_lines]))) + utt_ids, valid_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=valid_percent, shuffle=True, random_state=seed) + train_utt_ids, test_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=test_percent, shuffle=True, random_state=seed) + + train_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in train_utt_ids] + valid_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in valid_utt_ids] + test_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in test_utt_ids] + + train_tsv, train_km = [tsv_lines[i] for i in train_idx], [km_lines[i] for i in train_idx] + valid_tsv, valid_km = [tsv_lines[i] for i in valid_idx], [km_lines[i] for i in valid_idx] + test_tsv, test_km = [tsv_lines[i] for i in test_idx], [km_lines[i] for i in test_idx] + + print(f"train {len(train_km)}") + print(f"valid {len(valid_km)}") + print(f"test {len(test_km)}") + + return train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km + + +if __name__ == "__main__": + """ + this is a standalone script to process a km file + specifically, to dedup or remove tokens that repeat less + than k times in a row + """ + parser = argparse.ArgumentParser(description="") + parser.add_argument("tsv", type=str, help="path to tsv file") + parser.add_argument("km", type=str, help="path to km file") + parser.add_argument("--destdir", required=True, type=str) + parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set") + parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set") + parser.add_argument("--seed", type=int, default=42, help="") + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + os.makedirs(args.destdir, exist_ok=True) + km = open(args.km, "r").readlines() + tsv = open(args.tsv, "r").readlines() + root, tsv = tsv[0], tsv[1:] + + assert args.tsv.endswith(".tsv") and args.km.endswith(".km") + assert len(tsv) == len(km) + + train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km = train_val_test_split(tsv, km, args.valid_percent, args.test_percent, args.seed) + + assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv) + assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km) + + dir = Path(args.destdir) + open(dir / f"train.tsv", "w").writelines([root] + train_tsv) + open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv) + open(dir / f"test.tsv", "w").writelines([root] + test_tsv) + open(dir / f"train.km", "w").writelines(train_km) + open(dir / f"valid.km", "w").writelines(valid_km) + open(dir / f"test.km", "w").writelines(test_km) + print("done") diff --git a/examples/emotion_conversion/preprocess/split_km.py b/examples/emotion_conversion/preprocess/split_km.py new file mode 100644 index 00000000..d145fc2b --- /dev/null +++ b/examples/emotion_conversion/preprocess/split_km.py @@ -0,0 +1,50 @@ +from pathlib import Path +import os +import argparse +import random +import numpy as np +from sklearn.utils import shuffle + + +if __name__ == "__main__": + """ + this is a standalone script to process a km file + specifically, to dedup or remove tokens that repeat less + than k times in a row + """ + parser = argparse.ArgumentParser(description="") + parser.add_argument("km", type=str, help="path to km file") + parser.add_argument("--destdir", required=True, type=str) + parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set") + parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set") + parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file") + parser.add_argument("--seed", type=int, default=42, help="") + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + os.makedirs(args.destdir, exist_ok=True) + km = open(args.km, "r").readlines() + + if args.shuffle: + km = shuffle(km) + print(f"shuffled") + + N = len(km) + N_tt = int(N * args.test_percent) + N_cv = int(N * args.valid_percent) + N_tr = N - N_tt - N_cv + + train_km = km[:N_tr] + valid_km = km[N_tr:N_tr + N_cv] + test_km = km[N_tr + N_cv:] + + dir = Path(args.destdir) + open(dir / f"train.km", "w").writelines(train_km) + open(dir / f"valid.km", "w").writelines(valid_km) + open(dir / f"test.km", "w").writelines(test_km) + print(f"train: {len(train_km)}") + print(f"valid: {len(valid_km)}") + print(f"test: {len(test_km)}") + print("done") diff --git a/examples/emotion_conversion/preprocess/split_km_tsv.py b/examples/emotion_conversion/preprocess/split_km_tsv.py new file mode 100644 index 00000000..2113aa71 --- /dev/null +++ b/examples/emotion_conversion/preprocess/split_km_tsv.py @@ -0,0 +1,65 @@ +from pathlib import Path +import os +import argparse +import random +import numpy as np +from sklearn.utils import shuffle + + +if __name__ == "__main__": + """ + this is a standalone script to process a km file + specifically, to dedup or remove tokens that repeat less + than k times in a row + """ + parser = argparse.ArgumentParser(description="") + parser.add_argument("tsv", type=str, help="path to tsv file") + parser.add_argument("km", type=str, help="path to km file") + parser.add_argument("--destdir", required=True, type=str) + parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set") + parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set") + parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file") + parser.add_argument("--seed", type=int, default=42, help="") + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + os.makedirs(args.destdir, exist_ok=True) + km = open(args.km, "r").readlines() + tsv = open(args.tsv, "r").readlines() + root, tsv = tsv[0], tsv[1:] + + assert args.tsv.endswith(".tsv") and args.km.endswith(".km") + assert len(tsv) == len(km) + + if args.shuffle: + tsv, km = shuffle(tsv, km) + print(f"shuffled") + + N = len(tsv) + N_tt = int(N * args.test_percent) + N_cv = int(N * args.valid_percent) + N_tr = N - N_tt - N_cv + + train_tsv = tsv[:N_tr] + valid_tsv = tsv[N_tr:N_tr + N_cv] + test_tsv = tsv[N_tr + N_cv:] + train_km = km[:N_tr] + valid_km = km[N_tr:N_tr + N_cv] + test_km = km[N_tr + N_cv:] + + assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv) + assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km) + + dir = Path(args.destdir) + open(dir / f"train.tsv", "w").writelines([root] + train_tsv) + open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv) + open(dir / f"test.tsv", "w").writelines([root] + test_tsv) + open(dir / f"train.km", "w").writelines(train_km) + open(dir / f"valid.km", "w").writelines(valid_km) + open(dir / f"test.km", "w").writelines(test_km) + print(f"train: {len(train_km)}") + print(f"valid: {len(valid_km)}") + print(f"test: {len(test_km)}") + print("done") diff --git a/examples/emotion_conversion/requirements.txt b/examples/emotion_conversion/requirements.txt new file mode 100644 index 00000000..fc94c5a5 --- /dev/null +++ b/examples/emotion_conversion/requirements.txt @@ -0,0 +1,11 @@ +scipy +einops +amfm_decompy +joblib +numba +decorator +requests +appdirs +packaging +six +sklearn diff --git a/examples/emotion_conversion/synthesize.py b/examples/emotion_conversion/synthesize.py new file mode 100644 index 00000000..327fdaf4 --- /dev/null +++ b/examples/emotion_conversion/synthesize.py @@ -0,0 +1,322 @@ +import logging +import argparse +import random +import sys +import os +import numpy as np +import torch +import soundfile as sf +import shutil +import librosa +import json +from pathlib import Path +from tqdm import tqdm +import amfm_decompy.basic_tools as basic +import amfm_decompy.pYAAPT as pYAAPT + +dir_path = os.path.dirname(__file__) +resynth_path = os.path.dirname(os.path.abspath(__file__)) + "/speech-resynthesis" +sys.path.append(resynth_path) + +from models import CodeGenerator +from inference import scan_checkpoint, load_checkpoint, generate +from emotion_models.pitch_predictor import load_ckpt as load_pitch_predictor +from emotion_models.duration_predictor import load_ckpt as load_duration_predictor +from dataset import load_audio, MAX_WAV_VALUE, parse_style, parse_speaker, EMOV_SPK2ID, EMOV_STYLE2ID + + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[logging.FileHandler('debug.log'), logging.StreamHandler()] +) +logger = logging.getLogger(__name__) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def parse_generation_file(fname): + lines = open(fname).read() + lines = lines.split('\n') + + results = {} + for l in lines: + if len(l) == 0: + continue + + if l[0] == 'H': + parts = l[2:].split('\t') + if len(parts) == 2: + sid, utt = parts + else: + sid, _, utt = parts + sid = int(sid) + utt = [int(x) for x in utt.split()] + if sid in results: + results[sid]['H'] = utt + else: + results[sid] = {'H': utt} + elif l[0] == 'S': + sid, utt = l[2:].split('\t') + sid = int(sid) + utt = [x for x in utt.split()] + if sid in results: + results[sid]['S'] = utt + else: + results[sid] = {'S': utt} + elif l[0] == 'T': + sid, utt = l[2:].split('\t') + sid = int(sid) + utt = [int(x) for x in utt.split()] + if sid in results: + results[sid]['T'] = utt + else: + results[sid] = {'T': utt} + + for d, result in results.items(): + if 'H' not in result: + result['H'] = result['S'] + + return results + + +def get_code_to_fname(manifest, tokens): + if tokens is None: + code_to_fname = {} + with open(manifest) as f: + for line in f: + line = line.strip() + fname, code = line.split() + code = code.replace(',', ' ') + code_to_fname[code] = fname + + return code_to_fname + + with open(manifest) as f: + fnames = [l.strip() for l in f.readlines()] + root = Path(fnames[0]) + fnames = fnames[1:] + if '\t' in fnames[0]: + fnames = [x.split()[0] for x in fnames] + + with open(tokens) as f: + codes = [l.strip() for l in f.readlines()] + + code_to_fname = {} + for fname, code in zip(fnames, codes): + code = code.replace(',', ' ') + code_to_fname[code] = str(root / fname) + + return root, code_to_fname + + +def code_to_str(s): + k = ' '.join([str(x) for x in s]) + return k + + +def get_praat_f0(audio, rate=16000, interp=False): + frame_length = 20.0 + to_pad = int(frame_length / 1000 * rate) // 2 + + f0s = [] + for y in audio.astype(np.float64): + y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0) + signal = basic.SignalObj(y_pad, rate) + pitch = pYAAPT.yaapt(signal, **{'frame_length': frame_length, 'frame_space': 5.0, 'nccf_thresh1': 0.25, + 'tda_frame_length': 25.0}) + if interp: + f0s += [pitch.samp_interp[None, None, :]] + else: + f0s += [pitch.samp_values[None, None, :]] + + f0 = np.vstack(f0s) + return f0 + + +def generate_from_code(generator, h, code, spkr=None, f0=None, gst=None, device="cpu"): + batch = { + 'code': torch.LongTensor(code).to(device).view(1, -1), + } + if spkr is not None: + batch['spkr'] = spkr.to(device).unsqueeze(0) + if f0 is not None: + batch['f0'] = f0.to(device) + if gst is not None: + batch['style'] = gst.to(device) + + with torch.no_grad(): + audio, rtf = generate(h, generator, batch) + audio = librosa.util.normalize(audio / 2 ** 15) + + return audio + + +@torch.no_grad() +def synth(argv, interactive=False): + parser = argparse.ArgumentParser() + parser.add_argument('--result-path', type=Path, help='Translation Model Output', required=True) + parser.add_argument('--data', type=Path, help='a directory with the files: src.tsv, src.km, trg.tsv, trg.km, orig.tsv, orig.km') + parser.add_argument("--orig-tsv", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/data.tsv") + parser.add_argument("--orig-km", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/core_manifests/emov_16khz_km_100/data.km") + + parser.add_argument('--checkpoint-file', type=Path, help='Generator Checkpoint', required=True) + parser.add_argument('--dur-model', type=Path, help='a token duration prediction model (if tokens were deduped)') + parser.add_argument('--f0-model', type=Path, help='a f0 prediction model') + + parser.add_argument('-s', '--src-emotion', default=None) + parser.add_argument('-t', '--trg-emotion', default=None) + parser.add_argument('-N', type=int, default=10) + parser.add_argument('--split', default="test") + + parser.add_argument('--outdir', type=Path, default=Path('results')) + parser.add_argument('--orig-filename', action='store_true') + + parser.add_argument('--device', type=int, default=0) + a = parser.parse_args(argv) + + seed = 52 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if os.path.isdir(a.checkpoint_file): + config_file = os.path.join(a.checkpoint_file, 'config.json') + else: + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') + with open(config_file) as f: + data = f.read() + json_config = json.loads(data) + h = AttrDict(json_config) + + generator = CodeGenerator(h).to(a.device) + if os.path.isdir(a.checkpoint_file): + cp_g = scan_checkpoint(a.checkpoint_file, 'g_') + else: + cp_g = a.checkpoint_file + state_dict_g = load_checkpoint(cp_g) + generator.load_state_dict(state_dict_g['generator']) + + generator.eval() + generator.remove_weight_norm() + + dur_models = { + "neutral": load_duration_predictor(f"{a.dur_model}/neutral.ckpt"), + "amused": load_duration_predictor(f"{a.dur_model}/amused.ckpt"), + "disgusted": load_duration_predictor(f"{a.dur_model}/disgusted.ckpt"), + "angry": load_duration_predictor(f"{a.dur_model}/angry.ckpt"), + "sleepy": load_duration_predictor(f"{a.dur_model}/sleepy.ckpt"), + } + logger.info(f"loaded duration prediction model from {a.dur_model}") + + f0_model = load_pitch_predictor(a.f0_model).to(a.device) + logger.info(f"loaded f0 prediction model from {a.f0_model}") + + # we need to know how to map code back to the filename + # (if we want the original files names as output) + results = parse_generation_file(a.result_path) + _, src_code_to_fname = get_code_to_fname(f'{a.data}/files.{a.split}.{a.src_emotion}', f'{a.data}/{a.split}.{a.src_emotion}') + _, tgt_code_to_fname = get_code_to_fname(f'{a.data}/files.{a.split}.{a.trg_emotion}', f'{a.data}/{a.split}.{a.trg_emotion}') + + # we need the originals (before dedup) to get the ground-truth durations + orig_tsv = open(a.orig_tsv, 'r').readlines() + orig_tsv_root, orig_tsv = orig_tsv[0].strip(), orig_tsv[1:] + orig_km = open(a.orig_km, 'r').readlines() + fname_to_idx = {orig_tsv_root + "/" + line.split("\t")[0]: i for i, line in enumerate(orig_tsv)} + + outdir = a.outdir + outdir.mkdir(parents=True, exist_ok=True) + (outdir / '0-source').mkdir(exist_ok=True) + (outdir / '1-src-tokens-src-style-src-f0').mkdir(exist_ok=True) + (outdir / '2-src-tokens-trg-style-src-f0').mkdir(exist_ok=True) + (outdir / '2.5-src-tokens-trg-style-src-f0').mkdir(exist_ok=True) + (outdir / '3-src-tokens-trg-style-pred-f0').mkdir(exist_ok=True) + (outdir / '4-gen-tokens-trg-style-pred-f0').mkdir(exist_ok=True) + (outdir / '5-target').mkdir(exist_ok=True) + + N = 0 + results = list(results.items()) + random.shuffle(results) + for i, (sid, result) in tqdm(enumerate(results)): + N += 1 + if N > a.N and a.N != -1: + break + + if '[' in result['S'][0]: + result['S'] = result['S'][1:] + if '_' in result['S'][-1]: + result['S'] = result['S'][:-1] + src_ref = src_code_to_fname[code_to_str(result['S'])] + trg_ref = tgt_code_to_fname[code_to_str(result['T'])] + + src_style, trg_style = None, None + src_spkr, trg_spkr = None, None + src_f0 = None + src_audio = (load_audio(src_ref)[0] / MAX_WAV_VALUE) * 0.95 + trg_audio = (load_audio(trg_ref)[0] / MAX_WAV_VALUE) * 0.95 + src_audio = torch.FloatTensor(src_audio).unsqueeze(0).cuda() + trg_audio = torch.FloatTensor(trg_audio).unsqueeze(0).cuda() + + src_spkr = parse_speaker(src_ref, h.multispkr) + src_spkr = src_spkr if src_spkr in EMOV_SPK2ID else random.choice(list(EMOV_SPK2ID.keys())) + src_spkr = EMOV_SPK2ID[src_spkr] + src_spkr = torch.LongTensor([src_spkr]) + trg_spkr = parse_speaker(trg_ref, h.multispkr) + trg_spkr = trg_spkr if trg_spkr in EMOV_SPK2ID else random.choice(list(EMOV_SPK2ID.keys())) + trg_spkr = EMOV_SPK2ID[trg_spkr] + trg_spkr = torch.LongTensor([trg_spkr]) + + src_style = EMOV_STYLE2ID[a.src_emotion] + src_style = torch.LongTensor([src_style]).cuda() + trg_style_str = a.trg_emotion + trg_style = EMOV_STYLE2ID[a.trg_emotion] + trg_style = torch.LongTensor([trg_style]).cuda() + + src_tokens = list(map(int, orig_km[fname_to_idx[src_ref]].strip().split(" "))) + src_tokens = torch.LongTensor(src_tokens).unsqueeze(0) + src_tokens_dur_pred = torch.LongTensor(list(map(int, result['S']))).unsqueeze(0) + src_tokens_dur_pred = dur_models[trg_style_str].inflate_input(src_tokens_dur_pred) + gen_tokens = torch.LongTensor(result['H']).unsqueeze(0) + gen_tokens = dur_models[trg_style_str].inflate_input(gen_tokens) + trg_tokens = torch.LongTensor(result['T']).unsqueeze(0) + trg_tokens = dur_models[trg_style_str].inflate_input(trg_tokens) + + src_f0 = get_praat_f0(src_audio.unsqueeze(0).cpu().numpy()) + src_f0 = torch.FloatTensor(src_f0).cuda() + + pred_src_f0 = f0_model.inference(torch.LongTensor(src_tokens).to(a.device), src_spkr, trg_style).unsqueeze(0) + pred_src_dur_pred_f0 = f0_model.inference(torch.LongTensor(src_tokens_dur_pred).to(a.device), src_spkr, trg_style).unsqueeze(0) + pred_gen_f0 = f0_model.inference(torch.LongTensor(gen_tokens).to(a.device), src_spkr, trg_style).unsqueeze(0) + pred_trg_f0 = f0_model.inference(torch.LongTensor(trg_tokens).to(a.device), src_spkr, trg_style).unsqueeze(0) + + if a.orig_filename: + path = src_code_to_fname[code_to_str(result['S'])] + sid = str(sid) + "__" + Path(path).stem + shutil.copy(src_code_to_fname[code_to_str(result['S'])], outdir / '0-source' / f'{sid}.wav') + + audio = generate_from_code(generator, h, src_tokens, spkr=src_spkr, f0=src_f0, gst=src_style, device=a.device) + sf.write(outdir / '1-src-tokens-src-style-src-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate) + + audio = generate_from_code(generator, h, src_tokens, spkr=src_spkr, f0=src_f0, gst=trg_style, device=a.device) + sf.write(outdir / '2-src-tokens-trg-style-src-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate) + + audio = generate_from_code(generator, h, src_tokens_dur_pred, spkr=src_spkr, f0=src_f0, gst=trg_style, device=a.device) + sf.write(outdir / '2.5-src-tokens-trg-style-src-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate) + + audio = generate_from_code(generator, h, src_tokens_dur_pred, spkr=src_spkr, f0=pred_src_dur_pred_f0, gst=trg_style, device=a.device) + sf.write(outdir / '3-src-tokens-trg-style-pred-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate) + + audio = generate_from_code(generator, h, gen_tokens, spkr=src_spkr, f0=pred_gen_f0, gst=trg_style, device=a.device) + sf.write(outdir / '4-gen-tokens-trg-style-pred-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate) + + shutil.copy(tgt_code_to_fname[code_to_str(result['T'])], outdir / '5-target' / f'{sid}.wav') + + logger.info("Done.") + + +if __name__ == '__main__': + synth(sys.argv[1:])