From 34973a94d09ecc12092a5ecc8afece5e536b7692 Mon Sep 17 00:00:00 2001 From: Jiatong <728307998@qq.com> Date: Mon, 26 Feb 2024 15:15:44 -0500 Subject: [PATCH] Multires hubert (#5363) * multires hubert core * update core codebase on multiresolution hubert * add examples * adding entries to pretrained models (not finished) * add other abalation models * add multilinugal * add decode.sh train.sh finetune.sh and update links for README.md * fix readme * clean the codebase --------- Co-authored-by: Anna Sun <13106449+annasun28@users.noreply.github.com> --- examples/mr_hubert/README.md | 187 +++ examples/mr_hubert/config/decode/infer.yaml | 30 + .../mr_hubert/config/decode/infer_lm.yaml | 37 + .../config/decode/run/submitit_slurm.yaml | 17 + .../decode/run/submitit_slurm_8gpu.yaml | 17 + .../mr_hubert/config/finetune/base_100h.yaml | 97 ++ .../config/finetune/base_100h_large.yaml | 97 ++ .../mr_hubert/config/finetune/base_10h.yaml | 101 ++ .../config/finetune/base_10h_large.yaml | 101 ++ .../mr_hubert/config/finetune/base_1h.yaml | 100 ++ .../config/finetune/base_1h_large.yaml | 99 ++ .../pretrain/mrhubert_base_librispeech.yaml | 103 ++ .../pretrain/mrhubert_large_librilight.yaml | 107 ++ .../config/pretrain/run/submitit_reg.yaml | 20 + examples/mr_hubert/decode.sh | 46 + examples/mr_hubert/finetune.sh | 46 + examples/mr_hubert/simple_kmeans | 1 + examples/mr_hubert/train.sh | 45 + fairseq/models/multires_hubert/__init__.py | 2 + .../models/multires_hubert/multires_hubert.py | 1231 +++++++++++++++++ .../multires_hubert/multires_hubert_asr.py | 376 +++++ fairseq/models/wav2vec/wav2vec2.py | 19 +- fairseq/tasks/multires_hubert_pretraining.py | 204 +++ 23 files changed, 3077 insertions(+), 6 deletions(-) create mode 100644 examples/mr_hubert/README.md create mode 100644 examples/mr_hubert/config/decode/infer.yaml create mode 100644 examples/mr_hubert/config/decode/infer_lm.yaml create mode 100644 examples/mr_hubert/config/decode/run/submitit_slurm.yaml create mode 100644 examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml create mode 100644 examples/mr_hubert/config/finetune/base_100h.yaml create mode 100644 examples/mr_hubert/config/finetune/base_100h_large.yaml create mode 100644 examples/mr_hubert/config/finetune/base_10h.yaml create mode 100644 examples/mr_hubert/config/finetune/base_10h_large.yaml create mode 100644 examples/mr_hubert/config/finetune/base_1h.yaml create mode 100644 examples/mr_hubert/config/finetune/base_1h_large.yaml create mode 100644 examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml create mode 100644 examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml create mode 100644 examples/mr_hubert/config/pretrain/run/submitit_reg.yaml create mode 100755 examples/mr_hubert/decode.sh create mode 100755 examples/mr_hubert/finetune.sh create mode 120000 examples/mr_hubert/simple_kmeans create mode 100755 examples/mr_hubert/train.sh create mode 100644 fairseq/models/multires_hubert/__init__.py create mode 100644 fairseq/models/multires_hubert/multires_hubert.py create mode 100644 fairseq/models/multires_hubert/multires_hubert_asr.py create mode 100644 fairseq/tasks/multires_hubert_pretraining.py diff --git a/examples/mr_hubert/README.md b/examples/mr_hubert/README.md new file mode 100644 index 00000000..e72c09c0 --- /dev/null +++ b/examples/mr_hubert/README.md @@ -0,0 +1,187 @@ +# MR-HuBERT + +## Pre-trained models + +### Main models +Model | Pretraining Data | Model | Paper Reference +|---|---|---|--- +MR-HuBERT Base (~97M) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_base/mrhubert_mono_base.pt) | mono\_base +MR-HuBERT Base (~321M) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_large/mrhubert_mono_large.pt) | mono\_large +Multilingual MR-HuBERT Base (~97M) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_base/multi_base.pt) | multi\_base +Multilingual MR-HuBERT Large (~321M) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download 400k steps](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large_400k.pt) or [download 600k steps](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large_600k.pt) | Not in the paper + + +### Abalation models +Model | Pretraining Data | Model | Paper Reference +|---|---|---|--- +MR-HuBERT Base (2-4-6 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-a/b1-a.pt) | (B.1)-a +MR-HuBERT Base (5-2-5 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-b/b1-b.pt) | (B.1)-b +MR-HuBERT Base (6-4-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-c/b1-c.pt) | (B.1)-c +MR-HuBERT Base (3res 3-2-2-2-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-a/b2-a.pt) | (B.2)-a +MR-HuBERT Base (3res 2-2-4-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-b/b2-b.pt) | (B.2)-b +MR-HuBERT Base (3res 2-2-2-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-c/b2-c.pt) | (B.2)-c +MR-HuBERT Base (Simple sampling) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b3-a/b3-a.pt) | (B.3)-a +MR-HuBERT Base (Single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-a/b4-a.pt) | (B.4)-a +MR-HuBERT Base (Simple Sampling + single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-b/b4-b.pt) | (B.4)-b +MR-HuBERT Base (Mono-resolution 20ms) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b5-a/b5-a.pt) | (B.5)-a +MR-HuBERT Base (3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-a/b6-a.pt) | (B.6)-a +MR-HuBERT Base (Mono-resolution 20ms, 3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-b/b6-b.pt) | (B.6)-b +MR-HuBERT Base (HuBERT 20ms&40ms units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-a/b7-a.pt) | (B.7)-a +MR-HuBERT Base (Encodec 50Hz unit) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-b/b7-b.pt) | (B.7)-b +MR-HuBERT Base (Encodec 50Hz units and 25Hz units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-c/b7-c.pt) | (B.7)-c +MR-HuBERT Base (Encodec 50Hz units stream 0&1 ) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-d/b7-d.pt) | (B.7)-d +MR-HuBERT Large (no audio norm) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-a.pt) | (B.8)-a +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-b/b8-b.pt) | (B.8)-b +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-c/b8-c.pt) | (B.8)-c +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-d/b8-d.pt) | (B.8)-d +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-e/b8-e.pt) | (B.8)-e +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-f/b8-f.pt) | (B.8)-f +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-g/b8-g.pt) | (B.8)-g +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-h/b8-h.pt) | (B.8)-h +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-i/b8-i.pt) | (B.8)-i +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-j/b8-j.pt) | (B.8)-j +Multilingual MR-HuBERT Large (Simple sampling) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_large_simple/multi_large_simple.pt) | Not in paper +MR-HuBERT xLarge (from HuBERT-base label) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_xlarge/v1.pt) | Not in paper +MR-HuBERT xLarge (from HuBERT-large label) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_xlarge/v2.pt) | Not in paper + +## Load a model +``` +ckpt_path = "/path/to/the/checkpoint.pt" +models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) +model = models[0] +``` + +## Train a new model + +### Data preparation + +Follow the steps in `./simple_kmeans` to create: +- `{train,valid}.tsv` waveform list files with length information +``` +/path/to/your/audio/files +file1.wav\t160000 +file2.wav\t154600 +... +filen.wav\t54362 +``` +- `{train,valid}.km` frame-aligned pseudo label files (the order is the same as wavefiles in the tsv file). +``` +44 44 44 48 48 962 962 962 962 962 962 962 962 967 967 967 967 967 967 967 967 370 852 370 ... 18 18 745 745 +44 44 44 48 48 962 962 962 147 147 147 147 147 147 147 147 147 147 147 147 176 176 271 271 ... 27 27 745 745 +... +44 44 44 48 962 962 962 962 962 962 377 377 377 77 77 852 696 694 433 578 578 82 740 622 ... 27 27 745 745 +``` +- `dict.km.txt` a dummy dictionary (first column is id, the second is dummy one) +``` +0 1 +1 1 +2 1 +... +999 1 +``` + +The `label_rate` is the same as the feature frame rate used for clustering, +which is 100Hz for MFCC features and 50Hz for HuBERT features by default. + +### Pre-train a MR-HuBERT model + +Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km` +are saved at `/path/to/labels`, and the label rate is 100Hz. + +To train a base model (12 layer transformer), run: +```sh +$ python fairseq_cli/hydra_train.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/pretrain \ + --config-name mrhubert_base_librispeech \ + task.data=/path/to/data task.label_dir=/path/to/labels \ + task.labels='["km"]' model.label_rate=100 \ + task.label_rate_ratios='[1, 2]' \ +``` + +Please see sample pre-training scripts `train.sh` for an example script. + +### Fine-tune a MR-HuBERT model with a CTC loss + +Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their +corresponding character transcripts `{train,valid}.ltr` are saved at +`/path/to/trans`. A typical ltr file is with the same order of tsv waveform files as +``` +HOW | ARE | YOU +... +THANK | YOU +``` + +To fine-tune a pre-trained MR-HuBERT model at `/path/to/checkpoint`, run +```sh +$ python fairseq_cli/hydra_train.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/finetune \ + --config-name base_10h \ + task.data=/path/to/data task.label_dir=/path/to/trans \ + model.w2v_path=/path/to/checkpoint +``` + +Please see sample fine-tuning scripts `finetune.sh` for an example script. + +### Decode a MR-HuBERT model + +Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of +the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is +saved at `/path/to/checkpoint`. + + +We support three decoding modes: +- Viterbi decoding: greedy decoding without a language model +- KenLM decoding: decoding with an arpa-format KenLM n-gram language model +- Fairseq-LM deocding: decoding with a Fairseq neural language model (not fully tested) + + +#### Viterbi decoding + +`task.normalize` needs to be consistent with the value used during fine-tuning. +Decoding results will be saved at +`/path/to/experiment/directory/decode/viterbi/test`. + +```sh +$ python examples/speech_recognition/new/infer.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \ + --config-name infer \ + task.data=/path/to/data \ + task.normalize=[true|false] \ + decoding.exp_dir=/path/to/experiment/directory \ + common_eval.path=/path/to/checkpoint + dataset.gen_subset=test \ +``` + +#### KenLM / Fairseq-LM decoding + +Suppose the pronunciation lexicon and the n-gram LM are saved at +`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be +saved at `/path/to/experiment/directory/decode/kenlm/test`. + +```sh +$ python examples/speech_recognition/new/infer.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \ + --config-name infer_lm \ + task.data=/path/to/data \ + task.normalize=[true|false] \ + decoding.exp_dir=/path/to/experiment/directory \ + common_eval.path=/path/to/checkpoint + dataset.gen_subset=test \ + decoding.decoder.lexicon=/path/to/lexicon \ + decoding.decoder.lmpath=/path/to/arpa +``` + +The command above uses the default decoding hyperparameter, which can be found +in `examples/speech_recognition/hydra/decoder.py`. These parameters can be +configured from the command line. For example, to search with a beam size of +500, we can append the command above with `decoding.decoder.beam=500`. +Important parameters include: +- decoding.decoder.beam +- decoding.decoder.beamthreshold +- decoding.decoder.lmweight +- decoding.decoder.wordscore +- decoding.decoder.silweight + +To decode with a Fairseq LM, you may check the usage examples in wav2vec2 or hubert examples. + +Please see sample decoding scripts `decode.sh` for an example script. diff --git a/examples/mr_hubert/config/decode/infer.yaml b/examples/mr_hubert/config/decode/infer.yaml new file mode 100644 index 00000000..eff39802 --- /dev/null +++ b/examples/mr_hubert/config/decode/infer.yaml @@ -0,0 +1,30 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/viterbi + sweep: + dir: ${common_eval.results_path} + subdir: viterbi + +task: + _name: multires_hubert_pretraining + single_target: true + fine_tuning: true + label_rate_ratios: ??? + data: ??? + normalize: false + +decoding: + type: viterbi + unique_wer_file: true +common_eval: + results_path: ??? + path: ??? + post_process: letter +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/examples/mr_hubert/config/decode/infer_lm.yaml b/examples/mr_hubert/config/decode/infer_lm.yaml new file mode 100644 index 00000000..535b9507 --- /dev/null +++ b/examples/mr_hubert/config/decode/infer_lm.yaml @@ -0,0 +1,37 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} + sweep: + dir: ${common_eval.results_path} + subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} + +task: + _name: multires_hubert_pretraining + single_target: true + fine_tuning: true + data: ??? + label_rate_ratios: ??? + normalize: ??? + +decoding: + type: kenlm + lexicon: ??? + lmpath: ??? + beamthreshold: 100 + beam: 500 + lmweight: 1.5 + wordscore: -1 + silweight: 0 + unique_wer_file: true +common_eval: + results_path: ??? + path: ??? + post_process: letter +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/examples/mr_hubert/config/decode/run/submitit_slurm.yaml b/examples/mr_hubert/config/decode/run/submitit_slurm.yaml new file mode 100644 index 00000000..0b806583 --- /dev/null +++ b/examples/mr_hubert/config/decode/run/submitit_slurm.yaml @@ -0,0 +1,17 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: ${distributed_training.distributed_world_size} + gpus_per_node: ${distributed_training.distributed_world_size} + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + mem_gb: 200 + timeout_min: 4320 + max_num_timeout: 50 + name: ${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/submitit + +distributed_training: + distributed_world_size: 1 + distributed_no_spawn: true + distributed_port: 29761 diff --git a/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml b/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml new file mode 100644 index 00000000..2f669f37 --- /dev/null +++ b/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml @@ -0,0 +1,17 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: ${distributed_training.distributed_world_size} + gpus_per_node: ${distributed_training.distributed_world_size} + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + mem_gb: 200 + timeout_min: 4320 + max_num_timeout: 50 + name: ${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/submitit + +distributed_training: + distributed_world_size: 8 + distributed_no_spawn: true + distributed_port: 29761 diff --git a/examples/mr_hubert/config/finetune/base_100h.yaml b/examples/mr_hubert/config/finetune/base_100h.yaml new file mode 100644 index 00000000..c52a118c --- /dev/null +++ b/examples/mr_hubert/config/finetune/base_100h.yaml @@ -0,0 +1,97 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_100h + valid_subset: dev_other + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [3e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/finetune/base_100h_large.yaml b/examples/mr_hubert/config/finetune/base_100h_large.yaml new file mode 100644 index 00000000..1d0c0da3 --- /dev/null +++ b/examples/mr_hubert/config/finetune/base_100h_large.yaml @@ -0,0 +1,97 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: true # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 1600000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_100h + valid_subset: dev_other + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [3e-5] + sentence_avg: true + update_freq: [2] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/finetune/base_10h.yaml b/examples/mr_hubert/config/finetune/base_10h.yaml new file mode 100644 index 00000000..25123e44 --- /dev/null +++ b/examples/mr_hubert/config/finetune/base_10h.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 5 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_10h + valid_subset: dev + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [2e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/finetune/base_10h_large.yaml b/examples/mr_hubert/config/finetune/base_10h_large.yaml new file mode 100644 index 00000000..65448c77 --- /dev/null +++ b/examples/mr_hubert/config/finetune/base_10h_large.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 5 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: true # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_10h + valid_subset: dev + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [2e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/finetune/base_1h.yaml b/examples/mr_hubert/config/finetune/base_1h.yaml new file mode 100644 index 00000000..7459c3fc --- /dev/null +++ b/examples/mr_hubert/config/finetune/base_1h.yaml @@ -0,0 +1,100 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 50 + keep_interval_updates: 1 + save_interval_updates: 1000 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 1000 + train_subset: train_1h + valid_subset: dev_other + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [5e-5] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/finetune/base_1h_large.yaml b/examples/mr_hubert/config/finetune/base_1h_large.yaml new file mode 100644 index 00000000..34ef4dc1 --- /dev/null +++ b/examples/mr_hubert/config/finetune/base_1h_large.yaml @@ -0,0 +1,99 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 1000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: true # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 1280000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_10h + valid_subset: dev + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [3e-4] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml b/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml new file mode 100644 index 00000000..16a35d34 --- /dev/null +++ b/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml @@ -0,0 +1,103 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + min_loss_scale: 1e-8 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: multires_hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + label_rate_ratios: ??? + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + # max_keep_size: 300000 + # max_keep_size: 50000 + + +dataset: + num_workers: 0 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: multires_hubert + label_rate: ??? + label_rate_ratios: ${task.label_rate_ratios} + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: default + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + final_dim: 256 + encoder_layers: 4 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + conv_adapator_kernal: 1 + use_single_target: true + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '/' + exclude_keys: + - run + - task.data + - task.label_dir + - common.min_loss_scale + - common.log_interval + - optimization.clip_norm diff --git a/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml b/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml new file mode 100644 index 00000000..423f3b25 --- /dev/null +++ b/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml @@ -0,0 +1,107 @@ +# @package _group_ + +common: + memory_efficient_fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 128 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: multires_hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + label_rate_ratios: ??? + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: true # must be consistent with extractor + # max_keep_size: 50000 + +dataset: + num_workers: 0 + max_tokens: 300000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [3] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: multires_hubert + label_rate: ??? + label_rate_ratios: ${task.label_rate_ratios} + encoder_layers: 8 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + final_dim: 768 + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + layer_norm_first: true + feature_grad_mult: 1.0 + untie_final_proj: true + activation_dropout: 0.0 + conv_adapator_kernal: 1 + use_single_target: true + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + run: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + sweep: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml b/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml new file mode 100644 index 00000000..46c979cd --- /dev/null +++ b/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +hydra: + launcher: + cpus_per_task: 8 + gpus_per_node: 8 + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 4 + comment: null + mem_gb: 384 + timeout_min: 4320 + max_num_timeout: 100 + constraint: volta32gb + name: ${hydra.job.config_name}/${hydra.job.override_dirname} + submitit_folder: ${hydra.sweep.dir}/submitit/%j + +distributed_training: + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 diff --git a/examples/mr_hubert/decode.sh b/examples/mr_hubert/decode.sh new file mode 100755 index 00000000..1ff423a8 --- /dev/null +++ b/examples/mr_hubert/decode.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +FAIRSEQ= # Setup your fairseq directory + +config_dir=${FAIRSEQ}/examples/mr_hubert/config +config_name=mr_hubert_base_librispeech + + +# Prepared Data Directory + +data_dir=librispeech +# -- data_dir +# -- test.tsv +# -- test.ltr +# -- dict.ltr.txt + + +exp_dir=exp # Target experiments directory (where you have your pre-trained model with checkpoint_best.pt) +ratios="[1, 2]" # Default label rate ratios + +_opts= + +# If use slurm, uncomment this line and modify the job submission at +# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg" + +# If want to set additional experiment tag, uncomment this line +# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}" + +# If use un-normalized audio, uncomment this line +# _opts="${_opts} task.normalize=false" + + + +PYTHONPATH=${FAIRSEQ} +python examples/speech_recognition/new/infer.py \ + --config-dir ${config_dir} \ + --config-name infer_multires \ + ${_opts} \ + task.data=${data_dir} \ + task.label_rate_ratios='${ratios}' \ + common_eval.results_path=${exp_dir} \ + common_eval.path=${exp_dir}/checkpoint_best.pt \ + dataset.max_tokens=2000000 \ + dataset.gen_subset=test \ + dataset.skip_invalid_size_inputs_valid_test=true + diff --git a/examples/mr_hubert/finetune.sh b/examples/mr_hubert/finetune.sh new file mode 100755 index 00000000..31ba6455 --- /dev/null +++ b/examples/mr_hubert/finetune.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +FAIRSEQ= # Setup your fairseq directory + +config_dir=${FAIRSEQ}/examples/mr_hubert/config +config_name=mr_hubert_base_librispeech + +# override configs if need +max_tokens=3200000 +max_sample_size=1000000 +max_update=50000 + + +# Prepared Data Directory + +data_dir=librispeech +# -- data_dir +# -- train.tsv +# -- train.ltr +# -- valid.tsv +# -- valid.ltr +# -- dict.ltr.txt + + +exp_dir=exp # Target experiments directory +ratios="[1, 2]" # Default label rate ratios +hubert_path=/path/of/your/hubert.pt + +_opts= + +# If use slurm, uncomment this line and modify the job submission at +# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg" + +# If want to set additional experiment tag, uncomment this line +# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}" + + +python ${FAIRSEQ}/fairseq_cli/hydra_train.py \ + -m --config-dir ${config_dir} --config-name ${config_name} ${_opts} \ + task.data=${data_dir} +task.max_sample_size=${max_sample_size} \ + task.label_dir=${data_dir} \ + task.label_rate_ratios='${ratios}' \ + dataset.max_tokens=${max_tokens} \ + optimization.max_update=${max_update} \ + model.multires_hubert_path=${hubert_path} \ + hydra.sweep.dir=${exp_dir} & diff --git a/examples/mr_hubert/simple_kmeans b/examples/mr_hubert/simple_kmeans new file mode 120000 index 00000000..4f955451 --- /dev/null +++ b/examples/mr_hubert/simple_kmeans @@ -0,0 +1 @@ +../hubert/simple_kmeans \ No newline at end of file diff --git a/examples/mr_hubert/train.sh b/examples/mr_hubert/train.sh new file mode 100755 index 00000000..da561eb1 --- /dev/null +++ b/examples/mr_hubert/train.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +FAIRSEQ= # Setup your fairseq directory + +config_dir=${FAIRSEQ}/examples/mr_hubert/config +config_name=mr_hubert_base_librispeech + +# Prepared Data Directory +data_dir=librispeech +# -- data_dir +# -- train.tsv +# -- valid.tsv + +label_dir=labels +# -- label_dir +# -- train.km +# -- valid.km +# -- dict.km.txt + + +exp_dir=exp # Target experiments directory +ratios="[1, 2]" # Default label rate ratios +label_rate=50 # Base label rate + + +_opts= + +# If use slurm, uncomment this line and modify the job submission at +# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg" + +# If want to set additional experiment tag, uncomment this line +# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}" + + +python ${FAIRSEQ}/fairseq_cli/hydra_train.py \ + -m --config-dir ${config_dir} --config-name ${config_name} ${_opts} \ + task.data=${data_dir} \ + task.label_dir=${label_dir} \ + task.labels='["km"]' \ + model.label_rate=${label_rate} \ + task.label_rate_ratios='${ratios}' \ + hydra.sweep.dir=${exp_dir} & + + + diff --git a/fairseq/models/multires_hubert/__init__.py b/fairseq/models/multires_hubert/__init__.py new file mode 100644 index 00000000..ec36505b --- /dev/null +++ b/fairseq/models/multires_hubert/__init__.py @@ -0,0 +1,2 @@ +from .multires_hubert import * # noqa +from .multires_hubert_asr import * # noqa diff --git a/fairseq/models/multires_hubert/multires_hubert.py b/fairseq/models/multires_hubert/multires_hubert.py new file mode 100644 index 00000000..eacb29e5 --- /dev/null +++ b/fairseq/models/multires_hubert/multires_hubert.py @@ -0,0 +1,1231 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import math +import torch.nn as nn +from omegaconf import II +from fairseq.models.wav2vec.wav2vec import norm_block + +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + EXTRACTOR_MODE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + LAYER_TYPE_CHOICES, + ConvFeatureExtractionModel, + TransformerEncoder, +) +from omegaconf import II, MISSING, open_dict +from fairseq.modules import GradMultiply, LayerNorm +from fairseq.tasks.multires_hubert_pretraining import ( + MultiresHubertPretrainingConfig, + MultiresHubertPretrainingTask, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class MultiresHubertConfig(FairseqDataclass): + label_rate: float = II("task.label_rate") + # label_rate: 1,2,2,5 + # (imply (1,2), (2,5)) + # if base label_rate = 50 + # (1,2), (2,5) --> label rates 50, 25, 10 + label_rate_ratios: List[int] = field( + default=MISSING, metadata={"help": "tuple for label rates e.g., [(1,2), (2,5)]"} + ) + + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + # the blocks for each label rate + encoder_layers: int = field( + default="2", + metadata={ + "help": "num encoder layers in the each block (one sub module of the U-net)" + }, + ) + override_encoder_layers: str = field( + default="", + metadata={ + "help": "specific layer numbers for each block (one sub module of the U-net) for the training" + }, + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + conv_adapator_kernal: int = field( + default=7, metadata={"help": "kernal size for conv adaptor"} + ) + use_plain_updownsample: bool = field( + default=False, metadata={"help": "whether to use plain up downsample"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=True, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + use_single_target: bool = field( + default=False, + metadata={ + "help": "whether to use single data (in that case, we will compute with the fixed label rate)" + }, + ) + use_single_prediction: bool = field( + default=False, + metadata={ + "help": "if true, we will not conduct mlm prediction in low resolution in the middle" + }, + ) + use_multi_stream: bool = field( + default=False, + metadata={ + "help": "whether to use multi-stream setting (in this setting, we have multiple streams with the same resolution)" + }, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + +@register_model("multires_hubert", dataclass=MultiresHubertConfig) +class MultiresHubertModel(BaseFairseqModel): + def __init__( + self, + cfg: MultiresHubertConfig, + task_cfg: MultiresHubertPretrainingConfig, + dictionaries: List[Dictionary], + ) -> None: + super().__init__() + logger.info(f"MultiresHubertModel Config: {cfg}") + + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + # Estimate label rates + assert ( + cfg.label_rate_ratios != "None" + ), "without ratios, the model is exactly as the Hubert model" + self.label_rate_ratios = [] + self.base_rate = cfg.label_rate + self.label_rates = [] + self.downsample_modules = nn.ModuleList() + self.upsample_modules = nn.ModuleList() + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.use_single_target = cfg.use_single_target + self.use_single_prediction = cfg.use_single_prediction + self.use_plain_updownsample = cfg.use_plain_updownsample + + # For decide the override encoder layers, so that the layer number is not equally distributed + if cfg.override_encoder_layers != "": + self.override_encoder_layers = eval(cfg.override_encoder_layers) + assert ( + len(self.override_encoder_layers) % 2 == 1 + ), "must be odd number of layers if specify detailed layers" + assert ( + len(self.override_encoder_layers) // 2 + == len(cfg.label_rate_ratios) // 2 + ), "number of override encoder layers must match the label rate ratios information" + self.len_encoder_modules = len(self.override_encoder_layers) + else: + self.override_encoder_layers = None + self.len_encoder_modules = None + + # use different layers instead of equally distributed ones + middle_override_encoder_layer = ( + self.override_encoder_layers[self.len_encoder_modules // 2] + if self.override_encoder_layers is not None + else None + ) + skip_middle_pos_conv = False if len(cfg.label_rate_ratios) < 2 else True + + self.middle_encoder = TransformerEncoder( + cfg, + skip_pos_conv=skip_middle_pos_conv, + override_encoder_layer=middle_override_encoder_layer, + ) + + first_pos_conv = False # only enable pos_conv for the first encoder + raw_label_rate_ratios = cfg.label_rate_ratios + for i in range(len(raw_label_rate_ratios) // 2): + # check if have override encoder layers + if self.override_encoder_layers is not None: + override_encoder_layer = self.override_encoder_layers[i] + override_decoder_layer = self.override_encoder_layers[ + self.len_encoder_modules - 1 - i + ] + else: + override_encoder_layer, override_decoder_layer = None, None + + self.label_rate_ratios.append( + (raw_label_rate_ratios[i * 2], raw_label_rate_ratios[i * 2 + 1]) + ) + if self.use_plain_updownsample: + self.downsample_modules.append( + ConvDownsampler( + k=cfg.conv_adapator_kernal, + label_rate=( + ( + raw_label_rate_ratios[i * 2], + raw_label_rate_ratios[i * 2 + 1], + ) + ), + dropout=0.0, + channels=cfg.encoder_embed_dim, + activation=nn.GELU(), + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + ) + ) + else: + self.downsample_modules.append( + ConvAdapter( + k=cfg.conv_adapator_kernal, + label_rate=( + ( + raw_label_rate_ratios[i * 2], + raw_label_rate_ratios[i * 2 + 1], + ) + ), + dropout=0.0, + channels=cfg.encoder_embed_dim, + activation=nn.GELU(), + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + ) + ) + if not first_pos_conv: + self.encoders.append( + TransformerEncoder( + cfg, override_encoder_layer=override_encoder_layer + ) + ) # TODO(jiatong): add conformer options + first_pos_conv = True + else: + self.encoders.append( + TransformerEncoder( + cfg, + skip_pos_conv=True, + override_encoder_layer=override_encoder_layer, + ) + ) + if self.use_plain_updownsample: + self.upsample_modules.append( + ConvUpsampler( + k=cfg.conv_adapator_kernal, + label_rate=( + ( + raw_label_rate_ratios[i * 2 + 1], + raw_label_rate_ratios[i * 2], + ) + ), + dropout=0.0, + channels=cfg.encoder_embed_dim, + activation=nn.GELU(), + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + ) + ) + else: + self.upsample_modules.append( + ConvAdapter( + k=cfg.conv_adapator_kernal, + label_rate=( + ( + raw_label_rate_ratios[i * 2 + 1], + raw_label_rate_ratios[i * 2], + ) + ), + dropout=0.0, + channels=cfg.encoder_embed_dim, + activation=nn.GELU(), + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + ) + ) + self.decoders.append( + TransformerEncoder( + cfg, + skip_pos_conv=True, + override_encoder_layer=override_decoder_layer, + ) + ) + + base_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feature_ds_rates = [base_ds_rate] + running_rate = self.base_rate + + if cfg.use_single_target or cfg.use_multi_stream: + self.label_rates = self.base_rate + else: + self.label_rates.append(self.base_rate) + + for label_rate_ratio in self.label_rate_ratios: + upsample_rate, downsample_rate = label_rate_ratio + if (base_ds_rate * upsample_rate) % downsample_rate != 0: + logger.warning( + "base rate: {} cannot be ideally processed with downsample rate {}".format( + base_ds_rate, downsample_rate + ) + ) + + base_ds_rate = base_ds_rate * downsample_rate // upsample_rate + self.feature_ds_rates.append(base_ds_rate) + + if not cfg.use_single_target and not cfg.use_multi_stream: + running_rate = running_rate * upsample_rate // downsample_rate + self.label_rates.append(running_rate) + self.label_nums = len( + self.feature_ds_rates + ) # the number of labels for prediction (activate at iter 2) + + if type(self.label_rates) == float: + self.feat2tar_ratios = [ + self.feature_ds_rates[i] * self.label_rates / task_cfg.sample_rate + for i in range(len(self.feature_ds_rates)) + ] + else: + self.feat2tar_ratios = [ + self.feature_ds_rates[i] * self.label_rates[i] / task_cfg.sample_rate + for i in range(len(self.feature_ds_rates)) + ] + + # self.feat2tar_ratios = self.feat2tar_ratios[::-1] + + # An running example of the label rate: + # base_ds_rate = 320 + # self.label_rate_ratios = [(1, 2)] + # self.feature_ds_rates = [320, 640] + # self.label_rates = [50, 25] + # self.feat2tar_ratios = [1, 1] + + # Another running example of the label rate: + # base_ds_rate = 320 + # self.label_rate_ratios = [(1, 2)] + # self.feature_ds_rates = [320, 640] + # self.label_rates = 100 + # self.feat2tar_ratios = [4, 2] + # self.use_sinlge_target = True + + logging.info( + "ds_rates: {}, label_rates: {}, feat2tar_ratios: {}".format( + self.feature_ds_rates, self.label_rates, self.feat2tar_ratios + ) + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + # Note(jiatong): different from hubert, we just set the final dim as encoder_embed_dim + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.layer_norm = LayerNorm(self.embed) + + self.predictor_head_num = 1 if self.use_single_prediction else self.label_nums + + self.target_glu = None + if cfg.target_glu: + self.target_glus = nn.ModuleList() + for i in range(self.predictor_head_num): + self.target_glus.append( + nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU()) + ) + + self.untie_final_proj = cfg.untie_final_proj + self.final_projs = nn.ModuleList() + + # Note(jiatong): we do not have untie cases for multires hubert + for i in range(self.predictor_head_num): + self.final_projs.append(nn.Linear(cfg.encoder_embed_dim, final_dim)) + + # modules below are not needed during fine-tuning + self.multires_classes = [] + self.label_embs_concat = nn.ParameterList() + + for i in range(self.predictor_head_num): + if self.use_single_target: + num_classes = len(dictionaries[0]) + else: + num_classes = len(dictionaries[i]) + self.multires_classes.append(num_classes) + self.label_embs_concat.append( + nn.Parameter(torch.FloatTensor(num_classes, final_dim)) + ) + nn.init.uniform_(self.label_embs_concat[i]) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model( + cls, cfg: MultiresHubertConfig, task: MultiresHubertPretrainingTask + ): + """Build a new model instance.""" + + model = MultiresHubertModel(cfg, task.cfg, task.dictionaries) + return model + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: torch.Tensor, + target: torch.Tensor, + feat2tar_ratio: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + + feat_tsz = features.size(1) + + # skip if no target is provided + if target is None: + return features, None, None + targ_tsz = target.size(1) + if feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / feat2tar_ratio) + features = features[:, :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * feat2tar_ratio + target = target[:, target_inds.long()] + return features, target + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + features = self.forward_features(source) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + + def align_size_sum(feat1, pad1, feat2): + assert ( + abs(feat1.size(1) - feat2.size(1)) < 10 + ), "misaligned results for feat1 and feat2 of size {} - {}".format( + feat1.size(1), feat2.size(1) + ) + common_size = min(feat1.size(1), feat2.size(1)) + + return ( + feat1[:, :common_size] + feat2[:, :common_size], + pad1[:, :common_size], + ) + + # process encoders + res_outputs = [] # final output for different resolution + multi_mask_indices = [] # mask indices for different resolution + residuals = [] # record the x in encoders + padding_masks = [] # final padding masks + # The encoder has (self.label_nums - 1) blocks + for i in range(self.label_nums - 1): + x, _ = self.encoders[i](x, padding_mask=padding_mask, layer=None) + residuals.append(x) + x, padding_mask, mask_indices = self.downsample_modules[i]( + x, padding=padding_mask, mask_indices=mask_indices + ) + + residual = self.middle_encoder(x, padding_mask=padding_mask, layer=None)[0] + x = x + residual + res_outputs.append(x) + + # process decoders + # The encoder has (self.label_nums - 1) blocks + padding_masks.append(padding_mask) + multi_mask_indices.append(mask_indices) + residuals.reverse() # NOTE(jiatong): reverse res_output to match corresponding input + for i in range(self.label_nums - 1): + x, padding_mask, mask_indices = self.upsample_modules[ + self.label_nums - 2 - i + ](x, padding=padding_mask, mask_indices=mask_indices) + x, _ = self.decoders[i](x, padding_mask=padding_mask, layer=None) + x, padding_mask = align_size_sum(x, padding_mask, residuals[i]) + res_outputs.append(x) + padding_masks.append(padding_mask) + multi_mask_indices.append(mask_indices) + + # NOTE(jiatong): need reverse of target list to allow matched target-representation + res_outputs.reverse() + padding_masks.reverse() + multi_mask_indices.reverse() + if target_list is not None: + new_target_list = [] + for i in range(self.label_nums): + if self.use_single_target: + res_outputs[i], reformat_target_list = self.forward_targets( + res_outputs[i], target_list[0], self.feat2tar_ratios[i] + ) + new_target_list.append(reformat_target_list) + else: + if target_list[i] is not None: + res_outputs[i], reformat_target_list = self.forward_targets( + res_outputs[i], target_list[i], self.feat2tar_ratios[i] + ) + new_target_list.append(reformat_target_list) + else: + # Append a None target list then it won't be used to calculate loss + new_target_list.append(None) + if padding_masks[i] is not None: + padding_masks[i] = self.forward_padding_mask( + res_outputs[i], padding_masks[i] + ) + if multi_mask_indices[i] is not None: + multi_mask_indices[i] = self.forward_padding_mask( + res_outputs[i], multi_mask_indices[i] + ) + + + if features_only: + # NOTE(jiatong): need to reverse back + res_outputs.reverse() + return { + "x": res_outputs, + "padding_mask": padding_masks[0], + "features": features, + } + + def compute_pred(proj_x, target, label_embs): + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + logit_m_list, logit_u_list = [], [] + for j in range(self.label_nums): + if new_target_list[j] is None: + continue # skip empty targets + label_embs_list = self.label_embs_concat[j].split( + [self.multires_classes[j]], 0 + ) + # set the variables (after the set, the procedure is the same as hubert) + # all the elements are list with only one element (to simulate the normal hubert process) + x = res_outputs[j] + target = new_target_list[j] + padding_mask = padding_masks[j] + mask_indices = multi_mask_indices[j] + final_proj = self.final_projs[j] + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = final_proj(x[masked_indices]) + logit_m_list.append( + compute_pred(proj_x_m, target[masked_indices], label_embs_list[0]) + ) + else: + logit_m_list.append(None) + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = final_proj(x[nomask_indices]) + logit_u_list.append( + compute_pred(proj_x_u, target[nomask_indices], label_embs_list[0]) + ) + else: + logit_u_list.append(None) + + # if we only want one prediction, we can exit now + if self.predictor_head_num == 1: + break + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + last_layer: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + if last_layer: + feature = feature[-1] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + +class ConvAdapter(nn.Module): + """Conv adapter that combines two modules with different label rate with downsample or upsample. + To allow different ratios than integer, two convs are utilized with first to upsample (numerator) + and the second to downsample (denominator)""" + + def __init__( + self, + k, + label_rate, + dropout, + channels, + activation, + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + non_affine_group_norm=False, + ): + super().__init__() + + def downsample_block(channel, k, stride): + return nn.Sequential( + # with padding (k - 1) // 2 to keep the same size + nn.Conv1d( + channel, + channel, + k, + stride=stride, + bias=False, + padding=(k - 1) // 2, + ), + nn.Dropout(p=dropout), + norm_block( + is_layer_norm=False, dim=channel, affine=not non_affine_group_norm + ), + activation, + ) + + def upsample_block(channel, k, stride): + return nn.Sequential( + # with padding (k - 1) // 2 to keep the same size + nn.ConvTranspose1d( + channel, + channel, + k, + stride=stride, + bias=False, + padding=0, # padding=(k - 1) // 2, + output_padding=(stride - 1), + ), + nn.Dropout(p=dropout), + norm_block( + is_layer_norm=False, dim=channel, affine=not non_affine_group_norm + ), + activation, + ) + + assert len(label_rate) == 2, "label_rate should be sized two to apply fusion" + # Lout =(Lin~H~R1)~Wstride~H~R2~Wpadding+dilation~W(kernel_size~H~R1)+output_padding+1 + self.upsample_conv = upsample_block(channels, k, label_rate[0]) + self.downsample_conv = downsample_block(channels, k, label_rate[1]) + + self.upsample_rate, self.downsample_rate = label_rate + self.log_compression = log_compression + self.skip_connections = skip_connections + self.highway = highway + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x, padding=None, mask_indices=None): + # Assume x1 = (B, T, C) as input + x = x.permute(0, 2, 1) + residual_before_upsample = x + x = self.upsample_conv(x) + upsample_size = x.size(2) + + # conduct upsample + if self.skip_connections: + residual_upsample = torch.repeat_interleave( + residual_before_upsample, self.upsample_rate, dim=2 + ) + upsample_size = min(upsample_size, residual_upsample.size(2)) + x = ( + x[..., :upsample_size] + residual_upsample[..., :upsample_size] + ) * self.residual_scale + + residual_before_downsample = x + x = self.downsample_conv(x) + downsample_size = x.size(2) + + if self.skip_connections: + residual_downsample = residual_before_downsample[ + ..., :: self.downsample_rate + ] + downsample_size = min(x.size(2), residual_downsample.size(2)) + x = ( + x[..., :downsample_size] + residual_downsample[..., :downsample_size] + ) * self.residual_scale + + if self.highway: + residual_after_sample = residual_upsample[..., :: self.downsample_rate] + final_size = min(x.size(2), residual_after_sample.size(2)) + x = ( + x[..., :final_size] + residual_after_sample[..., :final_size] + ) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + x = x.permute(0, 2, 1) + + # process padding + if padding is not None: + padding = torch.repeat_interleave(padding, self.upsample_rate, dim=1) + padding = padding[..., :: self.downsample_rate] + padding = padding[..., : x.size(1)] + + # process mask indices + if mask_indices is not None: + mask_indices = torch.repeat_interleave( + mask_indices, self.upsample_rate, dim=1 + ) + mask_indices = mask_indices[..., :: self.downsample_rate] + mask_indices = mask_indices[..., : x.size(1)] + return x, padding, mask_indices + + +class ConvDownsampler(nn.Module): + """Conv downsampler that combines two modules with different label rate with downsample or upsample. + To allow different ratios than integer, two convs are utilized with first to upsample (numerator) + and the second to downsample (denominator)""" + + def __init__( + self, + k, + label_rate, + dropout, + channels, + activation, + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + non_affine_group_norm=False, + ): + super().__init__() + + def downsample_block(channel, k, stride): + return nn.Sequential( + # with padding (k - 1) // 2 to keep the same size + nn.Conv1d( + channel, + channel, + k, + stride=stride, + bias=False, + padding=(k - 1) // 2, + ), + nn.Dropout(p=dropout), + norm_block( + is_layer_norm=False, dim=channel, affine=not non_affine_group_norm + ), + activation, + ) + + assert len(label_rate) == 2, "label_rate should be sized two to apply fusion" + self.downsample_conv = downsample_block(channels, k, label_rate[1]) + + upsample_rate, self.downsample_rate = label_rate + assert upsample_rate == 1, "must be 1 to perform downsample only" + self.log_compression = log_compression + self.skip_connections = skip_connections + self.highway = highway # Useless as placeholder + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x, padding=None, mask_indices=None): + # Assume x1 = (B, T, C) as input + x = x.permute(0, 2, 1) + + residual_before_downsample = x + x = self.downsample_conv(x) + downsample_size = x.size(2) + + if self.skip_connections: + residual_downsample = residual_before_downsample[ + ..., :: self.downsample_rate + ] + downsample_size = min(x.size(2), residual_downsample.size(2)) + x = ( + x[..., :downsample_size] + residual_downsample[..., :downsample_size] + ) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + x = x.permute(0, 2, 1) + + # process padding + if padding is not None: + padding = padding[..., :: self.downsample_rate] + padding = padding[..., : x.size(1)] + + # process mask indices + if mask_indices is not None: + mask_indices = mask_indices[..., :: self.downsample_rate] + mask_indices = mask_indices[..., : x.size(1)] + return x, padding, mask_indices + + +class ConvUpsampler(nn.Module): + """Conv upsampler that combines two modules with different label rate with downsample or upsample. + To allow different ratios than integer, two convs are utilized with first to upsample (numerator) + and the second to downsample (denominator)""" + + def __init__( + self, + k, + label_rate, + dropout, + channels, + activation, + log_compression=False, + skip_connections=True, + highway=True, + residual_scale=0.4, + non_affine_group_norm=False, + ): + super().__init__() + + def upsample_block(channel, k, stride): + return nn.Sequential( + # with padding (k - 1) // 2 to keep the same size + nn.ConvTranspose1d( + channel, + channel, + k, + stride=stride, + bias=False, + padding=0, # padding=(k - 1) // 2, + output_padding=(stride - 1), + ), + nn.Dropout(p=dropout), + norm_block( + is_layer_norm=False, dim=channel, affine=not non_affine_group_norm + ), + activation, + ) + + assert len(label_rate) == 2, "label_rate should be sized two to apply fusion" + # Lout =(Lin~H~R1)~Wstride~H~R2~Wpadding+dilation~W(kernel_size~H~R1)+output_padding+1 + self.upsample_conv = upsample_block(channels, k, label_rate[0]) + + self.upsample_rate, downsample_rate = label_rate + assert downsample_rate == 1, "must be 1 to perform downsample only" + self.log_compression = log_compression + self.skip_connections = skip_connections + self.highway = highway # Useless + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x, padding=None, mask_indices=None): + # Assume x1 = (B, T, C) as input + x = x.permute(0, 2, 1) + residual_before_upsample = x + x = self.upsample_conv(x) + upsample_size = x.size(2) + + # conduct upsample + if self.skip_connections: + residual_upsample = torch.repeat_interleave( + residual_before_upsample, self.upsample_rate, dim=2 + ) + upsample_size = min(upsample_size, residual_upsample.size(2)) + x = ( + x[..., :upsample_size] + residual_upsample[..., :upsample_size] + ) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + x = x.permute(0, 2, 1) + + # process padding + if padding is not None: + padding = torch.repeat_interleave(padding, self.upsample_rate, dim=1) + padding = padding[..., : x.size(1)] + + # process mask indices + if mask_indices is not None: + mask_indices = torch.repeat_interleave( + mask_indices, self.upsample_rate, dim=1 + ) + mask_indices = mask_indices[..., : x.size(1)] + return x, padding, mask_indices diff --git a/fairseq/models/multires_hubert/multires_hubert_asr.py b/fairseq/models/multires_hubert/multires_hubert_asr.py new file mode 100644 index 00000000..2e7ad99c --- /dev/null +++ b/fairseq/models/multires_hubert/multires_hubert_asr.py @@ -0,0 +1,376 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.nn as nn +from omegaconf import II, MISSING + +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model +from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES +from fairseq.tasks import FairseqTask + + +@dataclass +class MultiresHubertAsrConfig(FairseqDataclass): + multires_hubert_path: str = field( + default=MISSING, metadata={"help": "path to multires_hubert model"} + ) + no_pretrained_weights: bool = field( + default=False, + metadata={"help": "if true, does not load pretrained weights"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "dropout after transformer and before final projection"}, + ) + dropout: float = field( + default=0.0, + metadata={"help": "dropout probability inside hubert model"}, + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " "inside hubert model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " "inside hubert model" + }, + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask " + "(normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + freeze_finetune_updates: int = field( + default=0, + metadata={"help": "dont finetune hubert for this many updates"}, + ) + feature_grad_mult: float = field( + default=0.0, + metadata={"help": "reset feature grad mult in hubert to this"}, + ) + layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a layer in hubert"}, + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + + # this holds the loaded hubert args + multires_hubert_args: Any = None + + +@dataclass +class MultiresHubertCtcConfig(MultiresHubertAsrConfig): + pass + + +@register_model("multires_hubert_ctc", dataclass=MultiresHubertAsrConfig) +class MultiresHubertCtc(BaseFairseqModel): + def __init__( + self, cfg: MultiresHubertAsrConfig, multireshubert_encoder: BaseFairseqModel + ): + super().__init__() + self.cfg = cfg + self.multireshubert_encoder = multireshubert_encoder + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: MultiresHubertAsrConfig, task: FairseqTask): + """Build a new model instance.""" + multireshubert_encoder = MultiresHubertEncoder(cfg, task) + return cls(cfg, multireshubert_encoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output): + logits = net_output["encoder_out"] + padding = net_output["encoder_padding_mask"] + if padding is not None and padding.any(): + padding = padding.T + logits[padding][..., 0] = 0 + logits[padding][..., 1:] = float("-inf") + + return logits + + def forward(self, **kwargs): + x = self.multireshubert_encoder(**kwargs) + return x + + +@dataclass +class MultiresHubertSeq2SeqConfig(MultiresHubertAsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, + metadata={"help": "apply layernorm before each decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings " "(outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " "inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " "inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, + metadata={"help": "share decoder input and output embeddings"}, + ) + + +class MultiresHubertEncoder(FairseqEncoder): + def __init__(self, cfg: MultiresHubertAsrConfig, task): + self.apply_mask = cfg.apply_mask + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.multires_hubert_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.multires_hubert_path, arg_overrides + ) + multires_hubert_args = state.get("cfg", None) + if multires_hubert_args is None: + multires_hubert_args = convert_namespace_to_omegaconf(state["args"]) + cfg.multires_hubert_args = multires_hubert_args + else: + state = None + multires_hubert_args = cfg.multires_hubert_args + if isinstance(multires_hubert_args, Namespace): + cfg.multires_hubert_args = ( + multires_hubert_args + ) = convert_namespace_to_omegaconf(multires_hubert_args) + + assert cfg.normalize == multires_hubert_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + multires_hubert_args.task.data = cfg.data + pretrain_task = tasks.setup_task(multires_hubert_args.task) + if state is not None and "task_state" in state: + # This will load the stored "dictionaries" object + pretrain_task.load_state_dict(state["task_state"]) + else: + pretrain_task.load_state_dict(task.state_dict()) + + model = pretrain_task.build_model( + multires_hubert_args.model, from_checkpoint=True + ) + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + model.load_state_dict(state["model"], strict=False) + + model.remove_pretraining_modules() + + super().__init__(pretrain_task.source_dictionary) + + d = multires_hubert_args.model.encoder_embed_dim + + self.multires_hubert_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + if task.target_dictionary is not None: + self.proj = Linear(d, len(task.target_dictionary)) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + multires_hubert_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + "last_layer": True, + } + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.multires_hubert_model.extract_features( + **multires_hubert_args + ) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index f8dc6a8f..0faba77f 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -1009,7 +1009,7 @@ class TransformerEncoder(nn.Module): layer = checkpoint_wrapper(layer) return layer - def __init__(self, args: Wav2Vec2Config): + def __init__(self, args: Wav2Vec2Config, skip_pos_conv: bool = False, override_encoder_layer: int = None): super().__init__() self.dropout = args.dropout @@ -1045,7 +1045,8 @@ class TransformerEncoder(nn.Module): self.pos_conv = make_conv_block( self.embedding_dim, k, args.conv_pos_groups, num_layers ) - + elif skip_pos_conv: + self.pos_conv = None else: self.pos_conv = make_conv_pos( self.embedding_dim, @@ -1056,8 +1057,13 @@ class TransformerEncoder(nn.Module): else False, ) + if override_encoder_layer is None: + encoder_layers = args.encoder_layers + else: + encoder_layers = override_encoder_layer + self.layers = nn.ModuleList( - [self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)] + [self.build_encoder_layer(args, layer_idx=ii) for ii in range(encoder_layers)] ) self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) @@ -1087,9 +1093,10 @@ class TransformerEncoder(nn.Module): if padding_mask is not None: x = index_put(x, padding_mask, 0) - x_conv = self.pos_conv(x.transpose(1, 2)) - x_conv = x_conv.transpose(1, 2) - x = x + x_conv + if self.pos_conv is not None: + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) diff --git a/fairseq/tasks/multires_hubert_pretraining.py b/fairseq/tasks/multires_hubert_pretraining.py new file mode 100644 index 00000000..cfed147c --- /dev/null +++ b/fairseq/tasks/multires_hubert_pretraining.py @@ -0,0 +1,204 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq.data import Dictionary, HubertDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, + append_eos=False, + add_if_not_exist=False, + ) + + +@dataclass +class MultiresHubertPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr50", "ltr25"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + # label_rate: 1,2,2,5 + # (imply (1,2), (2,5)) + # if base label_rate = 50 + # (1,2), (2,5) --> label rates 50, 25, 10 + label_rate_ratios: List[int] = field(default=MISSING, metadata={"help": "tuple for label rates e.g., [(1,2), (2,5)]"}) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@register_task("multires_hubert_pretraining", dataclass=MultiresHubertPretrainingConfig) +class MultiresHubertPretrainingTask(FairseqTask): + """ + Multiresolution HuBERT Pretraining Task. + The task is based on `HubertPretrainingTask` but extended to multiresolution. + """ + + cfg: MultiresHubertPretrainingConfig + + def __init__( + self, + cfg: MultiresHubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"MultiresHubertPretrainingTask Config {cfg}") + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + self.res_number = 1 + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + @classmethod + def setup_task( + cls, cfg: MultiresHubertPretrainingConfig, **kwargs + ) -> "MultiresHubertPretrainingTask": + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + self.res_number = len(label_dir) + dictionaries = [ (Dictionary.load(f"{label_dir}/dict.{label}.txt") if label is not "" else None ) for label in self.cfg.labels] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [(dict.pad() if dict is not None else None) for dict in dicts] + eos_list = [(dict.eos() if dict is not None else None) for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] + paths = [(f"{self.get_label_dir()}/{split}.{l}" if l != "" else None) for l in self.cfg.labels] + + base_rate = self.cfg.label_rate + self.label_rates = [base_rate] + label_rate_ratios = self.cfg.label_rate_ratios + self.label_rate_ratios = [] + for i in range(len(label_rate_ratios) // 2): + + upsample_rate, downsample_rate = label_rate_ratios[i * 2], label_rate_ratios[i * 2 + 1] + # parse label rate ratios + self.label_rate_ratios.append((upsample_rate, downsample_rate)) + base_rate = base_rate * upsample_rate // downsample_rate + self.label_rates.append(base_rate) + + # hubert v1: pad_audio=True, random_crop=False; + self.datasets[split] = HubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.label_rates, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_keep_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: + return indices