mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Merge branch 'main' into wav2vec_readme_update
This commit is contained in:
commit
41035ae21c
@ -1,128 +0,0 @@
|
||||
# Use 2.1 for orbs
|
||||
version: 2.1
|
||||
|
||||
# -------------------------------------------------------------------------------------
|
||||
# Environments to run the jobs in
|
||||
# -------------------------------------------------------------------------------------
|
||||
gpu: &gpu
|
||||
environment:
|
||||
CUDA_VERSION: "11.2"
|
||||
machine:
|
||||
image: ubuntu-2004-cuda-11.2:202103-01
|
||||
resource_class: gpu.nvidia.medium.multi
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------------
|
||||
# Re-usable commands
|
||||
# -------------------------------------------------------------------------------------
|
||||
cache_key: &cache_key cache-key-{{ .Environment.CIRCLE_JOB }}-{{ checksum ".circleci/config.yml" }}-{{ checksum "setup.py"}}
|
||||
|
||||
install_dep_pt1_10: &install_dep_pt1_10
|
||||
- run:
|
||||
name: Install Pytorch Dependencies
|
||||
command: |
|
||||
source activate fairseq
|
||||
pip install --upgrade setuptools
|
||||
pip install torch==1.10.1+cu111 torchaudio==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python -c 'import torch; print("Torch version:", torch.__version__)'
|
||||
|
||||
install_dep_pt1_12: &install_dep_pt1_12
|
||||
- run:
|
||||
name: Install Pytorch Dependencies
|
||||
command: |
|
||||
source activate fairseq
|
||||
pip install --upgrade setuptools
|
||||
pip install torch==1.12.1+cu116 torchaudio==0.12.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python -c 'import torch; print("Torch version:", torch.__version__)'
|
||||
|
||||
install_repo: &install_repo
|
||||
- run:
|
||||
name: Install Repository
|
||||
command: |
|
||||
source activate fairseq
|
||||
python -m pip install fairscale
|
||||
python -m pip install -e '.[dev,docs]'
|
||||
python -c 'import torch; print("Torch version:", torch.__version__)'
|
||||
|
||||
run_unittests: &run_unittests
|
||||
- run:
|
||||
name: Run Unit Tests
|
||||
command: |
|
||||
source activate fairseq
|
||||
pytest tests/gpu/test_binaries_gpu.py
|
||||
|
||||
check_nvidia_driver: &check_nvidia_driver
|
||||
- run:
|
||||
name: Check NVIDIA Driver
|
||||
working_directory: ~/
|
||||
command: |
|
||||
pyenv versions
|
||||
nvidia-smi
|
||||
|
||||
create_conda_env: &create_conda_env
|
||||
- run:
|
||||
name: Install and Create Conda Environment
|
||||
command: |
|
||||
curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
chmod +x ~/miniconda.sh
|
||||
bash ~/miniconda.sh -b -p $HOME/miniconda
|
||||
rm ~/miniconda.sh
|
||||
echo 'export PATH=$HOME/miniconda/bin:$PATH' >> $BASH_ENV
|
||||
source $BASH_ENV
|
||||
if [ ! -d ~/miniconda/envs/fairseq ]
|
||||
then
|
||||
conda create -y -n fairseq python=3.8
|
||||
fi
|
||||
source activate fairseq
|
||||
python --version
|
||||
pip install --upgrade pip
|
||||
# -------------------------------------------------------------------------------------
|
||||
# Jobs to run
|
||||
# -------------------------------------------------------------------------------------
|
||||
|
||||
jobs:
|
||||
|
||||
gpu_tests_pt1_10:
|
||||
<<: *gpu
|
||||
|
||||
working_directory: ~/fairseq-py
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- <<: *check_nvidia_driver
|
||||
- <<: *create_conda_env
|
||||
- restore_cache:
|
||||
key: *cache_key
|
||||
- <<: *install_dep_pt1_10
|
||||
- save_cache:
|
||||
paths:
|
||||
- ~/miniconda/
|
||||
key: *cache_key
|
||||
- <<: *install_repo
|
||||
- <<: *run_unittests
|
||||
|
||||
gpu_tests_pt1_12:
|
||||
<<: *gpu
|
||||
|
||||
working_directory: ~/fairseq-py
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- <<: *check_nvidia_driver
|
||||
- <<: *create_conda_env
|
||||
- restore_cache:
|
||||
key: *cache_key
|
||||
- <<: *install_dep_pt1_12
|
||||
- save_cache:
|
||||
paths:
|
||||
- ~/miniconda/
|
||||
key: *cache_key
|
||||
- <<: *install_repo
|
||||
- <<: *run_unittests
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build:
|
||||
jobs:
|
||||
- gpu_tests_pt1_12
|
||||
- gpu_tests_pt1_10
|
14
.github/workflows/depreview.yml
vendored
Normal file
14
.github/workflows/depreview.yml
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
name: 'Dependency Review'
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
dependency-review:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: 'Checkout Repository'
|
||||
uses: actions/checkout@v4
|
||||
- name: Dependency Review
|
||||
uses: actions/dependency-review-action@v4
|
@ -2,7 +2,7 @@
|
||||
|
||||
## Model details
|
||||
|
||||
**Organization developing the model** The FAIR team of Meta AI.
|
||||
**Organization developing the model** The FAIR team
|
||||
|
||||
**Model version** This is version 1 of the model.
|
||||
|
||||
|
@ -177,6 +177,10 @@ We also provide an Ipython notebook example inside `lid/tutorial` folder [ipynb]
|
||||
|
||||
MMS Adapter fine-tuning has been added to the official 🤗 Transformers examples [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#connectionist-temporal-classification-with-adapters).
|
||||
For a more step-by-step explanation of how to fine-tune MMS, please have a look at the blog [**Fine-tuning MMS Adapter Models for Multi-Lingual ASR**](https://huggingface.co/blog/mms_adapters) on 🤗 blogs.
|
||||
|
||||
### TTS
|
||||
|
||||
For a guide on how to fine-tune MMS TTS checkpoints using the 🤗 Transformer implementation, please have a look at this [repository](https://github.com/ylacombe/finetune-hf-vits).
|
||||
|
||||
## Pretrained models
|
||||
|
||||
|
@ -14,8 +14,8 @@ We describe the process of aligning long audio files with their transcripts and
|
||||
|
||||
- Step 3: Install a few other dependencies
|
||||
```
|
||||
pip install sox
|
||||
pip install dataclasses
|
||||
apt install sox
|
||||
pip install sox dataclasses
|
||||
```
|
||||
|
||||
- Step 4: Create a text file containing the transcript for a (long) audio file. Each line in the text file will correspond to a separate audio segment that will be generated upon alignment.
|
||||
@ -29,7 +29,7 @@ We describe the process of aligning long audio files with their transcripts and
|
||||
|
||||
- Step 5: Run forced alignment and segment the audio file into shorter segments.
|
||||
```
|
||||
python align_and_segment.py --audio /path/to/audio.wav --textfile /path/to/textfile --lang <iso> --outdir /path/to/output --uroman /path/to/uroman/bin
|
||||
python align_and_segment.py --audio /path/to/audio.wav --text_filepath /path/to/textfile --lang <iso> --outdir /path/to/output --uroman /path/to/uroman/bin
|
||||
```
|
||||
|
||||
The above code will generated the audio segments under output directory based on the content of each line in the input text file. The `manifest.json` file consisting of the of segmented audio filepaths and their corresponding transcripts.
|
||||
|
@ -87,13 +87,14 @@ def get_alignments(
|
||||
blank = dictionary["<blank>"]
|
||||
|
||||
targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE)
|
||||
input_lengths = torch.tensor(emissions.shape[0])
|
||||
target_lengths = torch.tensor(targets.shape[0])
|
||||
|
||||
|
||||
input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1)
|
||||
target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1)
|
||||
path, _ = F.forced_align(
|
||||
emissions, targets, input_lengths, target_lengths, blank=blank
|
||||
emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank
|
||||
)
|
||||
path = path.to("cpu").tolist()
|
||||
path = path.squeeze().to("cpu").tolist()
|
||||
|
||||
segments = merge_repeats(path, {v: k for k, v in dictionary.items()})
|
||||
return segments, stride
|
||||
|
||||
|
@ -6,12 +6,12 @@ We follow the recommendations of Gebru et al. (2018) and provide a datacard for
|
||||
## Motivation
|
||||
* **For what purpose was the dataset created? Was there a specific task in mind? Was there a specific gap that needed to be filled? Please provide a description.**
|
||||
The pre-training data for training the 1.1 T model was created by a union of six English language datasets, including five datasets used by RoBERTa (Liu et al 2019) and the English subset of CC 100. These purpose of creating this dataset was to pre-train the language model.
|
||||
|
||||
|
||||
* **Who created the dataset (e.g., which team, research group) and on behalf of which entity (e.g., company, institution, organization)?**
|
||||
Meta AI.
|
||||
|
||||
FAIR (Fundamental Artificial Intelligence Research)
|
||||
|
||||
* **Who funded the creation of the dataset? If there is an associated grant, please provide the name of the grantor and the grant name and number.**
|
||||
Meta AI.
|
||||
FAIR (Fundamental Artificial Intelligence Research)
|
||||
|
||||
* **Any other comments?**
|
||||
No.
|
||||
@ -183,7 +183,7 @@ No.
|
||||
## Maintenance
|
||||
|
||||
* **Who is supporting/hosting/maintaining the dataset?**
|
||||
Meta AI.
|
||||
FAIR (Fundamental Artificial Intelligence Research)
|
||||
|
||||
* **How can the owner/curator/manager of the dataset be contacted (e.g., email address)?**
|
||||
Refer to the main document.
|
||||
|
@ -2,7 +2,7 @@
|
||||
## Version 1.0.0
|
||||
|
||||
### Model developer
|
||||
Meta AI
|
||||
FAIR (Fundamental Artificial Intelligence Research)
|
||||
|
||||
### Model type
|
||||
An autoregressive English language model trained on a union of six English language models. We explore dense and sparse (MoE based) architectures in the paper.
|
||||
@ -132,7 +132,7 @@ A dataset extracted from CommonCrawl snapshots between January 2018 and December
|
||||
The 1.1T parameter model was evaluated on the StereoSet and CrowS pairs dataset for inherent bias in the model, and bias as a result of the data. Similar to StereoSet, we observe that both the dense and MoE models get worse in terms of the Stereotype Score (SS) with scale.
|
||||
|
||||
### Privacy and security
|
||||
The 1.1T model did not have any special Privacy and Security considerations. The training data and evaluation data were both public and went through standard Meta AI Privacy and licensing procedures.
|
||||
The 1.1T model did not have any special Privacy and Security considerations. The training data and evaluation data were both public and went through standard Meta privacy and licensing procedures.
|
||||
|
||||
### Transparency and control
|
||||
In the spirit of transparency and accountability we have created this model card for the 1.1T parameter model and a data card for the training data (referenced in Artetxe et al. (2021)).
|
||||
|
187
examples/mr_hubert/README.md
Normal file
187
examples/mr_hubert/README.md
Normal file
@ -0,0 +1,187 @@
|
||||
# MR-HuBERT
|
||||
|
||||
## Pre-trained models
|
||||
|
||||
### Main models
|
||||
Model | Pretraining Data | Model | Paper Reference
|
||||
|---|---|---|---
|
||||
MR-HuBERT Base (~97M) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_base/mrhubert_mono_base.pt) | mono\_base
|
||||
MR-HuBERT Base (~321M) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_large/mrhubert_mono_large.pt) | mono\_large
|
||||
Multilingual MR-HuBERT Base (~97M) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_base/multi_base.pt) | multi\_base
|
||||
Multilingual MR-HuBERT Large (~321M) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download 400k steps](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large_400k.pt) or [download 600k steps](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large_600k.pt) | Not in the paper
|
||||
|
||||
|
||||
### Abalation models
|
||||
Model | Pretraining Data | Model | Paper Reference
|
||||
|---|---|---|---
|
||||
MR-HuBERT Base (2-4-6 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-a/b1-a.pt) | (B.1)-a
|
||||
MR-HuBERT Base (5-2-5 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-b/b1-b.pt) | (B.1)-b
|
||||
MR-HuBERT Base (6-4-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-c/b1-c.pt) | (B.1)-c
|
||||
MR-HuBERT Base (3res 3-2-2-2-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-a/b2-a.pt) | (B.2)-a
|
||||
MR-HuBERT Base (3res 2-2-4-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-b/b2-b.pt) | (B.2)-b
|
||||
MR-HuBERT Base (3res 2-2-2-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-c/b2-c.pt) | (B.2)-c
|
||||
MR-HuBERT Base (Simple sampling) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b3-a/b3-a.pt) | (B.3)-a
|
||||
MR-HuBERT Base (Single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-a/b4-a.pt) | (B.4)-a
|
||||
MR-HuBERT Base (Simple Sampling + single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-b/b4-b.pt) | (B.4)-b
|
||||
MR-HuBERT Base (Mono-resolution 20ms) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b5-a/b5-a.pt) | (B.5)-a
|
||||
MR-HuBERT Base (3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-a/b6-a.pt) | (B.6)-a
|
||||
MR-HuBERT Base (Mono-resolution 20ms, 3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-b/b6-b.pt) | (B.6)-b
|
||||
MR-HuBERT Base (HuBERT 20ms&40ms units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-a/b7-a.pt) | (B.7)-a
|
||||
MR-HuBERT Base (Encodec 50Hz unit) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-b/b7-b.pt) | (B.7)-b
|
||||
MR-HuBERT Base (Encodec 50Hz units and 25Hz units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-c/b7-c.pt) | (B.7)-c
|
||||
MR-HuBERT Base (Encodec 50Hz units stream 0&1 ) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-d/b7-d.pt) | (B.7)-d
|
||||
MR-HuBERT Large (no audio norm) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-a.pt) | (B.8)-a
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-b/b8-b.pt) | (B.8)-b
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-c/b8-c.pt) | (B.8)-c
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-d/b8-d.pt) | (B.8)-d
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-e/b8-e.pt) | (B.8)-e
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-f/b8-f.pt) | (B.8)-f
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-g/b8-g.pt) | (B.8)-g
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-h/b8-h.pt) | (B.8)-h
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-i/b8-i.pt) | (B.8)-i
|
||||
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-j/b8-j.pt) | (B.8)-j
|
||||
Multilingual MR-HuBERT Large (Simple sampling) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_large_simple/multi_large_simple.pt) | Not in paper
|
||||
MR-HuBERT xLarge (from HuBERT-base label) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_xlarge/v1.pt) | Not in paper
|
||||
MR-HuBERT xLarge (from HuBERT-large label) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_xlarge/v2.pt) | Not in paper
|
||||
|
||||
## Load a model
|
||||
```
|
||||
ckpt_path = "/path/to/the/checkpoint.pt"
|
||||
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
||||
model = models[0]
|
||||
```
|
||||
|
||||
## Train a new model
|
||||
|
||||
### Data preparation
|
||||
|
||||
Follow the steps in `./simple_kmeans` to create:
|
||||
- `{train,valid}.tsv` waveform list files with length information
|
||||
```
|
||||
/path/to/your/audio/files
|
||||
file1.wav\t160000
|
||||
file2.wav\t154600
|
||||
...
|
||||
filen.wav\t54362
|
||||
```
|
||||
- `{train,valid}.km` frame-aligned pseudo label files (the order is the same as wavefiles in the tsv file).
|
||||
```
|
||||
44 44 44 48 48 962 962 962 962 962 962 962 962 967 967 967 967 967 967 967 967 370 852 370 ... 18 18 745 745
|
||||
44 44 44 48 48 962 962 962 147 147 147 147 147 147 147 147 147 147 147 147 176 176 271 271 ... 27 27 745 745
|
||||
...
|
||||
44 44 44 48 962 962 962 962 962 962 377 377 377 77 77 852 696 694 433 578 578 82 740 622 ... 27 27 745 745
|
||||
```
|
||||
- `dict.km.txt` a dummy dictionary (first column is id, the second is dummy one)
|
||||
```
|
||||
0 1
|
||||
1 1
|
||||
2 1
|
||||
...
|
||||
999 1
|
||||
```
|
||||
|
||||
The `label_rate` is the same as the feature frame rate used for clustering,
|
||||
which is 100Hz for MFCC features and 50Hz for HuBERT features by default.
|
||||
|
||||
### Pre-train a MR-HuBERT model
|
||||
|
||||
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km`
|
||||
are saved at `/path/to/labels`, and the label rate is 100Hz.
|
||||
|
||||
To train a base model (12 layer transformer), run:
|
||||
```sh
|
||||
$ python fairseq_cli/hydra_train.py \
|
||||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/pretrain \
|
||||
--config-name mrhubert_base_librispeech \
|
||||
task.data=/path/to/data task.label_dir=/path/to/labels \
|
||||
task.labels='["km"]' model.label_rate=100 \
|
||||
task.label_rate_ratios='[1, 2]' \
|
||||
```
|
||||
|
||||
Please see sample pre-training scripts `train.sh` for an example script.
|
||||
|
||||
### Fine-tune a MR-HuBERT model with a CTC loss
|
||||
|
||||
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their
|
||||
corresponding character transcripts `{train,valid}.ltr` are saved at
|
||||
`/path/to/trans`. A typical ltr file is with the same order of tsv waveform files as
|
||||
```
|
||||
HOW | ARE | YOU
|
||||
...
|
||||
THANK | YOU
|
||||
```
|
||||
|
||||
To fine-tune a pre-trained MR-HuBERT model at `/path/to/checkpoint`, run
|
||||
```sh
|
||||
$ python fairseq_cli/hydra_train.py \
|
||||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/finetune \
|
||||
--config-name base_10h \
|
||||
task.data=/path/to/data task.label_dir=/path/to/trans \
|
||||
model.w2v_path=/path/to/checkpoint
|
||||
```
|
||||
|
||||
Please see sample fine-tuning scripts `finetune.sh` for an example script.
|
||||
|
||||
### Decode a MR-HuBERT model
|
||||
|
||||
Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of
|
||||
the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is
|
||||
saved at `/path/to/checkpoint`.
|
||||
|
||||
|
||||
We support three decoding modes:
|
||||
- Viterbi decoding: greedy decoding without a language model
|
||||
- KenLM decoding: decoding with an arpa-format KenLM n-gram language model
|
||||
- Fairseq-LM deocding: decoding with a Fairseq neural language model (not fully tested)
|
||||
|
||||
|
||||
#### Viterbi decoding
|
||||
|
||||
`task.normalize` needs to be consistent with the value used during fine-tuning.
|
||||
Decoding results will be saved at
|
||||
`/path/to/experiment/directory/decode/viterbi/test`.
|
||||
|
||||
```sh
|
||||
$ python examples/speech_recognition/new/infer.py \
|
||||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \
|
||||
--config-name infer \
|
||||
task.data=/path/to/data \
|
||||
task.normalize=[true|false] \
|
||||
decoding.exp_dir=/path/to/experiment/directory \
|
||||
common_eval.path=/path/to/checkpoint
|
||||
dataset.gen_subset=test \
|
||||
```
|
||||
|
||||
#### KenLM / Fairseq-LM decoding
|
||||
|
||||
Suppose the pronunciation lexicon and the n-gram LM are saved at
|
||||
`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be
|
||||
saved at `/path/to/experiment/directory/decode/kenlm/test`.
|
||||
|
||||
```sh
|
||||
$ python examples/speech_recognition/new/infer.py \
|
||||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \
|
||||
--config-name infer_lm \
|
||||
task.data=/path/to/data \
|
||||
task.normalize=[true|false] \
|
||||
decoding.exp_dir=/path/to/experiment/directory \
|
||||
common_eval.path=/path/to/checkpoint
|
||||
dataset.gen_subset=test \
|
||||
decoding.decoder.lexicon=/path/to/lexicon \
|
||||
decoding.decoder.lmpath=/path/to/arpa
|
||||
```
|
||||
|
||||
The command above uses the default decoding hyperparameter, which can be found
|
||||
in `examples/speech_recognition/hydra/decoder.py`. These parameters can be
|
||||
configured from the command line. For example, to search with a beam size of
|
||||
500, we can append the command above with `decoding.decoder.beam=500`.
|
||||
Important parameters include:
|
||||
- decoding.decoder.beam
|
||||
- decoding.decoder.beamthreshold
|
||||
- decoding.decoder.lmweight
|
||||
- decoding.decoder.wordscore
|
||||
- decoding.decoder.silweight
|
||||
|
||||
To decode with a Fairseq LM, you may check the usage examples in wav2vec2 or hubert examples.
|
||||
|
||||
Please see sample decoding scripts `decode.sh` for an example script.
|
30
examples/mr_hubert/config/decode/infer.yaml
Normal file
30
examples/mr_hubert/config/decode/infer.yaml
Normal file
@ -0,0 +1,30 @@
|
||||
# @package _group_
|
||||
|
||||
defaults:
|
||||
- model: null
|
||||
|
||||
hydra:
|
||||
run:
|
||||
dir: ${common_eval.results_path}/viterbi
|
||||
sweep:
|
||||
dir: ${common_eval.results_path}
|
||||
subdir: viterbi
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
single_target: true
|
||||
fine_tuning: true
|
||||
label_rate_ratios: ???
|
||||
data: ???
|
||||
normalize: false
|
||||
|
||||
decoding:
|
||||
type: viterbi
|
||||
unique_wer_file: true
|
||||
common_eval:
|
||||
results_path: ???
|
||||
path: ???
|
||||
post_process: letter
|
||||
dataset:
|
||||
max_tokens: 1100000
|
||||
gen_subset: ???
|
37
examples/mr_hubert/config/decode/infer_lm.yaml
Normal file
37
examples/mr_hubert/config/decode/infer_lm.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# @package _group_
|
||||
|
||||
defaults:
|
||||
- model: null
|
||||
|
||||
hydra:
|
||||
run:
|
||||
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
|
||||
sweep:
|
||||
dir: ${common_eval.results_path}
|
||||
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
single_target: true
|
||||
fine_tuning: true
|
||||
data: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: ???
|
||||
|
||||
decoding:
|
||||
type: kenlm
|
||||
lexicon: ???
|
||||
lmpath: ???
|
||||
beamthreshold: 100
|
||||
beam: 500
|
||||
lmweight: 1.5
|
||||
wordscore: -1
|
||||
silweight: 0
|
||||
unique_wer_file: true
|
||||
common_eval:
|
||||
results_path: ???
|
||||
path: ???
|
||||
post_process: letter
|
||||
dataset:
|
||||
max_tokens: 1100000
|
||||
gen_subset: ???
|
17
examples/mr_hubert/config/decode/run/submitit_slurm.yaml
Normal file
17
examples/mr_hubert/config/decode/run/submitit_slurm.yaml
Normal file
@ -0,0 +1,17 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
launcher:
|
||||
cpus_per_task: ${distributed_training.distributed_world_size}
|
||||
gpus_per_node: ${distributed_training.distributed_world_size}
|
||||
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
||||
nodes: 1
|
||||
mem_gb: 200
|
||||
timeout_min: 4320
|
||||
max_num_timeout: 50
|
||||
name: ${hydra.job.config_name}
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
distributed_no_spawn: true
|
||||
distributed_port: 29761
|
@ -0,0 +1,17 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
launcher:
|
||||
cpus_per_task: ${distributed_training.distributed_world_size}
|
||||
gpus_per_node: ${distributed_training.distributed_world_size}
|
||||
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
||||
nodes: 1
|
||||
mem_gb: 200
|
||||
timeout_min: 4320
|
||||
max_num_timeout: 50
|
||||
name: ${hydra.job.config_name}
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 8
|
||||
distributed_no_spawn: true
|
||||
distributed_port: 29761
|
97
examples/mr_hubert/config/finetune/base_100h.yaml
Normal file
97
examples/mr_hubert/config/finetune/base_100h.yaml
Normal file
@ -0,0 +1,97 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tblog
|
||||
seed: 1337
|
||||
|
||||
checkpoint:
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: wer
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: c10d
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 8
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
fine_tuning: true
|
||||
label_dir: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: false # must be consistent with pre-training
|
||||
labels: ["ltr"]
|
||||
single_target: true
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 3200000
|
||||
validate_after_updates: ${model.freeze_finetune_updates}
|
||||
validate_interval: 5
|
||||
train_subset: train_100h
|
||||
valid_subset: dev_other
|
||||
|
||||
criterion:
|
||||
_name: ctc
|
||||
zero_infinity: true
|
||||
|
||||
optimization:
|
||||
max_update: 80000
|
||||
lr: [3e-5]
|
||||
sentence_avg: true
|
||||
update_freq: [1]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: tri_stage
|
||||
phase_ratio: [0.1, 0.4, 0.5]
|
||||
final_lr_scale: 0.05
|
||||
|
||||
model:
|
||||
_name: multires_hubert_ctc
|
||||
multires_hubert_path: ???
|
||||
apply_mask: true
|
||||
mask_selection: static
|
||||
mask_length: 10
|
||||
mask_other: 0
|
||||
mask_prob: 0.75
|
||||
mask_channel_selection: static
|
||||
mask_channel_length: 64
|
||||
mask_channel_other: 0
|
||||
mask_channel_prob: 0.5
|
||||
layerdrop: 0.1
|
||||
dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
freeze_finetune_updates: 10000
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- model.multires_hubert_path
|
||||
- dataset.train_subset
|
||||
- dataset.valid_subset
|
||||
- criterion.wer_kenlm_model
|
||||
- criterion.wer_lexicon
|
||||
run:
|
||||
dir: ???
|
||||
sweep:
|
||||
dir: ???
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
97
examples/mr_hubert/config/finetune/base_100h_large.yaml
Normal file
97
examples/mr_hubert/config/finetune/base_100h_large.yaml
Normal file
@ -0,0 +1,97 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tblog
|
||||
seed: 1337
|
||||
|
||||
checkpoint:
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: wer
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: c10d
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 8
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
fine_tuning: true
|
||||
label_dir: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: true # must be consistent with pre-training
|
||||
labels: ["ltr"]
|
||||
single_target: true
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 1600000
|
||||
validate_after_updates: ${model.freeze_finetune_updates}
|
||||
validate_interval: 5
|
||||
train_subset: train_100h
|
||||
valid_subset: dev_other
|
||||
|
||||
criterion:
|
||||
_name: ctc
|
||||
zero_infinity: true
|
||||
|
||||
optimization:
|
||||
max_update: 80000
|
||||
lr: [3e-5]
|
||||
sentence_avg: true
|
||||
update_freq: [2]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: tri_stage
|
||||
phase_ratio: [0.1, 0.4, 0.5]
|
||||
final_lr_scale: 0.05
|
||||
|
||||
model:
|
||||
_name: multires_hubert_ctc
|
||||
multires_hubert_path: ???
|
||||
apply_mask: true
|
||||
mask_selection: static
|
||||
mask_length: 10
|
||||
mask_other: 0
|
||||
mask_prob: 0.75
|
||||
mask_channel_selection: static
|
||||
mask_channel_length: 64
|
||||
mask_channel_other: 0
|
||||
mask_channel_prob: 0.5
|
||||
layerdrop: 0.1
|
||||
dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
freeze_finetune_updates: 10000
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- model.multires_hubert_path
|
||||
- dataset.train_subset
|
||||
- dataset.valid_subset
|
||||
- criterion.wer_kenlm_model
|
||||
- criterion.wer_lexicon
|
||||
run:
|
||||
dir: ???
|
||||
sweep:
|
||||
dir: ???
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
101
examples/mr_hubert/config/finetune/base_10h.yaml
Normal file
101
examples/mr_hubert/config/finetune/base_10h.yaml
Normal file
@ -0,0 +1,101 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tblog
|
||||
seed: 1337
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: wer
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: c10d
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 8
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
fine_tuning: true
|
||||
label_dir: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: false # must be consistent with pre-training
|
||||
labels: ["ltr"]
|
||||
single_target: true
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 3200000
|
||||
validate_after_updates: ${model.freeze_finetune_updates}
|
||||
validate_interval: 5
|
||||
train_subset: train_10h
|
||||
valid_subset: dev
|
||||
|
||||
criterion:
|
||||
_name: ctc
|
||||
zero_infinity: true
|
||||
|
||||
optimization:
|
||||
max_update: 25000
|
||||
lr: [2e-5]
|
||||
sentence_avg: true
|
||||
update_freq: [1]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: tri_stage
|
||||
warmup_steps: 8000
|
||||
hold_steps: 0
|
||||
decay_steps: 72000
|
||||
final_lr_scale: 0.05
|
||||
|
||||
model:
|
||||
_name: multires_hubert_ctc
|
||||
multires_hubert_path: ???
|
||||
apply_mask: true
|
||||
mask_selection: static
|
||||
mask_length: 10
|
||||
mask_other: 0
|
||||
mask_prob: 0.75
|
||||
mask_channel_selection: static
|
||||
mask_channel_length: 64
|
||||
mask_channel_other: 0
|
||||
mask_channel_prob: 0.5
|
||||
layerdrop: 0.1
|
||||
dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
freeze_finetune_updates: 10000
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- model.multires_hubert_path
|
||||
- dataset.train_subset
|
||||
- dataset.valid_subset
|
||||
- criterion.wer_kenlm_model
|
||||
- criterion.wer_lexicon
|
||||
run:
|
||||
dir: ???
|
||||
sweep:
|
||||
dir: ???
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
101
examples/mr_hubert/config/finetune/base_10h_large.yaml
Normal file
101
examples/mr_hubert/config/finetune/base_10h_large.yaml
Normal file
@ -0,0 +1,101 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tblog
|
||||
seed: 1337
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: wer
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: c10d
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 8
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
fine_tuning: true
|
||||
label_dir: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: true # must be consistent with pre-training
|
||||
labels: ["ltr"]
|
||||
single_target: true
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 3200000
|
||||
validate_after_updates: ${model.freeze_finetune_updates}
|
||||
validate_interval: 5
|
||||
train_subset: train_10h
|
||||
valid_subset: dev
|
||||
|
||||
criterion:
|
||||
_name: ctc
|
||||
zero_infinity: true
|
||||
|
||||
optimization:
|
||||
max_update: 25000
|
||||
lr: [2e-5]
|
||||
sentence_avg: true
|
||||
update_freq: [1]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: tri_stage
|
||||
warmup_steps: 8000
|
||||
hold_steps: 0
|
||||
decay_steps: 72000
|
||||
final_lr_scale: 0.05
|
||||
|
||||
model:
|
||||
_name: multires_hubert_ctc
|
||||
multires_hubert_path: ???
|
||||
apply_mask: true
|
||||
mask_selection: static
|
||||
mask_length: 10
|
||||
mask_other: 0
|
||||
mask_prob: 0.75
|
||||
mask_channel_selection: static
|
||||
mask_channel_length: 64
|
||||
mask_channel_other: 0
|
||||
mask_channel_prob: 0.5
|
||||
layerdrop: 0.1
|
||||
dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
freeze_finetune_updates: 10000
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- model.multires_hubert_path
|
||||
- dataset.train_subset
|
||||
- dataset.valid_subset
|
||||
- criterion.wer_kenlm_model
|
||||
- criterion.wer_lexicon
|
||||
run:
|
||||
dir: ???
|
||||
sweep:
|
||||
dir: ???
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
100
examples/mr_hubert/config/finetune/base_1h.yaml
Normal file
100
examples/mr_hubert/config/finetune/base_1h.yaml
Normal file
@ -0,0 +1,100 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tblog
|
||||
seed: 1337
|
||||
|
||||
checkpoint:
|
||||
save_interval: 50
|
||||
keep_interval_updates: 1
|
||||
save_interval_updates: 1000
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: wer
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: c10d
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 8
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
fine_tuning: true
|
||||
label_dir: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: false # must be consistent with pre-training
|
||||
labels: ["ltr"]
|
||||
single_target: true
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 3200000
|
||||
validate_after_updates: ${model.freeze_finetune_updates}
|
||||
validate_interval: 1000
|
||||
train_subset: train_1h
|
||||
valid_subset: dev_other
|
||||
|
||||
criterion:
|
||||
_name: ctc
|
||||
zero_infinity: true
|
||||
|
||||
optimization:
|
||||
max_update: 13000
|
||||
lr: [5e-5]
|
||||
sentence_avg: true
|
||||
update_freq: [4]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: tri_stage
|
||||
phase_ratio: [0.1, 0.4, 0.5]
|
||||
final_lr_scale: 0.05
|
||||
|
||||
model:
|
||||
_name: multires_hubert_ctc
|
||||
multires_hubert_path: ???
|
||||
apply_mask: true
|
||||
mask_selection: static
|
||||
mask_length: 10
|
||||
mask_other: 0
|
||||
mask_prob: 0.75
|
||||
mask_channel_selection: static
|
||||
mask_channel_length: 64
|
||||
mask_channel_other: 0
|
||||
mask_channel_prob: 0.5
|
||||
layerdrop: 0.1
|
||||
dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
freeze_finetune_updates: 10000
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- model.multires_hubert_path
|
||||
- dataset.train_subset
|
||||
- dataset.valid_subset
|
||||
- criterion.wer_kenlm_model
|
||||
- criterion.wer_lexicon
|
||||
run:
|
||||
dir: ???
|
||||
sweep:
|
||||
dir: ???
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
99
examples/mr_hubert/config/finetune/base_1h_large.yaml
Normal file
99
examples/mr_hubert/config/finetune/base_1h_large.yaml
Normal file
@ -0,0 +1,99 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tblog
|
||||
seed: 1337
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: wer
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: c10d
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 8
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
fine_tuning: true
|
||||
label_dir: ???
|
||||
label_rate_ratios: ???
|
||||
normalize: true # must be consistent with pre-training
|
||||
labels: ["ltr"]
|
||||
single_target: true
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 1280000
|
||||
validate_after_updates: ${model.freeze_finetune_updates}
|
||||
validate_interval: 5
|
||||
train_subset: train_10h
|
||||
valid_subset: dev
|
||||
|
||||
criterion:
|
||||
_name: ctc
|
||||
zero_infinity: true
|
||||
|
||||
optimization:
|
||||
max_update: 25000
|
||||
lr: [3e-4]
|
||||
sentence_avg: true
|
||||
update_freq: [5]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: tri_stage
|
||||
phase_ratio: [0.1, 0.4, 0.5]
|
||||
final_lr_scale: 0.05
|
||||
|
||||
model:
|
||||
_name: multires_hubert_ctc
|
||||
multires_hubert_path: ???
|
||||
apply_mask: true
|
||||
mask_selection: static
|
||||
mask_length: 10
|
||||
mask_other: 0
|
||||
mask_prob: 0.75
|
||||
mask_channel_selection: static
|
||||
mask_channel_length: 64
|
||||
mask_channel_other: 0
|
||||
mask_channel_prob: 0.5
|
||||
layerdrop: 0.1
|
||||
dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
freeze_finetune_updates: 10000
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- model.multires_hubert_path
|
||||
- dataset.train_subset
|
||||
- dataset.valid_subset
|
||||
- criterion.wer_kenlm_model
|
||||
- criterion.wer_lexicon
|
||||
run:
|
||||
dir: ???
|
||||
sweep:
|
||||
dir: ???
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
@ -0,0 +1,103 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
seed: 1337
|
||||
tensorboard_logdir: tblog
|
||||
min_loss_scale: 1e-8
|
||||
|
||||
checkpoint:
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: no_c10d
|
||||
distributed_backend: 'nccl'
|
||||
distributed_world_size: 32
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
find_unused_parameters: true
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
label_dir: ???
|
||||
labels: ???
|
||||
label_rate: ${model.label_rate}
|
||||
label_rate_ratios: ???
|
||||
sample_rate: 16000
|
||||
max_sample_size: 250000
|
||||
min_sample_size: 32000
|
||||
pad_audio: false
|
||||
random_crop: true
|
||||
normalize: false # must be consistent with extractor
|
||||
# max_keep_size: 300000
|
||||
# max_keep_size: 50000
|
||||
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 1000000
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
validate_interval: 5
|
||||
validate_interval_updates: 10000
|
||||
|
||||
criterion:
|
||||
_name: hubert
|
||||
pred_masked_weight: 1.0
|
||||
pred_nomask_weight: 0.0
|
||||
loss_weights: [10,]
|
||||
|
||||
optimization:
|
||||
max_update: 400000
|
||||
lr: [0.0005]
|
||||
clip_norm: 10.0
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 32000
|
||||
|
||||
model:
|
||||
_name: multires_hubert
|
||||
label_rate: ???
|
||||
label_rate_ratios: ${task.label_rate_ratios}
|
||||
skip_masked: false
|
||||
skip_nomask: false
|
||||
mask_prob: 0.80
|
||||
extractor_mode: default
|
||||
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
||||
final_dim: 256
|
||||
encoder_layers: 4
|
||||
encoder_layerdrop: 0.05
|
||||
dropout_input: 0.1
|
||||
dropout_features: 0.1
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
feature_grad_mult: 0.1
|
||||
untie_final_proj: true
|
||||
activation_dropout: 0.0
|
||||
conv_adapator_kernal: 1
|
||||
use_single_target: true
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
- task.label_dir
|
||||
- common.min_loss_scale
|
||||
- common.log_interval
|
||||
- optimization.clip_norm
|
@ -0,0 +1,107 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
memory_efficient_fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
seed: 1337
|
||||
tensorboard_logdir: tblog
|
||||
|
||||
checkpoint:
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: no_c10d
|
||||
distributed_backend: 'nccl'
|
||||
distributed_world_size: 128
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
||||
find_unused_parameters: true
|
||||
|
||||
task:
|
||||
_name: multires_hubert_pretraining
|
||||
data: ???
|
||||
label_dir: ???
|
||||
labels: ???
|
||||
label_rate: ${model.label_rate}
|
||||
label_rate_ratios: ???
|
||||
sample_rate: 16000
|
||||
max_sample_size: 250000
|
||||
min_sample_size: 32000
|
||||
pad_audio: false
|
||||
random_crop: true
|
||||
normalize: true # must be consistent with extractor
|
||||
# max_keep_size: 50000
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
||||
max_tokens: 300000
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
validate_interval: 5
|
||||
validate_interval_updates: 10000
|
||||
|
||||
criterion:
|
||||
_name: hubert
|
||||
pred_masked_weight: 1.0
|
||||
pred_nomask_weight: 0.0
|
||||
loss_weights: [10,]
|
||||
|
||||
optimization:
|
||||
max_update: 400000
|
||||
lr: [0.0015]
|
||||
clip_norm: 1.0
|
||||
update_freq: [3]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 32000
|
||||
|
||||
model:
|
||||
_name: multires_hubert
|
||||
label_rate: ???
|
||||
label_rate_ratios: ${task.label_rate_ratios}
|
||||
encoder_layers: 8
|
||||
encoder_embed_dim: 1024
|
||||
encoder_ffn_embed_dim: 4096
|
||||
encoder_attention_heads: 16
|
||||
final_dim: 768
|
||||
skip_masked: false
|
||||
skip_nomask: false
|
||||
mask_prob: 0.80
|
||||
extractor_mode: layer_norm
|
||||
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
||||
encoder_layerdrop: 0.0
|
||||
dropout_input: 0.0
|
||||
dropout_features: 0.0
|
||||
dropout: 0.0
|
||||
attention_dropout: 0.0
|
||||
layer_norm_first: true
|
||||
feature_grad_mult: 1.0
|
||||
untie_final_proj: true
|
||||
activation_dropout: 0.0
|
||||
conv_adapator_kernal: 1
|
||||
use_single_target: true
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '-'
|
||||
item_sep: '__'
|
||||
exclude_keys:
|
||||
- run
|
||||
- task.data
|
||||
run:
|
||||
dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt
|
||||
sweep:
|
||||
dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt
|
||||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
20
examples/mr_hubert/config/pretrain/run/submitit_reg.yaml
Normal file
20
examples/mr_hubert/config/pretrain/run/submitit_reg.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
launcher:
|
||||
cpus_per_task: 8
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
||||
nodes: 4
|
||||
comment: null
|
||||
mem_gb: 384
|
||||
timeout_min: 4320
|
||||
max_num_timeout: 100
|
||||
constraint: volta32gb
|
||||
name: ${hydra.job.config_name}/${hydra.job.override_dirname}
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit/%j
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 32
|
||||
distributed_port: 29671
|
||||
nprocs_per_node: 8
|
46
examples/mr_hubert/decode.sh
Executable file
46
examples/mr_hubert/decode.sh
Executable file
@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
FAIRSEQ= # Setup your fairseq directory
|
||||
|
||||
config_dir=${FAIRSEQ}/examples/mr_hubert/config
|
||||
config_name=mr_hubert_base_librispeech
|
||||
|
||||
|
||||
# Prepared Data Directory
|
||||
|
||||
data_dir=librispeech
|
||||
# -- data_dir
|
||||
# -- test.tsv
|
||||
# -- test.ltr
|
||||
# -- dict.ltr.txt
|
||||
|
||||
|
||||
exp_dir=exp # Target experiments directory (where you have your pre-trained model with checkpoint_best.pt)
|
||||
ratios="[1, 2]" # Default label rate ratios
|
||||
|
||||
_opts=
|
||||
|
||||
# If use slurm, uncomment this line and modify the job submission at
|
||||
# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg"
|
||||
|
||||
# If want to set additional experiment tag, uncomment this line
|
||||
# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}"
|
||||
|
||||
# If use un-normalized audio, uncomment this line
|
||||
# _opts="${_opts} task.normalize=false"
|
||||
|
||||
|
||||
|
||||
PYTHONPATH=${FAIRSEQ}
|
||||
python examples/speech_recognition/new/infer.py \
|
||||
--config-dir ${config_dir} \
|
||||
--config-name infer_multires \
|
||||
${_opts} \
|
||||
task.data=${data_dir} \
|
||||
task.label_rate_ratios='${ratios}' \
|
||||
common_eval.results_path=${exp_dir} \
|
||||
common_eval.path=${exp_dir}/checkpoint_best.pt \
|
||||
dataset.max_tokens=2000000 \
|
||||
dataset.gen_subset=test \
|
||||
dataset.skip_invalid_size_inputs_valid_test=true
|
||||
|
46
examples/mr_hubert/finetune.sh
Executable file
46
examples/mr_hubert/finetune.sh
Executable file
@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
FAIRSEQ= # Setup your fairseq directory
|
||||
|
||||
config_dir=${FAIRSEQ}/examples/mr_hubert/config
|
||||
config_name=mr_hubert_base_librispeech
|
||||
|
||||
# override configs if need
|
||||
max_tokens=3200000
|
||||
max_sample_size=1000000
|
||||
max_update=50000
|
||||
|
||||
|
||||
# Prepared Data Directory
|
||||
|
||||
data_dir=librispeech
|
||||
# -- data_dir
|
||||
# -- train.tsv
|
||||
# -- train.ltr
|
||||
# -- valid.tsv
|
||||
# -- valid.ltr
|
||||
# -- dict.ltr.txt
|
||||
|
||||
|
||||
exp_dir=exp # Target experiments directory
|
||||
ratios="[1, 2]" # Default label rate ratios
|
||||
hubert_path=/path/of/your/hubert.pt
|
||||
|
||||
_opts=
|
||||
|
||||
# If use slurm, uncomment this line and modify the job submission at
|
||||
# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg"
|
||||
|
||||
# If want to set additional experiment tag, uncomment this line
|
||||
# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}"
|
||||
|
||||
|
||||
python ${FAIRSEQ}/fairseq_cli/hydra_train.py \
|
||||
-m --config-dir ${config_dir} --config-name ${config_name} ${_opts} \
|
||||
task.data=${data_dir} +task.max_sample_size=${max_sample_size} \
|
||||
task.label_dir=${data_dir} \
|
||||
task.label_rate_ratios='${ratios}' \
|
||||
dataset.max_tokens=${max_tokens} \
|
||||
optimization.max_update=${max_update} \
|
||||
model.multires_hubert_path=${hubert_path} \
|
||||
hydra.sweep.dir=${exp_dir} &
|
1
examples/mr_hubert/simple_kmeans
Symbolic link
1
examples/mr_hubert/simple_kmeans
Symbolic link
@ -0,0 +1 @@
|
||||
../hubert/simple_kmeans
|
45
examples/mr_hubert/train.sh
Executable file
45
examples/mr_hubert/train.sh
Executable file
@ -0,0 +1,45 @@
|
||||
#!/bin/bash
|
||||
|
||||
FAIRSEQ= # Setup your fairseq directory
|
||||
|
||||
config_dir=${FAIRSEQ}/examples/mr_hubert/config
|
||||
config_name=mr_hubert_base_librispeech
|
||||
|
||||
# Prepared Data Directory
|
||||
data_dir=librispeech
|
||||
# -- data_dir
|
||||
# -- train.tsv
|
||||
# -- valid.tsv
|
||||
|
||||
label_dir=labels
|
||||
# -- label_dir
|
||||
# -- train.km
|
||||
# -- valid.km
|
||||
# -- dict.km.txt
|
||||
|
||||
|
||||
exp_dir=exp # Target experiments directory
|
||||
ratios="[1, 2]" # Default label rate ratios
|
||||
label_rate=50 # Base label rate
|
||||
|
||||
|
||||
_opts=
|
||||
|
||||
# If use slurm, uncomment this line and modify the job submission at
|
||||
# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg"
|
||||
|
||||
# If want to set additional experiment tag, uncomment this line
|
||||
# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}"
|
||||
|
||||
|
||||
python ${FAIRSEQ}/fairseq_cli/hydra_train.py \
|
||||
-m --config-dir ${config_dir} --config-name ${config_name} ${_opts} \
|
||||
task.data=${data_dir} \
|
||||
task.label_dir=${label_dir} \
|
||||
task.labels='["km"]' \
|
||||
model.label_rate=${label_rate} \
|
||||
task.label_rate_ratios='${ratios}' \
|
||||
hydra.sweep.dir=${exp_dir} &
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# ASR-BLEU evaluation toolkit
|
||||
|
||||
This toolkit provides a set of public ASR models used for evaluation of different speech-to-speech translation systems at Meta AI. It enables easier score comparisons between different system's outputs.
|
||||
This toolkit provides a set of public ASR models used for evaluation of different speech-to-speech translation systems at FAIR. It enables easier score comparisons between different system's outputs.
|
||||
|
||||
The ASRGenerator wraps different CTC-based ASR models from HuggingFace and fairseq code bases. Torchaudio CTC decoder is built on top of it to decode given audio files.
|
||||
|
||||
@ -31,4 +31,4 @@ python compute_asr_bleu.py --lang <LANG> \
|
||||
--reference_format txt
|
||||
```
|
||||
|
||||
For more details about arguments please see the script argparser help.
|
||||
For more details about arguments please see the script argparser help.
|
||||
|
@ -1,4 +1,4 @@
|
||||
XStoryCloze consists of professional translation of the validation split of the [English StoryCloze dataset](https://cs.rochester.edu/nlp/rocstories/) (Spring 2016 version) to 10 other languages. This dataset is released by Meta AI alongside the paper [Few-shot Learning with Multilingual Generative Language Models. EMNLP 2022](https://arxiv.org/abs/2112.10668).
|
||||
XStoryCloze consists of professional translation of the validation split of the [English StoryCloze dataset](https://cs.rochester.edu/nlp/rocstories/) (Spring 2016 version) to 10 other languages. This dataset is released by FAIR (Fundamental Artificial Intelligence Research) alongside the paper [Few-shot Learning with Multilingual Generative Language Models. EMNLP 2022](https://arxiv.org/abs/2112.10668).
|
||||
|
||||
# Languages
|
||||
ru, zh (Simplified), es (Latin America), ar, hi, id, te, sw, eu, my.
|
||||
|
@ -2,7 +2,7 @@
|
||||
## Version 1.0.0
|
||||
|
||||
### Model developer
|
||||
Meta AI
|
||||
FAIR (Fundamental Artificial Intelligence Research)
|
||||
|
||||
### Model type
|
||||
A family of multilingual autoregressive language models (ranging from 564 million to 7.5 billion parameters) trained on a balanced corpus of a diverse set of languages. The language model can learn tasks from natural language descriptions and a few examples.
|
||||
@ -31,7 +31,7 @@ The model was evaluated on hate speech detection and occupation identification.
|
||||
## Metrics
|
||||
### Model performance measures
|
||||
The XGLM model was primarily evaluated on
|
||||
1. Zero shot and few shot learning by looking at per-language performance on tasks spanning commonsense reasoning (XCOPA, XWinograd), natural language inference (XNLI) and paraphrasing (PAWS-X). The model is also evaluated on XStoryCloze, a new dataset created by Meta AI.
|
||||
1. Zero shot and few shot learning by looking at per-language performance on tasks spanning commonsense reasoning (XCOPA, XWinograd), natural language inference (XNLI) and paraphrasing (PAWS-X). The model is also evaluated on XStoryCloze, a new dataset created by FAIR (Fundamental Artificial Intelligence Research).
|
||||
2. Cross lingual transfer through templates and few-shot examples.
|
||||
3. Knowledge probing - Evaluate to what extent the XGLM model can effectively store factual knowledge in different languages using the mLAMA benchmark.
|
||||
4. Translation - We report machine translation results on WMT benchmarks and a subset of FLORES-101 in the main paper.
|
||||
@ -50,7 +50,7 @@ The Cross-lingual Natural Language Inference (XNLI) corpus is the extension of t
|
||||
|
||||
### XStoryCloze
|
||||
#### Description
|
||||
A new dataset created by Meta AI along side this work by translating the validation split of the English StoryCloze dataset (Mostafazadeh et al., 2016) (Spring 2016 version) to 10 other typologically diverse languages (ru, zh Simplified, es Latin America, ar, hi, id, te, sw, eu, my).
|
||||
A new dataset created by FAIR along side this work by translating the validation split of the English StoryCloze dataset (Mostafazadeh et al., 2016) (Spring 2016 version) to 10 other typologically diverse languages (ru, zh Simplified, es Latin America, ar, hi, id, te, sw, eu, my).
|
||||
|
||||
### XCOPA (Ponti et al., 2020)
|
||||
#### Description
|
||||
@ -85,7 +85,7 @@ More details on the CC100-XL dataset can be found in the Appendix section of the
|
||||
The XGLM model was evaluated on Hate speech and bias identification datasets. For hate speech, we observe that across the 5 languages in the dataset, in context learning results are only slightly better than random (50%). Another interesting observation is that most few shot results are worse than zero-shot, which indicates that the model is not able to utilize examples using the templates described in the paper. For bias identification, the XGLM (6.7B) English only model achieves the best performance on English and Spanish, while the GPT-3 model of comparable size (6.7B) model achieves the best in French. On certain occupations (e.g. model and teacher), XGLM 6.7B En only model and GPT-3 (6.7B) have very significant bias while XGLM 7.5B is much less biased.
|
||||
|
||||
### Privacy and security
|
||||
The XGLM model did not have any special Privacy and Security considerations. The training data and evaluation data were both public and went through standard Meta AI Privacy and licensing procedures.
|
||||
The XGLM model did not have any special Privacy and Security considerations. The training data and evaluation data were both public and went through standard Meta privacy and licensing procedures.
|
||||
|
||||
### Transparency and control
|
||||
In the spirit of transparency and accountability we have created this model card and a data card for the CC100-XL which can be found in the Appendix section of the paper.
|
||||
|
@ -104,7 +104,20 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not cfg.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
extra_state = {
|
||||
"train_iterator": epoch_itr.state_dict(),
|
||||
"val_loss": val_loss,
|
||||
}
|
||||
|
||||
# Going forward, different tasks could expose an API like this to dump all
|
||||
# the checkpoint worthy attributes in a dictionary which then will be
|
||||
# merged with the parent dictionary to create the "extra_state". This
|
||||
# allows for an extensible yet simple design to checkpoint task level
|
||||
# attributes
|
||||
if hasattr(trainer.task, "get_checkpoint_dict"):
|
||||
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
|
||||
logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
|
||||
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
extra_state.update({"best": save_checkpoint.best})
|
||||
|
||||
@ -275,6 +288,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
||||
)
|
||||
epoch_itr.load_state_dict(itr_state)
|
||||
|
||||
# Preload the checkpoint for the task
|
||||
task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
|
||||
if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
|
||||
trainer.task.set_checkpoint_dict(task_cp_dict)
|
||||
else:
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=1, load_dataset=True, **passthrough_args
|
||||
|
@ -524,7 +524,7 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
# TODO: Below is a lazy implementation which discard the final batch regardless
|
||||
# of whether it is a full batch or not.
|
||||
|
||||
total_num_itrs = len(self.epoch_batch_sampler) - 1
|
||||
total_num_itrs = len(itr) - 1
|
||||
itr.take(total_num_itrs)
|
||||
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")
|
||||
|
||||
|
2
fairseq/models/multires_hubert/__init__.py
Normal file
2
fairseq/models/multires_hubert/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .multires_hubert import * # noqa
|
||||
from .multires_hubert_asr import * # noqa
|
1231
fairseq/models/multires_hubert/multires_hubert.py
Normal file
1231
fairseq/models/multires_hubert/multires_hubert.py
Normal file
File diff suppressed because it is too large
Load Diff
376
fairseq/models/multires_hubert/multires_hubert_asr.py
Normal file
376
fairseq/models/multires_hubert/multires_hubert_asr.py
Normal file
@ -0,0 +1,376 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
from fairseq import checkpoint_utils, tasks, utils
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
|
||||
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
|
||||
from fairseq.tasks import FairseqTask
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiresHubertAsrConfig(FairseqDataclass):
|
||||
multires_hubert_path: str = field(
|
||||
default=MISSING, metadata={"help": "path to multires_hubert model"}
|
||||
)
|
||||
no_pretrained_weights: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if true, does not load pretrained weights"},
|
||||
)
|
||||
dropout_input: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
||||
)
|
||||
final_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout after transformer and before final projection"},
|
||||
)
|
||||
dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout probability inside hubert model"},
|
||||
)
|
||||
attention_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability for attention weights " "inside hubert model"
|
||||
},
|
||||
)
|
||||
activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability after activation in FFN " "inside hubert model"
|
||||
},
|
||||
)
|
||||
|
||||
# masking
|
||||
apply_mask: bool = field(
|
||||
default=False, metadata={"help": "apply masking during fine-tuning"}
|
||||
)
|
||||
mask_length: int = field(
|
||||
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
||||
)
|
||||
mask_prob: float = field(
|
||||
default=0.5,
|
||||
metadata={
|
||||
"help": "probability of replacing a token with mask "
|
||||
"(normalized by length)"
|
||||
},
|
||||
)
|
||||
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static", metadata={"help": "how to choose masks"}
|
||||
)
|
||||
mask_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "secondary mask argument "
|
||||
"(used for more complex distributions), "
|
||||
"see help in compute_mask_indices"
|
||||
},
|
||||
)
|
||||
no_mask_overlap: bool = field(
|
||||
default=False, metadata={"help": "whether to allow masks to overlap"}
|
||||
)
|
||||
|
||||
# channel masking
|
||||
mask_channel_length: int = field(
|
||||
default=10,
|
||||
metadata={"help": "length of the mask for features (channels)"},
|
||||
)
|
||||
mask_channel_prob: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "probability of replacing a feature with 0"},
|
||||
)
|
||||
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static",
|
||||
metadata={"help": "how to choose mask length for channel masking"},
|
||||
)
|
||||
mask_channel_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "secondary mask argument "
|
||||
"(used for more complex distributions), "
|
||||
"see help in compute_mask_indices"
|
||||
},
|
||||
)
|
||||
no_mask_channel_overlap: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to allow channel masks to overlap"},
|
||||
)
|
||||
freeze_finetune_updates: int = field(
|
||||
default=0,
|
||||
metadata={"help": "dont finetune hubert for this many updates"},
|
||||
)
|
||||
feature_grad_mult: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "reset feature grad mult in hubert to this"},
|
||||
)
|
||||
layerdrop: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "probability of dropping a layer in hubert"},
|
||||
)
|
||||
normalize: bool = II("task.normalize")
|
||||
data: str = II("task.data")
|
||||
|
||||
# this holds the loaded hubert args
|
||||
multires_hubert_args: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiresHubertCtcConfig(MultiresHubertAsrConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register_model("multires_hubert_ctc", dataclass=MultiresHubertAsrConfig)
|
||||
class MultiresHubertCtc(BaseFairseqModel):
|
||||
def __init__(
|
||||
self, cfg: MultiresHubertAsrConfig, multireshubert_encoder: BaseFairseqModel
|
||||
):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.multireshubert_encoder = multireshubert_encoder
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: MultiresHubertAsrConfig, task: FairseqTask):
|
||||
"""Build a new model instance."""
|
||||
multireshubert_encoder = MultiresHubertEncoder(cfg, task)
|
||||
return cls(cfg, multireshubert_encoder)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
|
||||
logits = net_output["encoder_out"]
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits.float(), dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits.float(), dim=-1)
|
||||
|
||||
def get_logits(self, net_output):
|
||||
logits = net_output["encoder_out"]
|
||||
padding = net_output["encoder_padding_mask"]
|
||||
if padding is not None and padding.any():
|
||||
padding = padding.T
|
||||
logits[padding][..., 0] = 0
|
||||
logits[padding][..., 1:] = float("-inf")
|
||||
|
||||
return logits
|
||||
|
||||
def forward(self, **kwargs):
|
||||
x = self.multireshubert_encoder(**kwargs)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiresHubertSeq2SeqConfig(MultiresHubertAsrConfig):
|
||||
decoder_embed_dim: int = field(
|
||||
default=768, metadata={"help": "decoder embedding dimension"}
|
||||
)
|
||||
decoder_ffn_embed_dim: int = field(
|
||||
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
||||
)
|
||||
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"})
|
||||
decoder_layerdrop: float = field(
|
||||
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
||||
)
|
||||
decoder_attention_heads: int = field(
|
||||
default=4, metadata={"help": "num decoder attention heads"}
|
||||
)
|
||||
decoder_learned_pos: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use learned positional embeddings in the decoder"},
|
||||
)
|
||||
decoder_normalize_before: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "apply layernorm before each decoder block"},
|
||||
)
|
||||
no_token_positional_embeddings: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, disables positional embeddings " "(outside self attention)"
|
||||
},
|
||||
)
|
||||
decoder_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability in the decoder"}
|
||||
)
|
||||
decoder_attention_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability for attention weights " "inside the decoder"
|
||||
},
|
||||
)
|
||||
decoder_activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability after activation in FFN " "inside the decoder"
|
||||
},
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
default=2048, metadata={"help": "max target positions"}
|
||||
)
|
||||
share_decoder_input_output_embed: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "share decoder input and output embeddings"},
|
||||
)
|
||||
|
||||
|
||||
class MultiresHubertEncoder(FairseqEncoder):
|
||||
def __init__(self, cfg: MultiresHubertAsrConfig, task):
|
||||
self.apply_mask = cfg.apply_mask
|
||||
|
||||
arg_overrides = {
|
||||
"dropout": cfg.dropout,
|
||||
"activation_dropout": cfg.activation_dropout,
|
||||
"dropout_input": cfg.dropout_input,
|
||||
"attention_dropout": cfg.attention_dropout,
|
||||
"mask_length": cfg.mask_length,
|
||||
"mask_prob": cfg.mask_prob,
|
||||
"mask_selection": cfg.mask_selection,
|
||||
"mask_other": cfg.mask_other,
|
||||
"no_mask_overlap": cfg.no_mask_overlap,
|
||||
"mask_channel_length": cfg.mask_channel_length,
|
||||
"mask_channel_prob": cfg.mask_channel_prob,
|
||||
"mask_channel_selection": cfg.mask_channel_selection,
|
||||
"mask_channel_other": cfg.mask_channel_other,
|
||||
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
||||
"encoder_layerdrop": cfg.layerdrop,
|
||||
"feature_grad_mult": cfg.feature_grad_mult,
|
||||
}
|
||||
|
||||
if cfg.multires_hubert_args is None:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(
|
||||
cfg.multires_hubert_path, arg_overrides
|
||||
)
|
||||
multires_hubert_args = state.get("cfg", None)
|
||||
if multires_hubert_args is None:
|
||||
multires_hubert_args = convert_namespace_to_omegaconf(state["args"])
|
||||
cfg.multires_hubert_args = multires_hubert_args
|
||||
else:
|
||||
state = None
|
||||
multires_hubert_args = cfg.multires_hubert_args
|
||||
if isinstance(multires_hubert_args, Namespace):
|
||||
cfg.multires_hubert_args = (
|
||||
multires_hubert_args
|
||||
) = convert_namespace_to_omegaconf(multires_hubert_args)
|
||||
|
||||
assert cfg.normalize == multires_hubert_args.task.normalize, (
|
||||
"Fine-tuning works best when data normalization is the same. "
|
||||
"Please check that --normalize is set or unset for "
|
||||
"both pre-training and here"
|
||||
)
|
||||
|
||||
multires_hubert_args.task.data = cfg.data
|
||||
pretrain_task = tasks.setup_task(multires_hubert_args.task)
|
||||
if state is not None and "task_state" in state:
|
||||
# This will load the stored "dictionaries" object
|
||||
pretrain_task.load_state_dict(state["task_state"])
|
||||
else:
|
||||
pretrain_task.load_state_dict(task.state_dict())
|
||||
|
||||
model = pretrain_task.build_model(
|
||||
multires_hubert_args.model, from_checkpoint=True
|
||||
)
|
||||
if state is not None and not cfg.no_pretrained_weights:
|
||||
# set strict=False because we omit some modules
|
||||
model.load_state_dict(state["model"], strict=False)
|
||||
|
||||
model.remove_pretraining_modules()
|
||||
|
||||
super().__init__(pretrain_task.source_dictionary)
|
||||
|
||||
d = multires_hubert_args.model.encoder_embed_dim
|
||||
|
||||
self.multires_hubert_model = model
|
||||
|
||||
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
||||
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
||||
self.num_updates = 0
|
||||
|
||||
if task.target_dictionary is not None:
|
||||
self.proj = Linear(d, len(task.target_dictionary))
|
||||
elif getattr(cfg, "decoder_embed_dim", d) != d:
|
||||
self.proj = Linear(d, cfg.decoder_embed_dim)
|
||||
else:
|
||||
self.proj = None
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
"""Set the number of parameters updates."""
|
||||
super().set_num_updates(num_updates)
|
||||
self.num_updates = num_updates
|
||||
|
||||
def forward(self, source, padding_mask, tbc=True, **kwargs):
|
||||
multires_hubert_args = {
|
||||
"source": source,
|
||||
"padding_mask": padding_mask,
|
||||
"mask": self.apply_mask and self.training,
|
||||
"last_layer": True,
|
||||
}
|
||||
|
||||
ft = self.freeze_finetune_updates <= self.num_updates
|
||||
|
||||
with torch.no_grad() if not ft else contextlib.ExitStack():
|
||||
x, padding_mask = self.multires_hubert_model.extract_features(
|
||||
**multires_hubert_args
|
||||
)
|
||||
|
||||
if tbc:
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
x = self.final_dropout(x)
|
||||
|
||||
if self.proj:
|
||||
x = self.proj(x)
|
||||
|
||||
return {
|
||||
"encoder_out": x, # T x B x C
|
||||
"encoder_padding_mask": padding_mask, # B x T
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
if encoder_out["encoder_out"] is not None:
|
||||
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
||||
1, new_order
|
||||
)
|
||||
if encoder_out["encoder_padding_mask"] is not None:
|
||||
encoder_out["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
].index_select(0, new_order)
|
||||
return encoder_out
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum input length supported by the encoder."""
|
||||
return None
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
return state_dict
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
||||
nn.init.constant_(m.weight[padding_idx], 0)
|
||||
return m
|
||||
|
||||
|
||||
def Linear(in_features, out_features, bias=True):
|
||||
m = nn.Linear(in_features, out_features, bias)
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if bias:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
return m
|
@ -1009,7 +1009,7 @@ class TransformerEncoder(nn.Module):
|
||||
layer = checkpoint_wrapper(layer)
|
||||
return layer
|
||||
|
||||
def __init__(self, args: Wav2Vec2Config):
|
||||
def __init__(self, args: Wav2Vec2Config, skip_pos_conv: bool = False, override_encoder_layer: int = None):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
@ -1045,7 +1045,8 @@ class TransformerEncoder(nn.Module):
|
||||
self.pos_conv = make_conv_block(
|
||||
self.embedding_dim, k, args.conv_pos_groups, num_layers
|
||||
)
|
||||
|
||||
elif skip_pos_conv:
|
||||
self.pos_conv = None
|
||||
else:
|
||||
self.pos_conv = make_conv_pos(
|
||||
self.embedding_dim,
|
||||
@ -1056,8 +1057,13 @@ class TransformerEncoder(nn.Module):
|
||||
else False,
|
||||
)
|
||||
|
||||
if override_encoder_layer is None:
|
||||
encoder_layers = args.encoder_layers
|
||||
else:
|
||||
encoder_layers = override_encoder_layer
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)]
|
||||
[self.build_encoder_layer(args, layer_idx=ii) for ii in range(encoder_layers)]
|
||||
)
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
@ -1087,9 +1093,10 @@ class TransformerEncoder(nn.Module):
|
||||
if padding_mask is not None:
|
||||
x = index_put(x, padding_mask, 0)
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
if self.pos_conv is not None:
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
204
fairseq/tasks/multires_hubert_pretraining.py
Normal file
204
fairseq/tasks/multires_hubert_pretraining.py
Normal file
@ -0,0 +1,204 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from fairseq.data import Dictionary, HubertDataset
|
||||
from fairseq.dataclass.configs import FairseqDataclass
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.fairseq_task import FairseqTask
|
||||
from omegaconf import MISSING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LabelEncoder(object):
|
||||
def __init__(self, dictionary: Dictionary) -> None:
|
||||
self.dictionary = dictionary
|
||||
|
||||
def __call__(self, label: str) -> List[str]:
|
||||
return self.dictionary.encode_line(
|
||||
label,
|
||||
append_eos=False,
|
||||
add_if_not_exist=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiresHubertPretrainingConfig(FairseqDataclass):
|
||||
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
||||
fine_tuning: bool = field(
|
||||
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
|
||||
)
|
||||
labels: List[str] = field(
|
||||
default_factory=lambda: ["ltr50", "ltr25"],
|
||||
metadata={
|
||||
"help": (
|
||||
"extension of the label files to load, frame-level labels for"
|
||||
" pre-training, and sequence-level label for fine-tuning"
|
||||
)
|
||||
},
|
||||
)
|
||||
label_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "if set, looks for labels in this directory instead",
|
||||
},
|
||||
)
|
||||
label_rate: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "label frame rate. -1.0 for sequence label"},
|
||||
)
|
||||
# label_rate: 1,2,2,5
|
||||
# (imply (1,2), (2,5))
|
||||
# if base label_rate = 50
|
||||
# (1,2), (2,5) --> label rates 50, 25, 10
|
||||
label_rate_ratios: List[int] = field(default=MISSING, metadata={"help": "tuple for label rates e.g., [(1,2), (2,5)]"})
|
||||
sample_rate: int = field(
|
||||
default=16_000,
|
||||
metadata={
|
||||
"help": "target sample rate. audio files will be up/down "
|
||||
"sampled to this rate"
|
||||
},
|
||||
)
|
||||
normalize: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
|
||||
)
|
||||
enable_padding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "pad shorter samples instead of cropping"},
|
||||
)
|
||||
max_keep_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "exclude sample longer than this"},
|
||||
)
|
||||
max_sample_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "max sample size to crop to for batching"},
|
||||
)
|
||||
min_sample_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "min sample size to crop to for batching"},
|
||||
)
|
||||
random_crop: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "always crop from the beginning if false"},
|
||||
)
|
||||
pad_audio: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "pad audio to the longest one in the batch if true"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("multires_hubert_pretraining", dataclass=MultiresHubertPretrainingConfig)
|
||||
class MultiresHubertPretrainingTask(FairseqTask):
|
||||
"""
|
||||
Multiresolution HuBERT Pretraining Task.
|
||||
The task is based on `HubertPretrainingTask` but extended to multiresolution.
|
||||
"""
|
||||
|
||||
cfg: MultiresHubertPretrainingConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: MultiresHubertPretrainingConfig,
|
||||
) -> None:
|
||||
super().__init__(cfg)
|
||||
|
||||
logger.info(f"current directory is {os.getcwd()}")
|
||||
logger.info(f"MultiresHubertPretrainingTask Config {cfg}")
|
||||
|
||||
self.cfg = cfg
|
||||
self.fine_tuning = cfg.fine_tuning
|
||||
|
||||
if cfg.fine_tuning:
|
||||
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
||||
self.res_number = 1
|
||||
else:
|
||||
self.state.add_factory("dictionaries", self.load_dictionaries)
|
||||
|
||||
self.blank_symbol = "<s>"
|
||||
|
||||
@property
|
||||
def source_dictionary(self) -> Optional[Dictionary]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def target_dictionary(self) -> Optional[Dictionary]:
|
||||
return self.state.target_dictionary
|
||||
|
||||
@property
|
||||
def dictionaries(self) -> List[Dictionary]:
|
||||
return self.state.dictionaries
|
||||
|
||||
@classmethod
|
||||
def setup_task(
|
||||
cls, cfg: MultiresHubertPretrainingConfig, **kwargs
|
||||
) -> "MultiresHubertPretrainingTask":
|
||||
return cls(cfg)
|
||||
|
||||
def load_dictionaries(self):
|
||||
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
||||
self.res_number = len(label_dir)
|
||||
dictionaries = [ (Dictionary.load(f"{label_dir}/dict.{label}.txt") if label is not "" else None ) for label in self.cfg.labels]
|
||||
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
||||
|
||||
def get_label_dir(self) -> str:
|
||||
if self.cfg.label_dir is None:
|
||||
return self.cfg.data
|
||||
return self.cfg.label_dir
|
||||
|
||||
def load_dataset(self, split: str, **kwargs) -> None:
|
||||
manifest = f"{self.cfg.data}/{split}.tsv"
|
||||
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
|
||||
pad_list = [(dict.pad() if dict is not None else None) for dict in dicts]
|
||||
eos_list = [(dict.eos() if dict is not None else None) for dict in dicts]
|
||||
procs = [LabelEncoder(dict) for dict in dicts]
|
||||
paths = [(f"{self.get_label_dir()}/{split}.{l}" if l != "" else None) for l in self.cfg.labels]
|
||||
|
||||
base_rate = self.cfg.label_rate
|
||||
self.label_rates = [base_rate]
|
||||
label_rate_ratios = self.cfg.label_rate_ratios
|
||||
self.label_rate_ratios = []
|
||||
for i in range(len(label_rate_ratios) // 2):
|
||||
|
||||
upsample_rate, downsample_rate = label_rate_ratios[i * 2], label_rate_ratios[i * 2 + 1]
|
||||
# parse label rate ratios
|
||||
self.label_rate_ratios.append((upsample_rate, downsample_rate))
|
||||
base_rate = base_rate * upsample_rate // downsample_rate
|
||||
self.label_rates.append(base_rate)
|
||||
|
||||
# hubert v1: pad_audio=True, random_crop=False;
|
||||
self.datasets[split] = HubertDataset(
|
||||
manifest,
|
||||
sample_rate=self.cfg.sample_rate,
|
||||
label_paths=paths,
|
||||
label_rates=self.label_rates,
|
||||
pad_list=pad_list,
|
||||
eos_list=eos_list,
|
||||
label_processors=procs,
|
||||
max_keep_sample_size=self.cfg.max_keep_size,
|
||||
min_keep_sample_size=self.cfg.min_sample_size,
|
||||
max_sample_size=self.cfg.max_sample_size,
|
||||
pad_audio=self.cfg.pad_audio,
|
||||
normalize=self.cfg.normalize,
|
||||
store_labels=False,
|
||||
random_crop=self.cfg.random_crop,
|
||||
)
|
||||
|
||||
def max_positions(self) -> Tuple[int, int]:
|
||||
return (sys.maxsize, sys.maxsize)
|
||||
|
||||
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
|
||||
return indices
|
@ -11,8 +11,6 @@ import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
from tests.utils import (
|
||||
create_dummy_data,
|
||||
|
172
tests/test_checkpoint_utils_for_task_level_attributes.py
Normal file
172
tests/test_checkpoint_utils_for_task_level_attributes.py
Normal file
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env fbpython
|
||||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, data
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def mock_trainer(epoch, num_updates, iterations_in_epoch):
|
||||
trainer = MagicMock()
|
||||
trainer.load_checkpoint.return_value = {
|
||||
"train_iterator": {
|
||||
"epoch": epoch,
|
||||
"iterations_in_epoch": iterations_in_epoch,
|
||||
"shuffle": False,
|
||||
},
|
||||
"FakeTask": checkpoint_dict()["FakeTask"],
|
||||
}
|
||||
trainer.get_num_updates.return_value = num_updates
|
||||
trainer.task.__class__.__name__ = "FakeTask"
|
||||
trainer.task.get_checkpoint_dict.return_value = checkpoint_dict()
|
||||
trainer.task.set_checkpoint_dict = MagicMock()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def checkpoint_dict():
|
||||
return {
|
||||
"FakeTask": {
|
||||
"observer_stats": {
|
||||
(
|
||||
4,
|
||||
16,
|
||||
"MovingAveragePerChannelMinMax",
|
||||
"MovingAveragePerChannelMinMax",
|
||||
): {"mod1": 1, "mod2": 2, "mod3": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def mock_dict():
|
||||
d = MagicMock()
|
||||
d.pad.return_value = 1
|
||||
d.eos.return_value = 2
|
||||
d.unk.return_value = 3
|
||||
return d
|
||||
|
||||
|
||||
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
|
||||
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
|
||||
tokens_ds = data.TokenBlockDataset(
|
||||
tokens,
|
||||
sizes=[tokens.size(-1)],
|
||||
block_size=1,
|
||||
pad=0,
|
||||
eos=1,
|
||||
include_targets=False,
|
||||
)
|
||||
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
|
||||
dataset = data.LanguagePairDataset(
|
||||
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
|
||||
)
|
||||
epoch_itr = data.EpochBatchIterator(
|
||||
dataset=dataset,
|
||||
collate_fn=dataset.collater,
|
||||
batch_sampler=[[i] for i in range(epoch_size)],
|
||||
)
|
||||
return trainer, epoch_itr
|
||||
|
||||
|
||||
def get_mock_cfg(finetune_from_model):
|
||||
cfg_mock = OmegaConf.create(
|
||||
{
|
||||
"checkpoint": {
|
||||
"save_dir": None,
|
||||
"optimizer_overrides": "{}",
|
||||
"reset_dataloader": False,
|
||||
"reset_meters": False,
|
||||
"reset_optimizer": False,
|
||||
"reset_lr_scheduler": False,
|
||||
"finetune_from_model": finetune_from_model,
|
||||
"model_parallel_size": 1,
|
||||
"restore_file": "checkpoint_last.pt",
|
||||
"no_save": False,
|
||||
"save_interval_updates": 0,
|
||||
"no_last_checkpoints": False,
|
||||
"keep_interval_updates": 0,
|
||||
"keep_last_epochs": 0,
|
||||
"keep_best_checkpoints": 0,
|
||||
},
|
||||
"common": {
|
||||
"model_parallel_size": 1,
|
||||
},
|
||||
}
|
||||
)
|
||||
return cfg_mock
|
||||
|
||||
|
||||
class TestCheckpointsForTaskLevelAttributes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.cfg_mock = get_mock_cfg(None)
|
||||
self.patches = {
|
||||
"os.makedirs": MagicMock(),
|
||||
"os.path.join": MagicMock(),
|
||||
"os.path.isfile": MagicMock(return_value=True),
|
||||
"os.path.isabs": MagicMock(return_value=False),
|
||||
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
|
||||
}
|
||||
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
|
||||
[p.start() for p in self.applied_patches]
|
||||
logging.disable(logging.CRITICAL)
|
||||
|
||||
self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
|
||||
self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr)
|
||||
self.epoch_itr.next_epoch_itr(shuffle=False)
|
||||
|
||||
checkpoint_utils.save_checkpoint(
|
||||
self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
patch.stopall()
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
def test_verify_checkpoint(self) -> None:
|
||||
cp_dict = self.trainer.task.get_checkpoint_dict()
|
||||
self.assertTrue(len(cp_dict) == 1)
|
||||
self.assertTrue("FakeTask" in cp_dict)
|
||||
self.assertTrue("observer_stats" in cp_dict["FakeTask"])
|
||||
self.assertTrue(len(cp_dict["FakeTask"]["observer_stats"]) == 1)
|
||||
self.assertTrue(
|
||||
(
|
||||
4,
|
||||
16,
|
||||
"MovingAveragePerChannelMinMax",
|
||||
"MovingAveragePerChannelMinMax",
|
||||
)
|
||||
in cp_dict["FakeTask"]["observer_stats"]
|
||||
)
|
||||
self.assertTrue(
|
||||
cp_dict["FakeTask"]["observer_stats"][
|
||||
(
|
||||
4,
|
||||
16,
|
||||
"MovingAveragePerChannelMinMax",
|
||||
"MovingAveragePerChannelMinMax",
|
||||
)
|
||||
]
|
||||
== {"mod1": 1, "mod2": 2, "mod3": 3}
|
||||
)
|
||||
|
||||
def test_load_checkpoint(self) -> None:
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
# Now, load checkpoint to ensure the respective logic works as expected
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
self.cfg_mock.checkpoint, self.trainer
|
||||
)
|
||||
|
||||
self.trainer.task.set_checkpoint_dict.assert_called_once_with(
|
||||
checkpoint_dict()["FakeTask"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user