mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-11 17:25:31 +03:00
parent
bfd9dc6d27
commit
728b947019
63
examples/mms/MODEL_CARD.md
Normal file
63
examples/mms/MODEL_CARD.md
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# MMS Model Card
|
||||||
|
|
||||||
|
## Model details
|
||||||
|
|
||||||
|
**Organization developing the model** The FAIR team of Meta AI.
|
||||||
|
|
||||||
|
**Model version** This is version 1 of the model.
|
||||||
|
|
||||||
|
**Model type** MMS is speech model, based on the transformer architecture. The pre-trained model comes in two sizes: 300M and 1B parameters. We fine-tune the model for speech recognition and make it available in the 1B variant. We also fine-tune the 1B variant for language identification.
|
||||||
|
|
||||||
|
**License** CC BY-NC
|
||||||
|
|
||||||
|
**Where to send questions or comments about the model** Questions and comments about MMS can be sent via the [GitHub repository](https://github.com/pytorch/fairseq/tree/master/examples/mms) of the project , by opening an issue and tagging it as MMS.
|
||||||
|
|
||||||
|
## Uses
|
||||||
|
|
||||||
|
**Primary intended uses** The primary use of MMS is to perform speech processing research for many more languages and to perform tasks such as automatic speech recognition, language identification, and speech synthesis.
|
||||||
|
|
||||||
|
**Primary intended users** The primary intended users of the model are researchers in speech processing, machine learning and artificial intelligence.
|
||||||
|
|
||||||
|
**Out-of-scope use cases** Fine-tuning the pre-pretrained models on other labeled datasets or downstream tasks requires further risk evaluation and mitigation.
|
||||||
|
|
||||||
|
## Bias and Risks
|
||||||
|
|
||||||
|
The MMS models were pre-trained on a blend of data from different domains, including readings of the New Testament. In the paper, we describe two studies analyzing gender bias and the use of religious language which conclude that models perform equally well for both genders and that on average, there is little bias for religious language (section 8 of the paper).
|
||||||
|
|
||||||
|
# Training Details
|
||||||
|
|
||||||
|
## Training Data
|
||||||
|
|
||||||
|
MMS is pre-trained on VoxPopuli (parliamentary speech), MLS (read audiobooks), VoxLingua-107 (YouTube speech), CommonVoice (read Wikipedia text), BABEL (telephone conversations), and MMS-lab-U (New Testament readings), MMS-unlab (various read Christian texts).
|
||||||
|
Models are fine-tuned on FLEURS, VoxLingua-107, MLS, CommonVoice, and MMS-lab. We obtained the language information for MMS-lab, MMS-lab-U and MMS-unlab from our data soucrce and did not manually verify it for every language.
|
||||||
|
|
||||||
|
## Training Procedure
|
||||||
|
|
||||||
|
Please refer to the research paper for details on this.
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
|
||||||
|
## Testing Data, Factors & Metrics
|
||||||
|
|
||||||
|
We evaluate the model on a different benchmarks for the downstream tasks. The evaluation details are presented in the paper. The models performance is measured using standard metrics such as character error rate, word error rate, and classification accuracy.
|
||||||
|
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
|
||||||
|
**BibTeX:**
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{pratap2023mms,
|
||||||
|
title={Scaling Speech Technology to 1,000+ Languages},
|
||||||
|
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Model Card Contact
|
||||||
|
|
||||||
|
Please reach out to the authors at: [vineelkpratap@meta.com](mailto:vineelkpratap@meta.com) [androstj@meta.com](mailto:androstj@meta.com) [bshi@meta.com](mailto:bshi@meta.com) [michaelauli@meta.com](mailto:michaelauli@gmail.com)
|
||||||
|
|
||||||
|
|
175
examples/mms/README.md
Normal file
175
examples/mms/README.md
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
# MMS: Scaling Speech Technology to 1000+ languages
|
||||||
|
|
||||||
|
The Massively Multilingual Speech (MMS) project expands speech technology from about 100 languages to over 1,000 by building a single multilingual speech recognition model supporting over 1,100 languages (more than 10 times as many as before), language identification models able to identify over [4,000 languages](https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html) (40 times more than before), pretrained models supporting over 1,400 languages, and text-to-speech models for over 1,100 languages. Our goal is to make it easier for people to access information and to use devices in their preferred language.
|
||||||
|
|
||||||
|
You can find details in the paper [Scaling Speech Technology to 1000+ languages](https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/) and the [blog post](https://ai.facebook.com/blog/multilingual-speech-recognition-model/).
|
||||||
|
|
||||||
|
An overview of the languages covered by MMS can be found [here](https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html).
|
||||||
|
|
||||||
|
|
||||||
|
## Pretrained models
|
||||||
|
|
||||||
|
| Model | Link
|
||||||
|
|---|---
|
||||||
|
MMS-300M | [download](https://dl.fbaipublicfiles.com/mms/pretraining/base_300m.pt)
|
||||||
|
MMS-1B | [download](https://dl.fbaipublicfiles.com/mms/pretraining/base_1b.pt)
|
||||||
|
|
||||||
|
Example commands to finetune the pretrained models can be found [here](https://github.com/fairinternal/fairseq-py/tree/mms_release/examples/wav2vec#fine-tune-a-pre-trained-model-with-ctc).
|
||||||
|
|
||||||
|
## Finetuned models
|
||||||
|
### ASR
|
||||||
|
|
||||||
|
| Model | Languages | Dataset | Model | Supported languages |
|
||||||
|
|---|---|---|---|---
|
||||||
|
MMS-1B:FL102 | 102 | FLEURS | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_fl102.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_fl102_langs.html)
|
||||||
|
MMS-1B:L1107| 1107 | MMS-lab | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_l1107.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_l1107_langs.html)
|
||||||
|
MMS-1B-all| 1162 | MMS-lab + FLEURS <br>+ CV + VP + MLS | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_all.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_all_langs.html)
|
||||||
|
|
||||||
|
### TTS
|
||||||
|
1. Download the list of [iso codes](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) of 1107 languages.
|
||||||
|
2. Find the iso code of the target language and download the checkpoint. Each folder contains 3 files: `G_100000.pth`, `config.json`, `vocab.txt`. The `G_100000.pth` is the generator trained for 100K updates, `config.json` is the training config, `vocab.txt` is the vocabulary for the TTS model.
|
||||||
|
```
|
||||||
|
# Examples:
|
||||||
|
wget https://dl.fbaipublicfiles.com/mms/tts/eng.tar.gz # English (eng)
|
||||||
|
wget https://dl.fbaipublicfiles.com/mms/tts/azj-script_latin.tar.gz # North Azerbaijani (azj-script_latin)
|
||||||
|
```
|
||||||
|
|
||||||
|
### LID
|
||||||
|
|
||||||
|
\# Languages | Dataset | Model | Dictionary | Supported languages |
|
||||||
|
|---|---|---|---|---
|
||||||
|
126 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l126.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l126/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l126_langs.html)
|
||||||
|
256 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l256.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l256/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l256_langs.html)
|
||||||
|
512 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l512.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l512/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l512_langs.html)
|
||||||
|
1024 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l1024.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l1024/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l1024_langs.html)
|
||||||
|
2048 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l2048.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l2048/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l2048_langs.html)
|
||||||
|
4017 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l4017/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017_langs.html)
|
||||||
|
|
||||||
|
## Commands to run inference
|
||||||
|
|
||||||
|
### ASR
|
||||||
|
Run this command to transcribe one or more audio files:
|
||||||
|
```shell command
|
||||||
|
cd /path/to/fairseq-py/
|
||||||
|
python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code --audio "/path/to/audio_1.wav" "/path/to/audio_1.wav"
|
||||||
|
```
|
||||||
|
|
||||||
|
For more advance configuration and calculate CER/WER, you could prepare manifest folder by creating a folder with this format:
|
||||||
|
```
|
||||||
|
$ ls /path/to/manifest
|
||||||
|
dev.tsv
|
||||||
|
dev.wrd
|
||||||
|
dev.ltr
|
||||||
|
dev.uid
|
||||||
|
|
||||||
|
# dev.tsv each line contains <audio> <number_of_sample>
|
||||||
|
$ cat dev.tsv
|
||||||
|
/
|
||||||
|
/path/to/audio_1 180000
|
||||||
|
/path/to/audio_2 200000
|
||||||
|
|
||||||
|
$ cat dev.ltr
|
||||||
|
t h i s | i s | o n e |
|
||||||
|
t h i s | i s | t w o |
|
||||||
|
|
||||||
|
$ cat dev.wrd
|
||||||
|
this is one
|
||||||
|
this is two
|
||||||
|
|
||||||
|
$ cat dev.uid
|
||||||
|
audio_1
|
||||||
|
audio_2
|
||||||
|
```
|
||||||
|
|
||||||
|
Followed by command below:
|
||||||
|
```
|
||||||
|
lang_code=<iso_code>
|
||||||
|
|
||||||
|
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='/path/to/asr/model'" task.data='/path/to/manifest' dataset.gen_subset="${lang_code}:dev" common_eval.post_process=letter
|
||||||
|
|
||||||
|
```
|
||||||
|
Available options:
|
||||||
|
* To get the raw character-based output, user can change to `common_eval.post_process=none`
|
||||||
|
|
||||||
|
* To maximize GPU efficiency or avoid out-of-memory (OOM), user can tune `dataset.max_tokens=???` size
|
||||||
|
|
||||||
|
* To run language model decoding, install flashlight python bindings using
|
||||||
|
```
|
||||||
|
git clone --recursive git@github.com:flashlight/flashlight.git
|
||||||
|
cd flashlight;
|
||||||
|
git checkout 035ead6efefb82b47c8c2e643603e87d38850076
|
||||||
|
cd bindings/python
|
||||||
|
python3 setup.py install
|
||||||
|
```
|
||||||
|
Train a [KenLM language model](https://github.com/flashlight/wav2letter/tree/main/recipes/rasr#language-model) and prepare a lexicon file in [this](https://dl.fbaipublicfiles.com/wav2letter/rasr/tutorial/lexicon.txt) format.
|
||||||
|
```
|
||||||
|
LANG=<iso> # for example - 'eng', 'azj-script_latin'
|
||||||
|
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py --config-dir=examples/mms/asr/config \
|
||||||
|
--config-name=infer_common decoding.type=kenlm distributed_training.distributed_world_size=1 \
|
||||||
|
decoding.unique_wer_file=true decoding.beam=500 decoding.beamsizetoken=50 \
|
||||||
|
task.data=<MANIFEST_FOLDER_PATH> common_eval.path='<MODEL_PATH.pt>' decoding.lexicon=<LEXICON_FILE> decoding.lmpath=<LM_FILE> \
|
||||||
|
decoding.results_path=<OUTPUT_DIR> dataset.gen_subset=${LANG}:dev decoding.lmweight=??? decoding.wordscore=???
|
||||||
|
```
|
||||||
|
We typically sweep `lmweight` in the range of 0 to 5 and `wordscore` in the range of -3 to 3. The output directory will contain the reference and hypothesis outputs from decoder.
|
||||||
|
|
||||||
|
For decoding with character-based language models, use empty lexicon file (`decoding.lexicon=`), `decoding.unitlm=True` and sweep over `decoding.silweight` instead of `wordscore`.
|
||||||
|
|
||||||
|
### TTS
|
||||||
|
Note: clone and install [VITS](https://github.com/jaywalnut310/vits) before running inference.
|
||||||
|
```shell script
|
||||||
|
## English TTS
|
||||||
|
$ PYTHONPATH=$PYTHONPATH:/path/to/vits python examples/mms/tts/infer.py --model-dir /path/to/model/eng \
|
||||||
|
--wav ./example.wav --txt "Expanding the language coverage of speech technology \
|
||||||
|
has the potential to improve access to information for many more people"
|
||||||
|
|
||||||
|
## Maithili TTS
|
||||||
|
$ PYTHONPATH=$PYTHONPATH:/path/to/vits python examples/mms/tts/infer.py --model-dir /path/to/model/mai \
|
||||||
|
--wav ./example.wav --txt "मुदा आइ धरि ई तकनीक सौ सं किछु बेसी भाषा तक सीमित छल जे सात हजार \
|
||||||
|
सं बेसी ज्ञात भाषाक एकटा अंश अछी"
|
||||||
|
```
|
||||||
|
`example.wav` contains synthesized audio for the language.
|
||||||
|
|
||||||
|
|
||||||
|
### LID
|
||||||
|
|
||||||
|
|
||||||
|
Prepare two files in this format
|
||||||
|
```
|
||||||
|
#/path/to/manifest.tsv
|
||||||
|
/
|
||||||
|
/path/to/audio1.wav
|
||||||
|
/path/to/audio2.wav
|
||||||
|
/path/to/audio3.wav
|
||||||
|
|
||||||
|
# /path/to/manifest.lang
|
||||||
|
eng 1
|
||||||
|
eng 1
|
||||||
|
eng 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Download model and the corresponding dictionary file for the LID model. The following command assuming there is a file named `dict.lang.txt` in `/path/to/dict/l126/`.
|
||||||
|
Use the following command to run inference -
|
||||||
|
```shell script
|
||||||
|
$ PYTHONPATH='.' python3 examples/mms/lid/infer.py /path/to/dict/l126/ --path /path/to/models/mms1b_l126.pt \
|
||||||
|
--task audio_classification --infer-manifest /path/to/manifest.tsv --output-path <OUTDIR>
|
||||||
|
```
|
||||||
|
`<OUTDIR>/predictions.txt` will contain the predictions from the model for the audio files in `manifest.tsv`.
|
||||||
|
|
||||||
|
|
||||||
|
# License
|
||||||
|
|
||||||
|
The MMS code and model weights are released under the CC-BY-NC 4.0 license.
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
|
||||||
|
**BibTeX:**
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{pratap2023mms,
|
||||||
|
title={Scaling Speech Technology to 1,000+ Languages},
|
||||||
|
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
32
examples/mms/asr/config/infer_common.yaml
Normal file
32
examples/mms/asr/config/infer_common.yaml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# @package _global_
|
||||||
|
# defaults:
|
||||||
|
# - hydra/launcher: submitit_slurm
|
||||||
|
|
||||||
|
# @package _group_
|
||||||
|
|
||||||
|
task:
|
||||||
|
_name: audio_finetuning
|
||||||
|
data: null
|
||||||
|
labels: ltr
|
||||||
|
common_eval:
|
||||||
|
path: null
|
||||||
|
post_process: letter
|
||||||
|
# model_overrides: "{'task':{'multi_corpus_keys':None}}"
|
||||||
|
decoding:
|
||||||
|
type: viterbi
|
||||||
|
lexicon: null
|
||||||
|
unique_wer_file: false
|
||||||
|
results_path: null
|
||||||
|
distributed_training:
|
||||||
|
ddp_backend: legacy_ddp
|
||||||
|
distributed_world_size: 1
|
||||||
|
hydra:
|
||||||
|
run:
|
||||||
|
dir: ${common_eval.results_path}/${dataset.gen_subset}
|
||||||
|
sweep:
|
||||||
|
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
|
||||||
|
subdir: ${dataset.gen_subset}
|
||||||
|
dataset:
|
||||||
|
max_tokens: 2_000_000
|
||||||
|
gen_subset: dev
|
||||||
|
required_batch_size_multiple: 1
|
3
examples/mms/asr/infer/example_infer_adapter.sh
Normal file
3
examples/mms/asr/infer/example_infer_adapter.sh
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
lang="$1"
|
||||||
|
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='/fsx-wav2vec/androstj/exps/wav2vec/mms/v4/finetune/xl1b_d5_dfls_0_0.3_u300k__ft_on_d5_127_dbeta1/ft_smax_adp_common.seed:1__dataset.max_tokens:2880000__optimization.lr:[0.001]__optimization.max_update:4000__merged_ckpt/checkpoints/checkpoint_last.pt'" task.data=/fsx-wav2vec/androstj/dataset/v4/fl/fseq dataset.gen_subset="${lang}:${lang}/dev" common_eval.post_process=none
|
52
examples/mms/asr/infer/mms_infer.py
Normal file
52
examples/mms/asr/infer/mms_infer.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#!/usr/bin/env python -u
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import soundfile as sf
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
|
||||||
|
def parser():
|
||||||
|
parser = argparse.ArgumentParser(description="ASR inference script for MMS model")
|
||||||
|
parser.add_argument("--model", type=str, help="path to ASR model", required=True)
|
||||||
|
parser.add_argument("--audio", type=str, help="path to audio file", required=True, nargs='+')
|
||||||
|
parser.add_argument("--lang", type=str, help="audio language", required=True)
|
||||||
|
parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def process(args):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
with open(tmpdir / "dev.tsv", "w") as fw:
|
||||||
|
fw.write("/\n")
|
||||||
|
for audio in args.audio:
|
||||||
|
nsample = sf.SoundFile(audio).frames
|
||||||
|
fw.write(f"{audio}\t{nsample}\n")
|
||||||
|
with open(tmpdir / "dev.uid", "w") as fw:
|
||||||
|
fw.write(f"{audio}\n"*len(args.audio))
|
||||||
|
with open(tmpdir / "dev.ltr", "w") as fw:
|
||||||
|
fw.write("d u m m y | d u m m y\n"*len(args.audio))
|
||||||
|
with open(tmpdir / "dev.wrd", "w") as fw:
|
||||||
|
fw.write("dummy dummy\n"*len(args.audio))
|
||||||
|
cmd = f"""
|
||||||
|
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir}
|
||||||
|
"""
|
||||||
|
print(">>> loading model & running inference ...", file=sys.stderr)
|
||||||
|
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
|
||||||
|
with open(tmpdir/"hypo.word") as fr:
|
||||||
|
for ii, hypo in enumerate(fr):
|
||||||
|
hypo = re.sub("\(\S+\)$", "", hypo).strip()
|
||||||
|
print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser()
|
||||||
|
process(args)
|
47
examples/mms/data_prep/README.md
Normal file
47
examples/mms/data_prep/README.md
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Data Preparation
|
||||||
|
|
||||||
|
We describe the process of aligning long audio files with their transcripts and generating shorter audio segments below.
|
||||||
|
|
||||||
|
- Step 1: Download and install torchaudio using the nightly version. We have open sourced the CTC forced alignment algorithm described in our paper via [torchaudio](https://github.com/pytorch/audio/pull/3348).
|
||||||
|
```
|
||||||
|
pip install --pre torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
|
||||||
|
```
|
||||||
|
|
||||||
|
- Step 2: Download [uroman](https://github.com/isi-nlp/uroman) from Github. It is a universal romanizer which converts text in any script to the Latin alphabet. Use [this link](https://www.isi.edu/~ulf/uroman.html) to try their web interface.
|
||||||
|
```
|
||||||
|
git clone git@github.com:isi-nlp/uroman.git
|
||||||
|
```
|
||||||
|
|
||||||
|
- Step 3: Install a few other dependencies
|
||||||
|
```
|
||||||
|
pip install sox
|
||||||
|
pip install dataclasses
|
||||||
|
```
|
||||||
|
|
||||||
|
- Step 4: Create a text file containing the transcript for a (long) audio file. Each line in the text file will correspond to a separate audio segment that will be generated upon alignment.
|
||||||
|
|
||||||
|
Example content of the input text file :
|
||||||
|
```
|
||||||
|
Text of the desired first segment
|
||||||
|
Text of the desired second segment
|
||||||
|
Text of the desired third segment
|
||||||
|
```
|
||||||
|
|
||||||
|
- Step 5: Run forced alignment and segment the audio file into shorter segments.
|
||||||
|
```
|
||||||
|
python align_and_segment.py --audio /path/to/audio.wav --textfile /path/to/textfile --lang <iso> --outdir /path/to/output --uroman /path/to/uroman/bin
|
||||||
|
```
|
||||||
|
|
||||||
|
The above code will generated the audio segments under output directory based on the content of each line in the input text file. The `manifest.json` file consisting of the of segmented audio filepaths and their corresponding transcripts.
|
||||||
|
|
||||||
|
```
|
||||||
|
> head /path/to/output/manifest.json
|
||||||
|
|
||||||
|
{"audio_start_sec": 0.0, "audio_filepath": "/path/to/output/segment1.flac", "duration": 6.8, "text": "she wondered afterwards how she could have spoken with that hard serenity how she could have", "normalized_text": "she wondered afterwards how she could have spoken with that hard serenity how she could have", "uroman_tokens": "s h e w o n d e r e d a f t e r w a r d s h o w s h e c o u l d h a v e s p o k e n w i t h t h a t h a r d s e r e n i t y h o w s h e c o u l d h a v e"}
|
||||||
|
{"audio_start_sec": 6.8, "audio_filepath": "/path/to/output/segment2.flac", "duration": 5.3, "text": "gone steadily on with story after story poem after poem till", "normalized_text": "gone steadily on with story after story poem after poem till", "uroman_tokens": "g o n e s t e a d i l y o n w i t h s t o r y a f t e r s t o r y p o e m a f t e r p o e m t i l l"}
|
||||||
|
{"audio_start_sec": 12.1, "audio_filepath": "/path/to/output/segment3.flac", "duration": 5.9, "text": "allan's grip on her hands relaxed and he fell into a heavy tired sleep", "normalized_text": "allan's grip on her hands relaxed and he fell into a heavy tired sleep", "uroman_tokens": "a l l a n ' s g r i p o n h e r h a n d s r e l a x e d a n d h e f e l l i n t o a h e a v y t i r e d s l e e p"}
|
||||||
|
```
|
||||||
|
|
||||||
|
To visualize the segmented audio files, [Speech Data Explorer](https://github.com/NVIDIA/NeMo/tree/main/tools/speech_data_explorer) tool from NeMo toolkit can be used.
|
||||||
|
|
||||||
|
As our alignment model outputs uroman tokens for input audio in any language, it also works with non-english audio and their corresponding transcripts.
|
187
examples/mms/data_prep/align_and_segment.py
Normal file
187
examples/mms/data_prep/align_and_segment.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import sox
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
from examples.mms.data_prep.text_normalization import text_normalize
|
||||||
|
from examples.mms.data_prep.align_utils import (
|
||||||
|
get_uroman_tokens,
|
||||||
|
time_to_frame,
|
||||||
|
load_model_dict,
|
||||||
|
merge_repeats,
|
||||||
|
get_spans,
|
||||||
|
)
|
||||||
|
import torchaudio.functional as F
|
||||||
|
|
||||||
|
SAMPLING_FREQ = 16000
|
||||||
|
EMISSION_INTERVAL = 30
|
||||||
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
def generate_emissions(model, audio_file):
|
||||||
|
waveform, _ = torchaudio.load(audio_file) # waveform: channels X T
|
||||||
|
waveform = waveform.to(DEVICE)
|
||||||
|
total_duration = sox.file_info.duration(audio_file)
|
||||||
|
|
||||||
|
audio_sf = sox.file_info.sample_rate(audio_file)
|
||||||
|
assert audio_sf == SAMPLING_FREQ
|
||||||
|
|
||||||
|
emissions_arr = []
|
||||||
|
with torch.inference_mode():
|
||||||
|
i = 0
|
||||||
|
while i < total_duration:
|
||||||
|
segment_start_time, segment_end_time = (i, i + EMISSION_INTERVAL)
|
||||||
|
|
||||||
|
context = EMISSION_INTERVAL * 0.1
|
||||||
|
input_start_time = max(segment_start_time - context, 0)
|
||||||
|
input_end_time = min(segment_end_time + context, total_duration)
|
||||||
|
waveform_split = waveform[
|
||||||
|
:,
|
||||||
|
int(SAMPLING_FREQ * input_start_time) : int(
|
||||||
|
SAMPLING_FREQ * (input_end_time)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
model_outs, _ = model(waveform_split)
|
||||||
|
emissions_ = model_outs[0]
|
||||||
|
emission_start_frame = time_to_frame(segment_start_time)
|
||||||
|
emission_end_frame = time_to_frame(segment_end_time)
|
||||||
|
offset = time_to_frame(input_start_time)
|
||||||
|
|
||||||
|
emissions_ = emissions_[
|
||||||
|
:, emission_start_frame - offset : emission_end_frame - offset
|
||||||
|
]
|
||||||
|
emissions_arr.append(emissions_)
|
||||||
|
i += EMISSION_INTERVAL
|
||||||
|
|
||||||
|
emissions = torch.cat(emissions_arr, dim=1).squeeze()
|
||||||
|
emissions = torch.log_softmax(emissions, dim=-1)
|
||||||
|
|
||||||
|
stride = float(waveform.size(1) * 1000 / emissions.size(0) / SAMPLING_FREQ)
|
||||||
|
|
||||||
|
return emissions, stride
|
||||||
|
|
||||||
|
|
||||||
|
def get_alignments(
|
||||||
|
audio_file,
|
||||||
|
tokens,
|
||||||
|
model,
|
||||||
|
dictionary,
|
||||||
|
use_star,
|
||||||
|
):
|
||||||
|
# Generate emissions
|
||||||
|
emissions, stride = generate_emissions(model, audio_file)
|
||||||
|
T, N = emissions.size()
|
||||||
|
if use_star:
|
||||||
|
emissions = torch.cat([emissions, torch.zeros(T, 1).to(DEVICE)], dim=1)
|
||||||
|
|
||||||
|
# Force Alignment
|
||||||
|
if tokens:
|
||||||
|
token_indices = [dictionary[c] for c in " ".join(tokens).split(" ") if c in dictionary]
|
||||||
|
else:
|
||||||
|
print(f"Empty transcript!!!!! for audio file {audio_file}")
|
||||||
|
token_indices = []
|
||||||
|
|
||||||
|
blank = dictionary["<blank>"]
|
||||||
|
|
||||||
|
path, _ = F.force_align(
|
||||||
|
emissions, torch.Tensor(token_indices, device=DEVICE).int(), blank=blank
|
||||||
|
)
|
||||||
|
path = path.to("cpu").tolist()
|
||||||
|
segments = merge_repeats(path, {v: k for k, v in dictionary.items()})
|
||||||
|
return segments, stride
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
assert not os.path.exists(
|
||||||
|
args.outdir
|
||||||
|
), f"Error: Output path exists already {args.outdir}"
|
||||||
|
|
||||||
|
transcripts = []
|
||||||
|
with open(args.text_filepath) as f:
|
||||||
|
transcripts = [line.strip() for line in f]
|
||||||
|
print("Read {} lines from {}".format(len(transcripts), args.text_filepath))
|
||||||
|
|
||||||
|
norm_transcripts = [text_normalize(line.strip(), args.lang) for line in transcripts]
|
||||||
|
tokens = get_uroman_tokens(norm_transcripts, args.uroman_path, args.lang)
|
||||||
|
|
||||||
|
model, dictionary = load_model_dict()
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
if args.use_star:
|
||||||
|
dictionary["<star>"] = len(dictionary)
|
||||||
|
tokens = ["<star>"] + tokens
|
||||||
|
transcripts = ["<star>"] + transcripts
|
||||||
|
norm_transcripts = ["<star>"] + norm_transcripts
|
||||||
|
|
||||||
|
segments, stride = get_alignments(
|
||||||
|
args.audio_filepath,
|
||||||
|
tokens,
|
||||||
|
model,
|
||||||
|
dictionary,
|
||||||
|
args.use_star,
|
||||||
|
)
|
||||||
|
# Get spans of each line in input text file
|
||||||
|
spans = get_spans(tokens, segments)
|
||||||
|
|
||||||
|
os.makedirs(args.outdir)
|
||||||
|
with open( f"{args.outdir}/manifest.json", "w") as f:
|
||||||
|
for i, t in enumerate(transcripts):
|
||||||
|
span = spans[i]
|
||||||
|
seg_start_idx = span[0].start
|
||||||
|
seg_end_idx = span[-1].end
|
||||||
|
|
||||||
|
output_file = f"{args.outdir}/segment{i}.flac"
|
||||||
|
|
||||||
|
audio_start_sec = seg_start_idx * stride / 1000
|
||||||
|
audio_end_sec = seg_end_idx * stride / 1000
|
||||||
|
|
||||||
|
tfm = sox.Transformer()
|
||||||
|
tfm.trim(audio_start_sec , audio_end_sec)
|
||||||
|
tfm.build_file(args.audio_filepath, output_file)
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
"audio_start_sec": audio_start_sec,
|
||||||
|
"audio_filepath": str(output_file),
|
||||||
|
"duration": audio_end_sec - audio_start_sec,
|
||||||
|
"text": t,
|
||||||
|
"normalized_text":norm_transcripts[i],
|
||||||
|
"uroman_tokens": tokens[i],
|
||||||
|
}
|
||||||
|
f.write(json.dumps(sample) + "\n")
|
||||||
|
|
||||||
|
return segments, stride
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Align and segment long audio files")
|
||||||
|
parser.add_argument(
|
||||||
|
"-a", "--audio_filepath", type=str, help="Path to input audio file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t", "--text_filepath", type=str, help="Path to input text file "
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l", "--lang", type=str, default="eng", help="ISO code of the language"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-u", "--uroman_path", type=str, default="eng", help="Location to uroman/bin"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--use_star",
|
||||||
|
action="store_true",
|
||||||
|
help="Use star at the start of transcript",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
help="Output directory to store segmented audio files",
|
||||||
|
)
|
||||||
|
print("Using torch version:", torch.__version__)
|
||||||
|
print("Using torchaudio version:", torchaudio.__version__)
|
||||||
|
print("Using device: ", DEVICE)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
176
examples/mms/data_prep/align_utils.py
Normal file
176
examples/mms/data_prep/align_utils.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from torchaudio.models import wav2vec2_model
|
||||||
|
|
||||||
|
# iso codes with specialized rules in uroman
|
||||||
|
special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split(",")
|
||||||
|
special_isos_uroman = [i.strip() for i in special_isos_uroman]
|
||||||
|
|
||||||
|
def normalize_uroman(text):
|
||||||
|
text = text.lower()
|
||||||
|
text = re.sub("([^a-z' ])", " ", text)
|
||||||
|
text = re.sub(' +', ' ', text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_uroman_tokens(norm_transcripts, uroman_root_dir, iso = None):
|
||||||
|
tf = tempfile.NamedTemporaryFile()
|
||||||
|
tf2 = tempfile.NamedTemporaryFile()
|
||||||
|
with open(tf.name, "w") as f:
|
||||||
|
for t in norm_transcripts:
|
||||||
|
f.write(t + "\n")
|
||||||
|
|
||||||
|
assert os.path.exists(f"{uroman_root_dir}/uroman.pl"), "uroman not found"
|
||||||
|
cmd = f"perl {uroman_root_dir}/uroman.pl"
|
||||||
|
if iso in special_isos_uroman:
|
||||||
|
cmd += f" -l {iso} "
|
||||||
|
cmd += f" < {tf.name} > {tf2.name}"
|
||||||
|
os.system(cmd)
|
||||||
|
outtexts = []
|
||||||
|
with open(tf2.name) as f:
|
||||||
|
for line in f:
|
||||||
|
line = " ".join(line.strip())
|
||||||
|
line = re.sub(r"\s+", " ", line).strip()
|
||||||
|
outtexts.append(line)
|
||||||
|
assert len(outtexts) == len(norm_transcripts)
|
||||||
|
uromans = []
|
||||||
|
for ot in outtexts:
|
||||||
|
uromans.append(normalize_uroman(ot))
|
||||||
|
return uromans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Segment:
|
||||||
|
label: str
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.label}: [{self.start:5d}, {self.end:5d})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self):
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
|
||||||
|
def merge_repeats(path, idx_to_token_map):
|
||||||
|
i1, i2 = 0, 0
|
||||||
|
segments = []
|
||||||
|
while i1 < len(path):
|
||||||
|
while i2 < len(path) and path[i1] == path[i2]:
|
||||||
|
i2 += 1
|
||||||
|
segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1))
|
||||||
|
i1 = i2
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def time_to_frame(time):
|
||||||
|
stride_msec = 20
|
||||||
|
frames_per_sec = 1000 / stride_msec
|
||||||
|
return int(time * frames_per_sec)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_dict():
|
||||||
|
model_path_name = "/tmp/ctc_alignment_mling_uroman_model.pt"
|
||||||
|
|
||||||
|
print("Downloading model and dictionary...")
|
||||||
|
if os.path.exists(model_path_name):
|
||||||
|
print("Model path already exists. Skipping downloading....")
|
||||||
|
else:
|
||||||
|
torch.hub.download_url_to_file(
|
||||||
|
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
|
||||||
|
model_path_name,
|
||||||
|
)
|
||||||
|
assert os.path.exists(model_path_name)
|
||||||
|
state_dict = torch.load(model_path_name, map_location="cpu")
|
||||||
|
|
||||||
|
model = wav2vec2_model(
|
||||||
|
extractor_mode="layer_norm",
|
||||||
|
extractor_conv_layer_config=[
|
||||||
|
(512, 10, 5),
|
||||||
|
(512, 3, 2),
|
||||||
|
(512, 3, 2),
|
||||||
|
(512, 3, 2),
|
||||||
|
(512, 3, 2),
|
||||||
|
(512, 2, 2),
|
||||||
|
(512, 2, 2),
|
||||||
|
],
|
||||||
|
extractor_conv_bias=True,
|
||||||
|
encoder_embed_dim=1024,
|
||||||
|
encoder_projection_dropout=0.0,
|
||||||
|
encoder_pos_conv_kernel=128,
|
||||||
|
encoder_pos_conv_groups=16,
|
||||||
|
encoder_num_layers=24,
|
||||||
|
encoder_num_heads=16,
|
||||||
|
encoder_attention_dropout=0.0,
|
||||||
|
encoder_ff_interm_features=4096,
|
||||||
|
encoder_ff_interm_dropout=0.1,
|
||||||
|
encoder_dropout=0.0,
|
||||||
|
encoder_layer_norm_first=True,
|
||||||
|
encoder_layer_drop=0.1,
|
||||||
|
aux_num_out=31,
|
||||||
|
)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dict_path_name = "/tmp/ctc_alignment_mling_uroman_model.dict"
|
||||||
|
if os.path.exists(dict_path_name):
|
||||||
|
print("Dictionary path already exists. Skipping downloading....")
|
||||||
|
else:
|
||||||
|
torch.hub.download_url_to_file(
|
||||||
|
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/dictionary.txt",
|
||||||
|
dict_path_name,
|
||||||
|
)
|
||||||
|
assert os.path.exists(dict_path_name)
|
||||||
|
dictionary = {}
|
||||||
|
with open(dict_path_name) as f:
|
||||||
|
dictionary = {l.strip(): i for i, l in enumerate(f.readlines())}
|
||||||
|
|
||||||
|
return model, dictionary
|
||||||
|
|
||||||
|
def get_spans(tokens, segments):
|
||||||
|
ltr_idx = 0
|
||||||
|
tokens_idx = 0
|
||||||
|
intervals = []
|
||||||
|
start, end = (0, 0)
|
||||||
|
sil = "<blank>"
|
||||||
|
for (seg_idx, seg) in enumerate(segments):
|
||||||
|
if(tokens_idx == len(tokens)):
|
||||||
|
assert(seg_idx == len(segments) - 1)
|
||||||
|
assert(seg.label == '<blank>')
|
||||||
|
continue
|
||||||
|
cur_token = tokens[tokens_idx].split(' ')
|
||||||
|
ltr = cur_token[ltr_idx]
|
||||||
|
if seg.label == "<blank>": continue
|
||||||
|
assert(seg.label == ltr)
|
||||||
|
if(ltr_idx) == 0: start = seg_idx
|
||||||
|
if ltr_idx == len(cur_token) - 1:
|
||||||
|
ltr_idx = 0
|
||||||
|
tokens_idx += 1
|
||||||
|
intervals.append((start, seg_idx))
|
||||||
|
while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0:
|
||||||
|
intervals.append((seg_idx, seg_idx))
|
||||||
|
tokens_idx += 1
|
||||||
|
else:
|
||||||
|
ltr_idx += 1
|
||||||
|
spans = []
|
||||||
|
for (idx, (start, end)) in enumerate(intervals):
|
||||||
|
span = segments[start:end + 1]
|
||||||
|
if start > 0:
|
||||||
|
prev_seg = segments[start - 1]
|
||||||
|
if prev_seg.label == sil:
|
||||||
|
pad_start = prev_seg.start if (idx == 0) else int((prev_seg.start + prev_seg.end)/2)
|
||||||
|
span = [Segment(sil, pad_start, span[0].start)] + span
|
||||||
|
if end+1 < len(segments):
|
||||||
|
next_seg = segments[end+1]
|
||||||
|
if next_seg.label == sil:
|
||||||
|
pad_end = next_seg.end if (idx == len(intervals) - 1) else math.floor((next_seg.start + next_seg.end) / 2)
|
||||||
|
span = span + [Segment(sil, span[-1].end, pad_end)]
|
||||||
|
spans.append(span)
|
||||||
|
return spans
|
277
examples/mms/data_prep/norm_config.py
Normal file
277
examples/mms/data_prep/norm_config.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
colon = ":"
|
||||||
|
comma = ","
|
||||||
|
exclamation_mark = "!"
|
||||||
|
period = re.escape(".")
|
||||||
|
question_mark = re.escape("?")
|
||||||
|
semicolon = ";"
|
||||||
|
|
||||||
|
left_curly_bracket = "{"
|
||||||
|
right_curly_bracket = "}"
|
||||||
|
quotation_mark = '"'
|
||||||
|
|
||||||
|
basic_punc = (
|
||||||
|
period
|
||||||
|
+ question_mark
|
||||||
|
+ comma
|
||||||
|
+ colon
|
||||||
|
+ exclamation_mark
|
||||||
|
+ left_curly_bracket
|
||||||
|
+ right_curly_bracket
|
||||||
|
)
|
||||||
|
|
||||||
|
# General punc unicode block (0x2000-0x206F)
|
||||||
|
zero_width_space = r"\u200B"
|
||||||
|
zero_width_nonjoiner = r"\u200C"
|
||||||
|
left_to_right_mark = r"\u200E"
|
||||||
|
right_to_left_mark = r"\u200F"
|
||||||
|
left_to_right_embedding = r"\u202A"
|
||||||
|
pop_directional_formatting = r"\u202C"
|
||||||
|
|
||||||
|
# Here are some commonly ill-typed versions of apostrophe
|
||||||
|
right_single_quotation_mark = r"\u2019"
|
||||||
|
left_single_quotation_mark = r"\u2018"
|
||||||
|
|
||||||
|
# Language specific definitions
|
||||||
|
# Spanish
|
||||||
|
inverted_exclamation_mark = r"\u00A1"
|
||||||
|
inverted_question_mark = r"\u00BF"
|
||||||
|
|
||||||
|
|
||||||
|
# Hindi
|
||||||
|
hindi_danda = u"\u0964"
|
||||||
|
|
||||||
|
# Egyptian Arabic
|
||||||
|
# arabic_percent = r"\u066A"
|
||||||
|
arabic_comma = r"\u060C"
|
||||||
|
arabic_question_mark = r"\u061F"
|
||||||
|
arabic_semicolon = r"\u061B"
|
||||||
|
arabic_diacritics = r"\u064B-\u0652"
|
||||||
|
|
||||||
|
|
||||||
|
arabic_subscript_alef_and_inverted_damma = r"\u0656-\u0657"
|
||||||
|
|
||||||
|
|
||||||
|
# Chinese
|
||||||
|
full_stop = r"\u3002"
|
||||||
|
full_comma = r"\uFF0C"
|
||||||
|
full_exclamation_mark = r"\uFF01"
|
||||||
|
full_question_mark = r"\uFF1F"
|
||||||
|
full_semicolon = r"\uFF1B"
|
||||||
|
full_colon = r"\uFF1A"
|
||||||
|
full_parentheses = r"\uFF08\uFF09"
|
||||||
|
quotation_mark_horizontal = r"\u300C-\u300F"
|
||||||
|
quotation_mark_vertical = r"\uFF41-\uFF44"
|
||||||
|
title_marks = r"\u3008-\u300B"
|
||||||
|
wavy_low_line = r"\uFE4F"
|
||||||
|
ellipsis = r"\u22EF"
|
||||||
|
enumeration_comma = r"\u3001"
|
||||||
|
hyphenation_point = r"\u2027"
|
||||||
|
forward_slash = r"\uFF0F"
|
||||||
|
wavy_dash = r"\uFF5E"
|
||||||
|
box_drawings_light_horizontal = r"\u2500"
|
||||||
|
fullwidth_low_line = r"\uFF3F"
|
||||||
|
chinese_punc = (
|
||||||
|
full_stop
|
||||||
|
+ full_comma
|
||||||
|
+ full_exclamation_mark
|
||||||
|
+ full_question_mark
|
||||||
|
+ full_semicolon
|
||||||
|
+ full_colon
|
||||||
|
+ full_parentheses
|
||||||
|
+ quotation_mark_horizontal
|
||||||
|
+ quotation_mark_vertical
|
||||||
|
+ title_marks
|
||||||
|
+ wavy_low_line
|
||||||
|
+ ellipsis
|
||||||
|
+ enumeration_comma
|
||||||
|
+ hyphenation_point
|
||||||
|
+ forward_slash
|
||||||
|
+ wavy_dash
|
||||||
|
+ box_drawings_light_horizontal
|
||||||
|
+ fullwidth_low_line
|
||||||
|
)
|
||||||
|
|
||||||
|
# Armenian
|
||||||
|
armenian_apostrophe = r"\u055A"
|
||||||
|
emphasis_mark = r"\u055B"
|
||||||
|
exclamation_mark = r"\u055C"
|
||||||
|
armenian_comma = r"\u055D"
|
||||||
|
armenian_question_mark = r"\u055E"
|
||||||
|
abbreviation_mark = r"\u055F"
|
||||||
|
armenian_full_stop = r"\u0589"
|
||||||
|
armenian_punc = (
|
||||||
|
armenian_apostrophe
|
||||||
|
+ emphasis_mark
|
||||||
|
+ exclamation_mark
|
||||||
|
+ armenian_comma
|
||||||
|
+ armenian_question_mark
|
||||||
|
+ abbreviation_mark
|
||||||
|
+ armenian_full_stop
|
||||||
|
)
|
||||||
|
|
||||||
|
lesser_than_symbol = r"<"
|
||||||
|
greater_than_symbol = r">"
|
||||||
|
|
||||||
|
lesser_than_sign = r"\u003c"
|
||||||
|
greater_than_sign = r"\u003e"
|
||||||
|
|
||||||
|
nbsp_written_form = r" "
|
||||||
|
|
||||||
|
# Quotation marks
|
||||||
|
left_double_quotes = r"\u201c"
|
||||||
|
right_double_quotes = r"\u201d"
|
||||||
|
left_double_angle = r"\u00ab"
|
||||||
|
right_double_angle = r"\u00bb"
|
||||||
|
left_single_angle = r"\u2039"
|
||||||
|
right_single_angle = r"\u203a"
|
||||||
|
low_double_quotes = r"\u201e"
|
||||||
|
low_single_quotes = r"\u201a"
|
||||||
|
high_double_quotes = r"\u201f"
|
||||||
|
high_single_quotes = r"\u201b"
|
||||||
|
|
||||||
|
all_punct_quotes = (
|
||||||
|
left_double_quotes
|
||||||
|
+ right_double_quotes
|
||||||
|
+ left_double_angle
|
||||||
|
+ right_double_angle
|
||||||
|
+ left_single_angle
|
||||||
|
+ right_single_angle
|
||||||
|
+ low_double_quotes
|
||||||
|
+ low_single_quotes
|
||||||
|
+ high_double_quotes
|
||||||
|
+ high_single_quotes
|
||||||
|
+ right_single_quotation_mark
|
||||||
|
+ left_single_quotation_mark
|
||||||
|
)
|
||||||
|
mapping_quotes = (
|
||||||
|
"["
|
||||||
|
+ high_single_quotes
|
||||||
|
+ right_single_quotation_mark
|
||||||
|
+ left_single_quotation_mark
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Digits
|
||||||
|
|
||||||
|
english_digits = r"\u0030-\u0039"
|
||||||
|
bengali_digits = r"\u09e6-\u09ef"
|
||||||
|
khmer_digits = r"\u17e0-\u17e9"
|
||||||
|
devanagari_digits = r"\u0966-\u096f"
|
||||||
|
oriya_digits = r"\u0b66-\u0b6f"
|
||||||
|
extended_arabic_indic_digits = r"\u06f0-\u06f9"
|
||||||
|
kayah_li_digits = r"\ua900-\ua909"
|
||||||
|
fullwidth_digits = r"\uff10-\uff19"
|
||||||
|
malayam_digits = r"\u0d66-\u0d6f"
|
||||||
|
myanmar_digits = r"\u1040-\u1049"
|
||||||
|
roman_numeral = r"\u2170-\u2179"
|
||||||
|
nominal_digit_shapes = r"\u206f"
|
||||||
|
|
||||||
|
# Load punctuations from MMS-lab data
|
||||||
|
with open(f"{os.path.dirname(__file__)}/punctuations.lst", "r") as punc_f:
|
||||||
|
punc_list = punc_f.readlines()
|
||||||
|
|
||||||
|
punct_pattern = r""
|
||||||
|
for punc in punc_list:
|
||||||
|
# the first character in the tab separated line is the punc to be removed
|
||||||
|
punct_pattern += re.escape(punc.split("\t")[0])
|
||||||
|
|
||||||
|
shared_digits = (
|
||||||
|
english_digits
|
||||||
|
+ bengali_digits
|
||||||
|
+ khmer_digits
|
||||||
|
+ devanagari_digits
|
||||||
|
+ oriya_digits
|
||||||
|
+ extended_arabic_indic_digits
|
||||||
|
+ kayah_li_digits
|
||||||
|
+ fullwidth_digits
|
||||||
|
+ malayam_digits
|
||||||
|
+ myanmar_digits
|
||||||
|
+ roman_numeral
|
||||||
|
+ nominal_digit_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_punc_list = (
|
||||||
|
basic_punc
|
||||||
|
+ all_punct_quotes
|
||||||
|
+ greater_than_sign
|
||||||
|
+ lesser_than_sign
|
||||||
|
+ inverted_question_mark
|
||||||
|
+ full_stop
|
||||||
|
+ semicolon
|
||||||
|
+ armenian_punc
|
||||||
|
+ inverted_exclamation_mark
|
||||||
|
+ arabic_comma
|
||||||
|
+ enumeration_comma
|
||||||
|
+ hindi_danda
|
||||||
|
+ quotation_mark
|
||||||
|
+ arabic_semicolon
|
||||||
|
+ arabic_question_mark
|
||||||
|
+ chinese_punc
|
||||||
|
+ punct_pattern
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_mappping = {
|
||||||
|
lesser_than_symbol: "",
|
||||||
|
greater_than_symbol: "",
|
||||||
|
nbsp_written_form: "",
|
||||||
|
r"(\S+)" + mapping_quotes + r"(\S+)": r"\1'\2",
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_deletion_list = (
|
||||||
|
left_to_right_mark
|
||||||
|
+ zero_width_nonjoiner
|
||||||
|
+ arabic_subscript_alef_and_inverted_damma
|
||||||
|
+ zero_width_space
|
||||||
|
+ arabic_diacritics
|
||||||
|
+ pop_directional_formatting
|
||||||
|
+ right_to_left_mark
|
||||||
|
+ left_to_right_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_config = {
|
||||||
|
"*": {
|
||||||
|
"lower_case": True,
|
||||||
|
"punc_set": shared_punc_list,
|
||||||
|
"del_set": shared_deletion_list,
|
||||||
|
"mapping": shared_mappping,
|
||||||
|
"digit_set": shared_digits,
|
||||||
|
"unicode_norm": "NFKC",
|
||||||
|
"rm_diacritics" : False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#=============== Mongolian ===============#
|
||||||
|
|
||||||
|
norm_config["mon"] = norm_config["*"].copy()
|
||||||
|
# add soft hyphen to punc list to match with fleurs
|
||||||
|
norm_config["mon"]["del_set"] += r"\u00AD"
|
||||||
|
|
||||||
|
norm_config["khk"] = norm_config["mon"].copy()
|
||||||
|
|
||||||
|
#=============== Hebrew ===============#
|
||||||
|
|
||||||
|
norm_config["heb"] = norm_config["*"].copy()
|
||||||
|
# add "HEBREW POINT" symbols to match with fleurs
|
||||||
|
norm_config["heb"]["del_set"] += r"\u05B0-\u05BF\u05C0-\u05CF"
|
||||||
|
|
||||||
|
#=============== Thai ===============#
|
||||||
|
|
||||||
|
norm_config["tha"] = norm_config["*"].copy()
|
||||||
|
# add "Zero width joiner" symbols to match with fleurs
|
||||||
|
norm_config["tha"]["punc_set"] += r"\u200D"
|
||||||
|
|
||||||
|
#=============== Arabic ===============#
|
||||||
|
norm_config["ara"] = norm_config["*"].copy()
|
||||||
|
norm_config["ara"]["mapping"]["ٱ"] = "ا"
|
||||||
|
norm_config["arb"] = norm_config["ara"].copy()
|
||||||
|
|
||||||
|
#=============== Javanese ===============#
|
||||||
|
norm_config["jav"] = norm_config["*"].copy()
|
||||||
|
norm_config["jav"]["rm_diacritics"] = True
|
||||||
|
|
188
examples/mms/data_prep/punctuations.lst
Normal file
188
examples/mms/data_prep/punctuations.lst
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
7355 INVALID UNICODE 0x81
|
||||||
|
5265 INVALID UNICODE 0x90
|
||||||
|
75 INVALID UNICODE 0x8
|
||||||
|
31 INVALID UNICODE 0x8d
|
||||||
|
3 INVALID UNICODE 0x94
|
||||||
|
2 INVALID UNICODE 0x8f
|
||||||
|
2 INVALID UNICODE 0x1a
|
||||||
|
1 INVALID UNICODE 0x9d
|
||||||
|
1 INVALID UNICODE 0x93
|
||||||
|
1 INVALID UNICODE 0x92
|
||||||
|
8647 INVALID UNICODE 0xe295
|
||||||
|
6650 INVALID UNICODE 0xf21d
|
||||||
|
6234 INVALID UNICODE 0xf62d
|
||||||
|
4815 INVALID UNICODE 0xf173
|
||||||
|
4789 INVALID UNICODE 0xe514
|
||||||
|
4409 INVALID UNICODE 0xe293
|
||||||
|
3881 INVALID UNICODE 0xf523
|
||||||
|
3788 INVALID UNICODE 0xe233
|
||||||
|
2448 INVALID UNICODE 0xf50f
|
||||||
|
2177 INVALID UNICODE 0xe232
|
||||||
|
1955 INVALID UNICODE 0xea7b
|
||||||
|
1926 INVALID UNICODE 0xf172
|
||||||
|
973 INVALID UNICODE 0xe290
|
||||||
|
972 INVALID UNICODE 0xf519
|
||||||
|
661 INVALID UNICODE 0xe292
|
||||||
|
591 INVALID UNICODE 0xe328
|
||||||
|
509 INVALID UNICODE 0xe2fa
|
||||||
|
458 INVALID UNICODE 0xe234
|
||||||
|
446 INVALID UNICODE 0xe043
|
||||||
|
419 INVALID UNICODE 0xe040
|
||||||
|
399 INVALID UNICODE 0xe2fb
|
||||||
|
387 INVALID UNICODE 0xe32b
|
||||||
|
381 INVALID UNICODE 0xe236
|
||||||
|
374 INVALID UNICODE 0xf511
|
||||||
|
314 INVALID UNICODE 0xe517
|
||||||
|
296 INVALID UNICODE 0xe2fe
|
||||||
|
293 INVALID UNICODE 0xe492
|
||||||
|
291 INVALID UNICODE 0xf52d
|
||||||
|
289 INVALID UNICODE 0xe2fc
|
||||||
|
195 INVALID UNICODE 0xf521
|
||||||
|
190 INVALID UNICODE 0xe516
|
||||||
|
182 INVALID UNICODE 0xe041
|
||||||
|
178 INVALID UNICODE 0xf529
|
||||||
|
113 INVALID UNICODE 0xe2f9
|
||||||
|
87 INVALID UNICODE 0xe2d9
|
||||||
|
78 INVALID UNICODE 0xe32a
|
||||||
|
76 INVALID UNICODE 0xe291
|
||||||
|
74 INVALID UNICODE 0xe296
|
||||||
|
66 INVALID UNICODE 0xe518
|
||||||
|
52 INVALID UNICODE 0xe32c
|
||||||
|
46 INVALID UNICODE 0xe2db
|
||||||
|
41 INVALID UNICODE 0xe231
|
||||||
|
34 INVALID UNICODE 0xf522
|
||||||
|
33 INVALID UNICODE 0xf518
|
||||||
|
32 INVALID UNICODE 0xf513
|
||||||
|
27 INVALID UNICODE 0xe32d
|
||||||
|
25 INVALID UNICODE 0xe32e
|
||||||
|
23 INVALID UNICODE 0xe06b
|
||||||
|
15 INVALID UNICODE 0xea01
|
||||||
|
12 INVALID UNICODE 0xe294
|
||||||
|
11 INVALID UNICODE 0xe203
|
||||||
|
8 INVALID UNICODE 0xf218
|
||||||
|
7 INVALID UNICODE 0xe070
|
||||||
|
7 INVALID UNICODE 0xe013
|
||||||
|
5 INVALID UNICODE 0xe2de
|
||||||
|
4 INVALID UNICODE 0xe493
|
||||||
|
3 INVALID UNICODE 0xf7e8
|
||||||
|
3 INVALID UNICODE 0xf7d0
|
||||||
|
3 INVALID UNICODE 0xe313
|
||||||
|
2 INVALID UNICODE 0xe329
|
||||||
|
2 INVALID UNICODE 0xe06d
|
||||||
|
2 INVALID UNICODE 0xe003
|
||||||
|
1 INVALID UNICODE 0xf50e
|
||||||
|
1 INVALID UNICODE 0xf171
|
||||||
|
1 INVALID UNICODE 0xe01d
|
||||||
|
71 NOMINAL DIGIT SHAPES 0x206f
|
||||||
|
3 WORD JOINER 0x2060
|
||||||
|
― 126545 HORIZONTAL BAR 0x2015
|
||||||
|
־ 1028 HEBREW PUNCTUATION MAQAF 0x5be
|
||||||
|
) 98429 RIGHT PARENTHESIS 0x29
|
||||||
|
] 27108 RIGHT SQUARE BRACKET 0x5d
|
||||||
|
⌋ 1567 RIGHT FLOOR 0x230b
|
||||||
|
〕 97 RIGHT TORTOISE SHELL BRACKET 0x3015
|
||||||
|
】 36 RIGHT BLACK LENTICULAR BRACKET 0x3011
|
||||||
|
﴾ 14 ORNATE LEFT PARENTHESIS 0xfd3e
|
||||||
|
& 170517 AMPERSAND 0x26
|
||||||
|
། 106330 TIBETAN MARK SHAD 0xf0d
|
||||||
|
። 90203 ETHIOPIC FULL STOP 0x1362
|
||||||
|
፥ 60484 ETHIOPIC COLON 0x1365
|
||||||
|
༌ 60464 TIBETAN MARK DELIMITER TSHEG BSTAR 0xf0c
|
||||||
|
။ 51567 MYANMAR SIGN SECTION 0x104b
|
||||||
|
/ 46929 SOLIDUS 0x2f
|
||||||
|
၊ 38042 MYANMAR SIGN LITTLE SECTION 0x104a
|
||||||
|
· 37985 MIDDLE DOT 0xb7
|
||||||
|
‸ 36310 CARET 0x2038
|
||||||
|
* 34793 ASTERISK 0x2a
|
||||||
|
۔ 32432 ARABIC FULL STOP 0x6d4
|
||||||
|
፤ 31906 ETHIOPIC SEMICOLON 0x1364
|
||||||
|
၏ 21519 MYANMAR SYMBOL GENITIVE 0x104f
|
||||||
|
។ 20834 KHMER SIGN KHAN 0x17d4
|
||||||
|
꓾ 15773 LISU PUNCTUATION COMMA 0xa4fe
|
||||||
|
᙮ 13473 CANADIAN SYLLABICS FULL STOP 0x166e
|
||||||
|
꤯ 12892 KAYAH LI SIGN SHYA 0xa92f
|
||||||
|
⵰ 11478 TIFINAGH SEPARATOR MARK 0x2d70
|
||||||
|
꓿ 11118 LISU PUNCTUATION FULL STOP 0xa4ff
|
||||||
|
॥ 10763 DEVANAGARI DOUBLE DANDA 0x965
|
||||||
|
؞ 10403 ARABIC TRIPLE DOT PUNCTUATION MARK 0x61e
|
||||||
|
၍ 8936 MYANMAR SYMBOL COMPLETED 0x104d
|
||||||
|
· 8431 GREEK ANO TELEIA 0x387
|
||||||
|
† 7477 DAGGER 0x2020
|
||||||
|
၌ 6632 MYANMAR SYMBOL LOCATIVE 0x104c
|
||||||
|
፣ 5719 ETHIOPIC COMMA 0x1363
|
||||||
|
៖ 5528 KHMER SIGN CAMNUC PII KUUH 0x17d6
|
||||||
|
꤮ 4791 KAYAH LI SIGN CWI 0xa92e
|
||||||
|
※ 3439 REFERENCE MARK 0x203b
|
||||||
|
፦ 2727 ETHIOPIC PREFACE COLON 0x1366
|
||||||
|
• 1749 BULLET 0x2022
|
||||||
|
¶ 1507 PILCROW SIGN 0xb6
|
||||||
|
၎ 1386 MYANMAR SYMBOL AFOREMENTIONED 0x104e
|
||||||
|
﹖ 1224 SMALL QUESTION MARK 0xfe56
|
||||||
|
; 975 GREEK QUESTION MARK 0x37e
|
||||||
|
… 827 HORIZONTAL ELLIPSIS 0x2026
|
||||||
|
% 617 PERCENT SIGN 0x25
|
||||||
|
・ 468 KATAKANA MIDDLE DOT 0x30fb
|
||||||
|
༎ 306 TIBETAN MARK NYIS SHAD 0xf0e
|
||||||
|
‡ 140 DOUBLE DAGGER 0x2021
|
||||||
|
# 137 NUMBER SIGN 0x23
|
||||||
|
@ 125 COMMERCIAL AT 0x40
|
||||||
|
፡ 121 ETHIOPIC WORDSPACE 0x1361
|
||||||
|
៚ 55 KHMER SIGN KOOMUUT 0x17da
|
||||||
|
៕ 49 KHMER SIGN BARIYOOSAN 0x17d5
|
||||||
|
﹐ 10 SMALL COMMA 0xfe50
|
||||||
|
༅ 6 TIBETAN MARK CLOSING YIG MGO SGAB MA 0xf05
|
||||||
|
༄ 6 TIBETAN MARK INITIAL YIG MGO MDUN MA 0xf04
|
||||||
|
. 2 FULLWIDTH FULL STOP 0xff0e
|
||||||
|
﹗ 2 SMALL EXCLAMATION MARK 0xfe57
|
||||||
|
﹕ 2 SMALL COLON 0xfe55
|
||||||
|
‰ 2 PER MILLE SIGN 0x2030
|
||||||
|
・ 1 HALFWIDTH KATAKANA MIDDLE DOT 0xff65
|
||||||
|
( 98504 LEFT PARENTHESIS 0x28
|
||||||
|
[ 27245 LEFT SQUARE BRACKET 0x5b
|
||||||
|
⌊ 1567 LEFT FLOOR 0x230a
|
||||||
|
〔 95 LEFT TORTOISE SHELL BRACKET 0x3014
|
||||||
|
【 36 LEFT BLACK LENTICULAR BRACKET 0x3010
|
||||||
|
﴿ 14 ORNATE RIGHT PARENTHESIS 0xfd3f
|
||||||
|
_ 4851 LOW LINE 0x5f
|
||||||
|
$ 72 DOLLAR SIGN 0x24
|
||||||
|
€ 14 EURO SIGN 0x20ac
|
||||||
|
£ 2 POUND SIGN 0xa3
|
||||||
|
~ 27462 TILDE 0x7e
|
||||||
|
= 11450 EQUALS SIGN 0x3d
|
||||||
|
| 8430 VERTICAL LINE 0x7c
|
||||||
|
− 3971 MINUS SIGN 0x2212
|
||||||
|
≫ 1904 MUCH GREATER-THAN 0x226b
|
||||||
|
≪ 1903 MUCH LESS-THAN 0x226a
|
||||||
|
+ 1450 PLUS SIGN 0x2b
|
||||||
|
< 345 FULLWIDTH LESS-THAN SIGN 0xff1c
|
||||||
|
> 344 FULLWIDTH GREATER-THAN SIGN 0xff1e
|
||||||
|
¬ 5 NOT SIGN 0xac
|
||||||
|
× 4 MULTIPLICATION SIGN 0xd7
|
||||||
|
→ 2 RIGHTWARDS ARROW 0x2192
|
||||||
|
᙭ 537 CANADIAN SYLLABICS CHI SIGN 0x166d
|
||||||
|
° 499 DEGREE SIGN 0xb0
|
||||||
|
႟ 421 MYANMAR SYMBOL SHAN EXCLAMATION 0x109f
|
||||||
|
<EFBFBD> 192 REPLACEMENT CHARACTER 0xfffd
|
||||||
|
⌟ 54 BOTTOM RIGHT CORNER 0x231f
|
||||||
|
⌞ 54 BOTTOM LEFT CORNER 0x231e
|
||||||
|
© 2 COPYRIGHT SIGN 0xa9
|
||||||
|
40 NARROW NO-BREAK SPACE 0x202f
|
||||||
|
1 SIX-PER-EM SPACE 0x2006
|
||||||
|
˜ 40261 SMALL TILDE 0x2dc
|
||||||
|
^ 6469 CIRCUMFLEX ACCENT 0x5e
|
||||||
|
¯ 20 MACRON 0xaf
|
||||||
|
ˇ 191442 CARON 0x2c7
|
||||||
|
ⁿ 38144 SUPERSCRIPT LATIN SMALL LETTER N 0x207f
|
||||||
|
ـ 9440 ARABIC TATWEEL 0x640
|
||||||
|
ๆ 6766 THAI CHARACTER MAIYAMOK 0xe46
|
||||||
|
ៗ 3310 KHMER SIGN LEK TOO 0x17d7
|
||||||
|
々 678 IDEOGRAPHIC ITERATION MARK 0x3005
|
||||||
|
ໆ 430 LAO KO LA 0xec6
|
||||||
|
ー 319 KATAKANA-HIRAGANA PROLONGED SOUND MARK 0x30fc
|
||||||
|
ⁱ 137 SUPERSCRIPT LATIN SMALL LETTER I 0x2071
|
||||||
|
৷ 11056 BENGALI CURRENCY NUMERATOR FOUR 0x9f7
|
||||||
|
⅓ 26 VULGAR FRACTION ONE THIRD 0x2153
|
||||||
|
½ 26 VULGAR FRACTION ONE HALF 0xbd
|
||||||
|
¼ 4 VULGAR FRACTION ONE QUARTER 0xbc
|
||||||
|
⅟ 1 FRACTION NUMERATOR ONE 0x215f
|
||||||
|
⁄ 57 FRACTION SLASH 0x2044
|
92
examples/mms/data_prep/text_normalization.py
Normal file
92
examples/mms/data_prep/text_normalization.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
from examples.mms.data_prep.norm_config import norm_config
|
||||||
|
|
||||||
|
|
||||||
|
def text_normalize(text, iso_code, lower_case=True, remove_numbers=True, remove_brackets=False):
|
||||||
|
|
||||||
|
"""Given a text, normalize it by changing to lower case, removing punctuations, removing words that only contain digits and removing extra spaces
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text : The string to be normalized
|
||||||
|
iso_code :
|
||||||
|
remove_numbers : Boolean flag to specify if words containing only digits should be removed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
normalized_text : the string after all normalization
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = norm_config.get(iso_code, norm_config["*"])
|
||||||
|
|
||||||
|
for field in ["lower_case", "punc_set","del_set", "mapping", "digit_set", "unicode_norm"]:
|
||||||
|
if field not in config:
|
||||||
|
config[field] = norm_config["*"][field]
|
||||||
|
|
||||||
|
|
||||||
|
text = unicodedata.normalize(config["unicode_norm"], text)
|
||||||
|
|
||||||
|
# Convert to lower case
|
||||||
|
|
||||||
|
if config["lower_case"] and lower_case:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# brackets
|
||||||
|
|
||||||
|
# always text inside brackets with numbers in them. Usually corresponds to "(Sam 23:17)"
|
||||||
|
text = re.sub(r"\([^\)]*\d[^\)]*\)", " ", text)
|
||||||
|
if remove_brackets:
|
||||||
|
text = re.sub(r"\([^\)]*\)", " ", text)
|
||||||
|
|
||||||
|
# Apply mappings
|
||||||
|
|
||||||
|
for old, new in config["mapping"].items():
|
||||||
|
text = re.sub(old, new, text)
|
||||||
|
|
||||||
|
# Replace punctutations with space
|
||||||
|
|
||||||
|
punct_pattern = r"[" + config["punc_set"]
|
||||||
|
|
||||||
|
punct_pattern += "]"
|
||||||
|
|
||||||
|
normalized_text = re.sub(punct_pattern, " ", text)
|
||||||
|
|
||||||
|
# remove characters in delete list
|
||||||
|
|
||||||
|
delete_patten = r"[" + config["del_set"] + "]"
|
||||||
|
|
||||||
|
normalized_text = re.sub(delete_patten, "", normalized_text)
|
||||||
|
|
||||||
|
# Remove words containing only digits
|
||||||
|
# We check for 3 cases a)text starts with a number b) a number is present somewhere in the middle of the text c) the text ends with a number
|
||||||
|
# For each case we use lookaround regex pattern to see if the digit pattern in preceded and followed by whitespaces, only then we replace the numbers with space
|
||||||
|
# The lookaround enables overlapping pattern matches to be replaced
|
||||||
|
|
||||||
|
if remove_numbers:
|
||||||
|
|
||||||
|
digits_pattern = "[" + config["digit_set"]
|
||||||
|
|
||||||
|
digits_pattern += "]+"
|
||||||
|
|
||||||
|
complete_digit_pattern = (
|
||||||
|
r"^"
|
||||||
|
+ digits_pattern
|
||||||
|
+ "(?=\s)|(?<=\s)"
|
||||||
|
+ digits_pattern
|
||||||
|
+ "(?=\s)|(?<=\s)"
|
||||||
|
+ digits_pattern
|
||||||
|
+ "$"
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_text = re.sub(complete_digit_pattern, " ", normalized_text)
|
||||||
|
|
||||||
|
if config["rm_diacritics"]:
|
||||||
|
from unidecode import unidecode
|
||||||
|
normalized_text = unidecode(normalized_text)
|
||||||
|
|
||||||
|
# Remove extra spaces
|
||||||
|
normalized_text = re.sub(r"\s+", " ", normalized_text).strip()
|
||||||
|
|
||||||
|
return normalized_text
|
197
examples/mms/lid/infer.py
Normal file
197
examples/mms/lid/infer.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
import torch
|
||||||
|
from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor
|
||||||
|
from fairseq import checkpoint_utils, distributed_utils, options, utils
|
||||||
|
from fairseq import checkpoint_utils, data, options, tasks
|
||||||
|
from fairseq.data import FileAudioDataset, AddTargetDataset, Dictionary
|
||||||
|
from fairseq.tasks.audio_classification import LabelEncoder
|
||||||
|
import copy
|
||||||
|
from tqdm import tqdm
|
||||||
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def subset_manifest(infer_manifest, veri_pair):
|
||||||
|
with open(infer_manifest) as ff, open(veri_pair) as gg, tempfile.NamedTemporaryFile(
|
||||||
|
"w", delete=False
|
||||||
|
) as ww:
|
||||||
|
fnames = ff.read().strip().split("\n")
|
||||||
|
basedir = fnames[0]
|
||||||
|
needed_fname = []
|
||||||
|
for gi in gg.read().strip().split("\n"):
|
||||||
|
_, x1, x2 = gi.split()
|
||||||
|
needed_fname.append(x1)
|
||||||
|
needed_fname.append(x2)
|
||||||
|
needed_fname = set(needed_fname)
|
||||||
|
|
||||||
|
ww.write(basedir + "\n")
|
||||||
|
for ii in range(1, len(fnames)):
|
||||||
|
x1, x2 = fnames[ii].split()
|
||||||
|
if x1 in needed_fname:
|
||||||
|
ww.write(fnames[ii] + "\n")
|
||||||
|
print(f"| subset manifest for verification: {ww.name}")
|
||||||
|
return ww.name
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_target_dataset(infer_manifest, dataset, task):
|
||||||
|
label_path = infer_manifest.replace(".tsv", ".lang")
|
||||||
|
text_compressor = TextCompressor(level=TextCompressionLevel.none)
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
labels = [text_compressor.compress(l) for i,l in enumerate(f)]
|
||||||
|
assert len(labels) == len(dataset)
|
||||||
|
|
||||||
|
process_label = LabelEncoder(task.target_dictionary)
|
||||||
|
dataset = AddTargetDataset(
|
||||||
|
dataset,
|
||||||
|
labels,
|
||||||
|
pad=task.target_dictionary.pad(),
|
||||||
|
eos=task.target_dictionary.eos(),
|
||||||
|
batch_targets=True,
|
||||||
|
process_label=process_label,
|
||||||
|
add_to_input=False,
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def resample_data(source, padding_mask, n_sample, max_sample_len):
|
||||||
|
# source: BxT
|
||||||
|
# padding_mask: BxT
|
||||||
|
B = source.shape[0]
|
||||||
|
T = source.shape[1]
|
||||||
|
sources = []
|
||||||
|
padding_masks = []
|
||||||
|
if B == 1:
|
||||||
|
return [source], [None]
|
||||||
|
seq_len = (~padding_mask).sum(1)
|
||||||
|
for jj in range(n_sample):
|
||||||
|
new_source = source.new_zeros(B, max_sample_len)
|
||||||
|
new_padding_mask = padding_mask.new_zeros(B, max_sample_len)
|
||||||
|
for ii in range(B):
|
||||||
|
if seq_len[ii] > max_sample_len:
|
||||||
|
start = np.random.randint(0, seq_len[ii] - max_sample_len + 1)
|
||||||
|
end = start + max_sample_len
|
||||||
|
else:
|
||||||
|
start = 0
|
||||||
|
end = seq_len[ii]
|
||||||
|
new_source[ii, 0 : end - start] = source[ii, start:end]
|
||||||
|
new_padding_mask[ii, end - start + 1 :] = True
|
||||||
|
sources.append(new_source)
|
||||||
|
padding_masks.append(new_padding_mask)
|
||||||
|
return sources, padding_masks
|
||||||
|
|
||||||
|
|
||||||
|
def resample_sample(sample, n_sample, max_sample_len):
|
||||||
|
new_sources, new_padding_masks = resample_data(
|
||||||
|
sample["net_input"]["source"],
|
||||||
|
sample["net_input"]["padding_mask"],
|
||||||
|
n_sample,
|
||||||
|
max_sample_len,
|
||||||
|
)
|
||||||
|
new_samples = []
|
||||||
|
for ii in range(n_sample):
|
||||||
|
new_sample = copy.deepcopy(sample)
|
||||||
|
new_sample["net_input"]["source"] = new_sources[ii]
|
||||||
|
new_sample["net_input"]["padding_mask"] = new_padding_masks[ii]
|
||||||
|
new_samples.append(new_sample)
|
||||||
|
return new_samples
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_nparr(dd):
|
||||||
|
dict_class = []
|
||||||
|
dict_idx = []
|
||||||
|
for ii, jj in enumerate(dd.symbols):
|
||||||
|
dict_idx.append(ii)
|
||||||
|
dict_class.append(jj)
|
||||||
|
dict_idx = np.array(dict_idx)
|
||||||
|
dict_class = np.array(dict_class)
|
||||||
|
return dict_class, dict_idx
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
np.random.seed(123)
|
||||||
|
# Parse command-line arguments for generation
|
||||||
|
parser = options.get_generation_parser(default_task="audio_classification")
|
||||||
|
# parser.add_argument('--infer-merge', type=str, default='mean')
|
||||||
|
parser.add_argument("--infer-xtimes", type=int, default=1)
|
||||||
|
parser.add_argument("--infer-num-samples", type=int, default=None)
|
||||||
|
parser.add_argument("--top-k", type=int, default=3)
|
||||||
|
parser.add_argument(
|
||||||
|
"--infer-max-sample-size", type=int, default=5 * 16000
|
||||||
|
) # 5 secs
|
||||||
|
parser.add_argument("--infer-manifest", required=True, type=str)
|
||||||
|
parser.add_argument("--output-path", default="/tmp/", type=str)
|
||||||
|
|
||||||
|
args = options.parse_args_and_arch(parser)
|
||||||
|
# Setup task
|
||||||
|
# task = tasks.setup_task(args)
|
||||||
|
use_cuda = not args.cpu
|
||||||
|
|
||||||
|
# Load model & task
|
||||||
|
print("| loading model from {}".format(args.path))
|
||||||
|
arg_overrides = {
|
||||||
|
"task": {
|
||||||
|
"data": args.data
|
||||||
|
},
|
||||||
|
# 'mask_prob': 0
|
||||||
|
#'max_sample_size': sys.maxsize,
|
||||||
|
#'min_sample_size': 0,
|
||||||
|
}
|
||||||
|
state = checkpoint_utils.load_checkpoint_to_cpu(args.path, arg_overrides)
|
||||||
|
|
||||||
|
models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[args.path], arg_overrides=arg_overrides, task=None, state=state
|
||||||
|
)
|
||||||
|
model = models[0]
|
||||||
|
model.eval()
|
||||||
|
if use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
# Load dataset
|
||||||
|
|
||||||
|
dict_class, dict_idx = dict_to_nparr(task.target_dictionary)
|
||||||
|
|
||||||
|
infer_manifest = args.infer_manifest
|
||||||
|
infer_dataset = FileAudioDataset(
|
||||||
|
infer_manifest,
|
||||||
|
sample_rate=task.cfg.sample_rate,
|
||||||
|
max_sample_size=10**10, # task.cfg.max_sample_size,
|
||||||
|
min_sample_size=1, # task.cfg.min_sample_size,
|
||||||
|
pad=True,
|
||||||
|
normalize=task.cfg.normalize,
|
||||||
|
)
|
||||||
|
# add target (if needed)
|
||||||
|
infer_dataset = wrap_target_dataset(infer_manifest, infer_dataset, task)
|
||||||
|
|
||||||
|
itr = task.get_batch_iterator(
|
||||||
|
dataset=infer_dataset,
|
||||||
|
max_sentences=1,
|
||||||
|
# max_tokens=args.max_tokens,
|
||||||
|
num_workers=4,
|
||||||
|
).next_epoch_itr(shuffle=False)
|
||||||
|
predictions = {}
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, sample in tqdm(enumerate(itr)):
|
||||||
|
# resample if needed
|
||||||
|
samples = resample_sample(
|
||||||
|
sample, args.infer_xtimes, args.infer_max_sample_size
|
||||||
|
)
|
||||||
|
for sample in samples:
|
||||||
|
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||||
|
try:
|
||||||
|
latent = model.forward_latent(**sample["net_input"])
|
||||||
|
except:
|
||||||
|
latent = None
|
||||||
|
logit = model.forward(**sample["net_input"])
|
||||||
|
logit_lsm = torch.log_softmax(logit.squeeze(), dim=-1)
|
||||||
|
scores, indices = torch.topk(logit_lsm, args.top_k, dim=-1)
|
||||||
|
scores = torch.exp(scores).to("cpu").tolist()
|
||||||
|
indices = indices.to("cpu").tolist()
|
||||||
|
assert sample["id"].numel() == 1
|
||||||
|
sample_idx = sample["id"].to("cpu").tolist()[0]
|
||||||
|
assert sample_idx not in predictions
|
||||||
|
predictions[sample_idx] = [(task.target_dictionary[int(i)], s) for s, i in zip(scores, indices)]
|
||||||
|
|
||||||
|
with open(f"{args.output_path}/predictions.txt", "w") as fo:
|
||||||
|
for idx in range(len(infer_dataset)):
|
||||||
|
fo.write(json.dumps(predictions[idx]) + "\n")
|
||||||
|
|
||||||
|
print(f"Outputs will be located at - {args.output_path}/predictions.txt")
|
102
examples/mms/tts/infer.py
Normal file
102
examples/mms/tts/infer.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import commons
|
||||||
|
import utils
|
||||||
|
import argparse
|
||||||
|
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
|
||||||
|
from models import SynthesizerTrn
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
|
class TextMapper(object):
|
||||||
|
def __init__(self, vocab_file):
|
||||||
|
self.symbols = [x.replace("\n", "") for x in open(vocab_file).readlines()]
|
||||||
|
self.SPACE_ID = self.symbols.index(" ")
|
||||||
|
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
||||||
|
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
|
||||||
|
|
||||||
|
def text_to_sequence(self, text, cleaner_names):
|
||||||
|
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||||
|
Args:
|
||||||
|
text: string to convert to a sequence
|
||||||
|
cleaner_names: names of the cleaner functions to run the text through
|
||||||
|
Returns:
|
||||||
|
List of integers corresponding to the symbols in the text
|
||||||
|
'''
|
||||||
|
sequence = []
|
||||||
|
clean_text = text.strip()
|
||||||
|
for symbol in clean_text:
|
||||||
|
symbol_id = self._symbol_to_id[symbol]
|
||||||
|
sequence += [symbol_id]
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def get_text(self, text, hps):
|
||||||
|
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
|
||||||
|
if hps.data.add_blank:
|
||||||
|
text_norm = commons.intersperse(text_norm, 0)
|
||||||
|
text_norm = torch.LongTensor(text_norm)
|
||||||
|
return text_norm
|
||||||
|
|
||||||
|
def filter_oov(self, text):
|
||||||
|
val_chars = self._symbol_to_id
|
||||||
|
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
|
||||||
|
print(f"text after filtering OOV: {txt_filt}")
|
||||||
|
return txt_filt
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
parser = argparse.ArgumentParser(description='TTS inference')
|
||||||
|
parser.add_argument('--model-dir', type=str, help='model checkpoint dir')
|
||||||
|
parser.add_argument('--wav', type=str, help='output wav path')
|
||||||
|
parser.add_argument('--txt', type=str, help='input text')
|
||||||
|
args = parser.parse_args()
|
||||||
|
ckpt_dir, wav_path, txt = args.model_dir, args.wav, args.txt
|
||||||
|
|
||||||
|
vocab_file = f"{ckpt_dir}/vocab.txt"
|
||||||
|
config_file = f"{ckpt_dir}/config.json"
|
||||||
|
assert os.path.isfile(config_file), f"{config_file} doesn't exist"
|
||||||
|
hps = utils.get_hparams_from_file(config_file)
|
||||||
|
text_mapper = TextMapper(vocab_file)
|
||||||
|
net_g = SynthesizerTrn(
|
||||||
|
len(text_mapper.symbols),
|
||||||
|
hps.data.filter_length // 2 + 1,
|
||||||
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
|
**hps.model)
|
||||||
|
net_g.cuda()
|
||||||
|
_ = net_g.eval()
|
||||||
|
|
||||||
|
g_pth = f"{ckpt_dir}/G_100000.pth"
|
||||||
|
print(f"load {g_pth}")
|
||||||
|
|
||||||
|
_ = utils.load_checkpoint(g_pth, net_g, None)
|
||||||
|
|
||||||
|
print(f"text: {txt}")
|
||||||
|
txt = txt.lower()
|
||||||
|
txt = text_mapper.filter_oov(txt)
|
||||||
|
stn_tst = text_mapper.get_text(txt, hps)
|
||||||
|
with torch.no_grad():
|
||||||
|
x_tst = stn_tst.unsqueeze(0).cuda()
|
||||||
|
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
|
||||||
|
hyp = net_g.infer(
|
||||||
|
x_tst, x_tst_lengths, noise_scale=.667,
|
||||||
|
noise_scale_w=0.8, length_scale=1.0
|
||||||
|
)[0][0,0].cpu().float().numpy()
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
|
||||||
|
print(f"wav: {wav_path}")
|
||||||
|
write(wav_path, hps.data.sampling_rate, hyp)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
generate()
|
@ -10,6 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
@ -101,6 +102,29 @@ class InferenceProcessor:
|
|||||||
self.task = tasks.setup_task(cfg.task)
|
self.task = tasks.setup_task(cfg.task)
|
||||||
|
|
||||||
models, saved_cfg = self.load_model_ensemble()
|
models, saved_cfg = self.load_model_ensemble()
|
||||||
|
|
||||||
|
### LOAD ADAPTER ####
|
||||||
|
ckpt_obj = checkpoint_utils.load_checkpoint_to_cpu(self.cfg.common_eval.path)
|
||||||
|
if "adapter" in ckpt_obj:
|
||||||
|
target_lang = self.cfg.dataset.gen_subset.split(":")[0]
|
||||||
|
assert target_lang in ckpt_obj["adapter"]
|
||||||
|
|
||||||
|
logger.info(f">>> LOADING ADAPTER: {target_lang}")
|
||||||
|
ft_obj = ckpt_obj["adapter"][target_lang]
|
||||||
|
ft_model = ft_obj["model"]
|
||||||
|
cdevice = models[0].w2v_encoder.proj.weight.device
|
||||||
|
cdtype = models[0].w2v_encoder.proj.weight.dtype
|
||||||
|
ft_proj_out, ft_proj_in = ft_model["w2v_encoder.proj.weight"].shape
|
||||||
|
ft_proj = torch.nn.Linear(ft_proj_in, ft_proj_out, bias=True)
|
||||||
|
ft_proj.to(device=cdevice, dtype=cdtype)
|
||||||
|
models[0].w2v_encoder.proj = ft_proj
|
||||||
|
with torch.no_grad():
|
||||||
|
for kk, vv in models[0].named_parameters():
|
||||||
|
if kk in ft_model:
|
||||||
|
vv.copy_(ft_model[kk])
|
||||||
|
self.task.load_state_dict(ft_obj["task_state"])
|
||||||
|
# overwrite gen_subset with master config
|
||||||
|
self.cfg.dataset.gen_subset = re.sub('^[\w-]+:', saved_cfg['task']['multi_corpus_keys']+":", self.cfg.dataset.gen_subset)
|
||||||
self.models = models
|
self.models = models
|
||||||
self.saved_cfg = saved_cfg
|
self.saved_cfg = saved_cfg
|
||||||
self.tgt_dict = self.task.target_dictionary
|
self.tgt_dict = self.task.target_dictionary
|
||||||
|
@ -47,6 +47,7 @@ class RawAudioDataset(FairseqDataset):
|
|||||||
expand_adjacent: bool = False,
|
expand_adjacent: bool = False,
|
||||||
mask_dropout: float = 0,
|
mask_dropout: float = 0,
|
||||||
non_overlapping: bool = False,
|
non_overlapping: bool = False,
|
||||||
|
corpus_key=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -72,6 +73,7 @@ class RawAudioDataset(FairseqDataset):
|
|||||||
self.expand_adjacent = expand_adjacent
|
self.expand_adjacent = expand_adjacent
|
||||||
self.mask_dropout = mask_dropout
|
self.mask_dropout = mask_dropout
|
||||||
self.non_overlapping = non_overlapping
|
self.non_overlapping = non_overlapping
|
||||||
|
self.corpus_key = corpus_key
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -144,6 +146,8 @@ class RawAudioDataset(FairseqDataset):
|
|||||||
collated_sources[i] = self.crop_to_max_size(source, target_size)
|
collated_sources[i] = self.crop_to_max_size(source, target_size)
|
||||||
|
|
||||||
input = {"source": collated_sources}
|
input = {"source": collated_sources}
|
||||||
|
if self.corpus_key is not None:
|
||||||
|
input["corpus_key"] = [self.corpus_key] * len(sources)
|
||||||
out = {"id": torch.LongTensor([s["id"] for s in samples])}
|
out = {"id": torch.LongTensor([s["id"] for s in samples])}
|
||||||
if self.pad:
|
if self.pad:
|
||||||
input["padding_mask"] = padding_mask
|
input["padding_mask"] = padding_mask
|
||||||
|
@ -26,11 +26,13 @@ class Dictionary:
|
|||||||
eos="</s>",
|
eos="</s>",
|
||||||
unk="<unk>",
|
unk="<unk>",
|
||||||
extra_special_symbols=None,
|
extra_special_symbols=None,
|
||||||
|
add_special_symbols=True,
|
||||||
):
|
):
|
||||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||||
self.symbols = []
|
self.symbols = []
|
||||||
self.count = []
|
self.count = []
|
||||||
self.indices = {}
|
self.indices = {}
|
||||||
|
if add_special_symbols:
|
||||||
self.bos_index = self.add_symbol(bos)
|
self.bos_index = self.add_symbol(bos)
|
||||||
self.pad_index = self.add_symbol(pad)
|
self.pad_index = self.add_symbol(pad)
|
||||||
self.eos_index = self.add_symbol(eos)
|
self.eos_index = self.add_symbol(eos)
|
||||||
@ -213,7 +215,7 @@ class Dictionary:
|
|||||||
return self.unk_index
|
return self.unk_index
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, f):
|
def load(cls, f, add_special_symbols=True):
|
||||||
"""Loads the dictionary from a text file with the format:
|
"""Loads the dictionary from a text file with the format:
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -222,7 +224,7 @@ class Dictionary:
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
d = cls()
|
d = cls(add_special_symbols=add_special_symbols)
|
||||||
d.add_from_file(f)
|
d.add_from_file(f)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@ -7,3 +7,4 @@ from .wav2vec import * # noqa
|
|||||||
from .wav2vec2 import * # noqa
|
from .wav2vec2 import * # noqa
|
||||||
from .wav2vec2_asr import * # noqa
|
from .wav2vec2_asr import * # noqa
|
||||||
from .wav2vec2_laser import * # noqa
|
from .wav2vec2_laser import * # noqa
|
||||||
|
from .wav2vec2_classification import * # noqa
|
||||||
|
@ -17,6 +17,7 @@ from fairseq.data.data_utils import compute_mask_indices
|
|||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.distributed import fsdp_wrap
|
from fairseq.distributed import fsdp_wrap
|
||||||
from fairseq.models import BaseFairseqModel, register_model
|
from fairseq.models import BaseFairseqModel, register_model
|
||||||
|
from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||||
from fairseq.modules import (
|
from fairseq.modules import (
|
||||||
Fp32GroupNorm,
|
Fp32GroupNorm,
|
||||||
Fp32LayerNorm,
|
Fp32LayerNorm,
|
||||||
@ -37,7 +38,7 @@ from .utils import pad_to_multiple
|
|||||||
|
|
||||||
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
||||||
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
|
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
|
||||||
LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer"])
|
LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -289,6 +290,20 @@ class Wav2Vec2Config(FairseqDataclass):
|
|||||||
)
|
)
|
||||||
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
|
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
|
||||||
|
|
||||||
|
# Adapter num
|
||||||
|
adp_num: int = field(
|
||||||
|
default=-1
|
||||||
|
)
|
||||||
|
adp_dim: int = field(
|
||||||
|
default=64
|
||||||
|
)
|
||||||
|
adp_act_fn: str = field(
|
||||||
|
default="relu"
|
||||||
|
)
|
||||||
|
adp_trf_idx: str = field(
|
||||||
|
default="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model("wav2vec2", dataclass=Wav2Vec2Config)
|
@register_model("wav2vec2", dataclass=Wav2Vec2Config)
|
||||||
class Wav2Vec2Model(BaseFairseqModel):
|
class Wav2Vec2Model(BaseFairseqModel):
|
||||||
@ -588,6 +603,7 @@ class Wav2Vec2Model(BaseFairseqModel):
|
|||||||
mask_indices=None,
|
mask_indices=None,
|
||||||
mask_channel_indices=None,
|
mask_channel_indices=None,
|
||||||
padding_count=None,
|
padding_count=None,
|
||||||
|
corpus_key=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.feature_grad_mult > 0:
|
if self.feature_grad_mult > 0:
|
||||||
@ -672,7 +688,9 @@ class Wav2Vec2Model(BaseFairseqModel):
|
|||||||
y = unmasked_features
|
y = unmasked_features
|
||||||
mask_indices = None
|
mask_indices = None
|
||||||
|
|
||||||
x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer)
|
x, layer_results = self.encoder(
|
||||||
|
x, padding_mask=padding_mask, layer=layer, corpus_key=corpus_key
|
||||||
|
)
|
||||||
|
|
||||||
if features_only:
|
if features_only:
|
||||||
return {
|
return {
|
||||||
@ -774,9 +792,16 @@ class Wav2Vec2Model(BaseFairseqModel):
|
|||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
return self.quantizer.forward_idx(x)
|
return self.quantizer.forward_idx(x)
|
||||||
|
|
||||||
def extract_features(self, source, padding_mask, mask=False, layer=None):
|
def extract_features(
|
||||||
|
self, source, padding_mask, mask=False, layer=None, corpus_key=None
|
||||||
|
):
|
||||||
res = self.forward(
|
res = self.forward(
|
||||||
source, padding_mask, mask=mask, features_only=True, layer=layer
|
source,
|
||||||
|
padding_mask,
|
||||||
|
mask=mask,
|
||||||
|
features_only=True,
|
||||||
|
layer=layer,
|
||||||
|
corpus_key=corpus_key,
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -917,7 +942,7 @@ def make_conv_pos(e, k, g):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
def build_encoder_layer(self, args: Wav2Vec2Config):
|
def build_encoder_layer(self, args: Wav2Vec2Config, layer_idx: int):
|
||||||
if args.layer_type == "transformer":
|
if args.layer_type == "transformer":
|
||||||
layer = TransformerSentenceEncoderLayer(
|
layer = TransformerSentenceEncoderLayer(
|
||||||
embedding_dim=self.embedding_dim,
|
embedding_dim=self.embedding_dim,
|
||||||
@ -941,6 +966,40 @@ class TransformerEncoder(nn.Module):
|
|||||||
use_fp16=args.fp16,
|
use_fp16=args.fp16,
|
||||||
pos_enc_type="abs",
|
pos_enc_type="abs",
|
||||||
)
|
)
|
||||||
|
elif args.layer_type == "trf_adp":
|
||||||
|
use_adp = False
|
||||||
|
if args.adp_trf_idx == "all":
|
||||||
|
use_adp = True
|
||||||
|
else:
|
||||||
|
adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")]))
|
||||||
|
if layer_idx in adp_trf_idx:
|
||||||
|
use_adp = True
|
||||||
|
if use_adp:
|
||||||
|
layer = TransformerSentenceEncoderWithAdapterLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
adapter_num=args.adp_num,
|
||||||
|
adapter_dim=args.adp_dim,
|
||||||
|
adapter_act_fn=args.adp_act_fn,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer = TransformerSentenceEncoderLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
)
|
||||||
|
|
||||||
layer = fsdp_wrap(layer)
|
layer = fsdp_wrap(layer)
|
||||||
if args.checkpoint_activations:
|
if args.checkpoint_activations:
|
||||||
layer = checkpoint_wrapper(layer)
|
layer = checkpoint_wrapper(layer)
|
||||||
@ -991,7 +1050,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[self.build_encoder_layer(args) for _ in range(args.encoder_layers)]
|
[self.build_encoder_layer(args, ii) for ii in range(args.encoder_layers)]
|
||||||
)
|
)
|
||||||
self.layer_norm_first = args.layer_norm_first
|
self.layer_norm_first = args.layer_norm_first
|
||||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||||
@ -999,8 +1058,10 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
self.apply(init_bert_params)
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
def forward(self, x, padding_mask=None, layer=None):
|
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
|
||||||
x, layer_results = self.extract_features(x, padding_mask, layer)
|
x, layer_results = self.extract_features(
|
||||||
|
x, padding_mask, layer, corpus_key=corpus_key
|
||||||
|
)
|
||||||
|
|
||||||
if self.layer_norm_first and layer is None:
|
if self.layer_norm_first and layer is None:
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
@ -1013,6 +1074,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
padding_mask=None,
|
padding_mask=None,
|
||||||
tgt_layer=None,
|
tgt_layer=None,
|
||||||
min_layer=0,
|
min_layer=0,
|
||||||
|
corpus_key=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if padding_mask is not None:
|
if padding_mask is not None:
|
||||||
@ -1043,12 +1105,29 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
layer_results = []
|
layer_results = []
|
||||||
r = None
|
r = None
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||||
if not self.training or (dropout_probability > self.layerdrop):
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
layer_check = layer
|
||||||
|
if isinstance(layer, FullyShardedDataParallel):
|
||||||
|
layer_check = layer.unwrapped_module
|
||||||
|
if (corpus_key is None) or (
|
||||||
|
not isinstance(layer_check, (
|
||||||
|
TransformerSentenceEncoderWithAdapterLayer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
x, (z, lr) = layer(
|
x, (z, lr) = layer(
|
||||||
x, self_attn_padding_mask=padding_mask, need_weights=False
|
x, self_attn_padding_mask=padding_mask, need_weights=False
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
x, (z, lr) = layer(
|
||||||
|
x,
|
||||||
|
self_attn_padding_mask=padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
corpus_key=corpus_key,
|
||||||
|
)
|
||||||
if i >= min_layer:
|
if i >= min_layer:
|
||||||
layer_results.append((x, z, lr))
|
layer_results.append((x, z, lr))
|
||||||
if i == tgt_layer:
|
if i == tgt_layer:
|
||||||
@ -1282,3 +1361,125 @@ class TransformerSentenceEncoderLayer(nn.Module):
|
|||||||
x = self.final_layer_norm(x)
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
return x, (attn, layer_result)
|
return x, (attn, layer_result)
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterFast(nn.Module):
|
||||||
|
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
|
||||||
|
"""
|
||||||
|
Implements adapter modules directly with 3D tensor weight as parameters
|
||||||
|
and without using ModuleList orto speed up training throughput.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.adapter_num = adapter_num
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
|
||||||
|
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
|
||||||
|
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
|
||||||
|
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
||||||
|
|
||||||
|
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
|
||||||
|
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
||||||
|
self.act_fn = nn.Identity()
|
||||||
|
if act_fn == "relu":
|
||||||
|
self.act_fn = nn.ReLU()
|
||||||
|
elif act_fn == "gelu":
|
||||||
|
self.act_fn = nn.GELU()
|
||||||
|
elif act_fn == "selu":
|
||||||
|
self.act_fn = nn.SELU()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported {act_fn}")
|
||||||
|
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for ii in range(self.adapter_num):
|
||||||
|
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
|
||||||
|
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
|
||||||
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
|
||||||
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||||
|
nn.init.uniform_(self.b_a[ii], -bound, bound)
|
||||||
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
|
||||||
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||||
|
nn.init.uniform_(self.b_b[ii], -bound, bound)
|
||||||
|
|
||||||
|
nn.init.ones_(self.ln_W)
|
||||||
|
nn.init.zeros_(self.ln_b)
|
||||||
|
|
||||||
|
def forward(self, x, adapter_id):
|
||||||
|
ii = adapter_id
|
||||||
|
h = x
|
||||||
|
h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii])
|
||||||
|
h = F.linear(h, self.W_a[ii], self.b_a[ii])
|
||||||
|
h = self.act_fn(h)
|
||||||
|
h = F.linear(h, self.W_b[ii], self.b_b[ii])
|
||||||
|
outputs = h
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
|
||||||
|
"""
|
||||||
|
Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
|
||||||
|
models. An adapter module is added along with vanilla Transformer module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
ffn_embedding_dim: float = 3072,
|
||||||
|
num_attention_heads: int = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
attention_dropout: float = 0.1,
|
||||||
|
activation_dropout: float = 0.1,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
layer_norm_first: bool = False,
|
||||||
|
adapter_num=201,
|
||||||
|
adapter_dim=64,
|
||||||
|
adapter_act_fn="relu",
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
ffn_embedding_dim=ffn_embedding_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
activation_dropout=activation_dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
layer_norm_first=layer_norm_first,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adapter_num = adapter_num
|
||||||
|
self.adapter_dim = adapter_dim
|
||||||
|
self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
self_attn_mask: torch.Tensor = None,
|
||||||
|
self_attn_padding_mask: torch.Tensor = None,
|
||||||
|
need_weights: bool = False,
|
||||||
|
att_args=None,
|
||||||
|
corpus_key=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
x, (attn, layer_result) = super().forward(
|
||||||
|
x=x,
|
||||||
|
self_attn_mask=self_attn_mask,
|
||||||
|
self_attn_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
att_args=att_args,
|
||||||
|
)
|
||||||
|
assert corpus_key is not None
|
||||||
|
assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
|
||||||
|
y = self.adapter_layer(x, corpus_key[0])
|
||||||
|
x = x + y
|
||||||
|
return x, (attn, layer_result)
|
||||||
|
@ -28,7 +28,7 @@ from fairseq.models import (
|
|||||||
FairseqIncrementalDecoder,
|
FairseqIncrementalDecoder,
|
||||||
register_model,
|
register_model,
|
||||||
)
|
)
|
||||||
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, LAYER_TYPE_CHOICES, AdapterFast
|
||||||
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer
|
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer
|
||||||
from fairseq.tasks import FairseqTask
|
from fairseq.tasks import FairseqTask
|
||||||
|
|
||||||
@ -178,6 +178,27 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
|
|||||||
layer_decay: float = 1
|
layer_decay: float = 1
|
||||||
|
|
||||||
|
|
||||||
|
layer_type: LAYER_TYPE_CHOICES = field(
|
||||||
|
default="transformer", metadata={"help": "layer type in encoder"}
|
||||||
|
)
|
||||||
|
# Adapter num
|
||||||
|
adp_num: int = field(
|
||||||
|
default=-1
|
||||||
|
)
|
||||||
|
adp_dim: int = field(
|
||||||
|
default=64
|
||||||
|
)
|
||||||
|
adp_act_fn: str = field(
|
||||||
|
default="relu"
|
||||||
|
)
|
||||||
|
adp_trf_idx: str = field(
|
||||||
|
default="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
freeze_regex: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
|
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
|
||||||
blank_weight: float = 0
|
blank_weight: float = 0
|
||||||
@ -416,6 +437,14 @@ class Wav2VecEncoder(FairseqEncoder):
|
|||||||
"Please check that --normalize is set or unset for both pre-training and here"
|
"Please check that --normalize is set or unset for both pre-training and here"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with open_dict(w2v_args):
|
||||||
|
args_replacement = ["checkpoint_activations", "layer_type",
|
||||||
|
"adp_num", "adp_dim",
|
||||||
|
"adp_act_fn", "adp_trf_idx"]
|
||||||
|
for _args in args_replacement:
|
||||||
|
if hasattr(cfg, _args) and getattr(cfg, _args, None) is not None:
|
||||||
|
w2v_args.model[_args] = getattr(cfg, _args, None)
|
||||||
|
|
||||||
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
|
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
|
||||||
with open_dict(w2v_args):
|
with open_dict(w2v_args):
|
||||||
w2v_args.model.checkpoint_activations = cfg.checkpoint_activations
|
w2v_args.model.checkpoint_activations = cfg.checkpoint_activations
|
||||||
@ -423,7 +452,6 @@ class Wav2VecEncoder(FairseqEncoder):
|
|||||||
w2v_args.task.data = cfg.data
|
w2v_args.task.data = cfg.data
|
||||||
task = tasks.setup_task(w2v_args.task, from_checkpoint=True)
|
task = tasks.setup_task(w2v_args.task, from_checkpoint=True)
|
||||||
model = task.build_model(w2v_args.model, from_checkpoint=True)
|
model = task.build_model(w2v_args.model, from_checkpoint=True)
|
||||||
|
|
||||||
model.remove_pretraining_modules()
|
model.remove_pretraining_modules()
|
||||||
d = w2v_args.model.encoder_embed_dim
|
d = w2v_args.model.encoder_embed_dim
|
||||||
else:
|
else:
|
||||||
@ -468,6 +496,9 @@ class Wav2VecEncoder(FairseqEncoder):
|
|||||||
if targ_d is not None:
|
if targ_d is not None:
|
||||||
self.proj = Linear(d, targ_d)
|
self.proj = Linear(d, targ_d)
|
||||||
|
|
||||||
|
if cfg.freeze_regex is not None:
|
||||||
|
self.freeze_regex(cfg.freeze_regex)
|
||||||
|
|
||||||
layer_decay = getattr(cfg, "layer_decay", 1)
|
layer_decay = getattr(cfg, "layer_decay", 1)
|
||||||
if layer_decay < 1:
|
if layer_decay < 1:
|
||||||
mod_encs = list(model.modality_encoders.values())
|
mod_encs = list(model.modality_encoders.values())
|
||||||
@ -491,6 +522,14 @@ class Wav2VecEncoder(FairseqEncoder):
|
|||||||
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
|
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
|
||||||
p.optim_overrides = optim_override
|
p.optim_overrides = optim_override
|
||||||
|
|
||||||
|
def freeze_regex(self, pattern):
|
||||||
|
unfrozen_names = []
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if re.fullmatch(pattern, name) is not None:
|
||||||
|
param.requires_grad_(False)
|
||||||
|
else:
|
||||||
|
unfrozen_names.append(name)
|
||||||
|
|
||||||
def load_model_weights(self, state, model, cfg):
|
def load_model_weights(self, state, model, cfg):
|
||||||
if cfg.ddp_backend == "fully_sharded":
|
if cfg.ddp_backend == "fully_sharded":
|
||||||
from fairseq.distributed import FullyShardedDataParallel
|
from fairseq.distributed import FullyShardedDataParallel
|
||||||
@ -553,6 +592,8 @@ class Wav2VecEncoder(FairseqEncoder):
|
|||||||
"padding_mask": padding_mask,
|
"padding_mask": padding_mask,
|
||||||
"mask": self.apply_mask and self.training,
|
"mask": self.apply_mask and self.training,
|
||||||
}
|
}
|
||||||
|
if "corpus_key" in kwargs:
|
||||||
|
w2v_args["corpus_key"] = kwargs["corpus_key"]
|
||||||
|
|
||||||
if self.is_d2v_multi:
|
if self.is_d2v_multi:
|
||||||
w2v_args["mode"] = "AUDIO"
|
w2v_args["mode"] = "AUDIO"
|
||||||
|
348
fairseq/models/wav2vec/wav2vec2_classification.py
Normal file
348
fairseq/models/wav2vec/wav2vec2_classification.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from argparse import Namespace
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from omegaconf import II, MISSING, open_dict
|
||||||
|
|
||||||
|
from fairseq import checkpoint_utils, tasks, utils
|
||||||
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||||
|
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
|
||||||
|
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, Wav2Vec2Config
|
||||||
|
from fairseq.models.wav2vec.wav2vec2_asr import Embedding, Linear, Wav2VecEncoder, Wav2Vec2AsrConfig
|
||||||
|
from fairseq.tasks import FairseqTask
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Wav2Vec2ClassificationConfig(Wav2Vec2AsrConfig):
|
||||||
|
latent_embed_dim: Optional[int] = field(
|
||||||
|
default=None, metadata={"help": "latent dim (encoder w2v -> latent -> class"}
|
||||||
|
)
|
||||||
|
pooling: str = field(
|
||||||
|
default="first_token",
|
||||||
|
metadata={"help": "pooling layer choices"},
|
||||||
|
)
|
||||||
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
default="gelu", metadata={"help": "activation function to use"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("wav2vec_classification", dataclass=Wav2Vec2ClassificationConfig)
|
||||||
|
class Wav2VecClassification(BaseFairseqModel):
|
||||||
|
# TODO: Can be shared/merged with ASR model class as w2v_encoder params are common.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
w2v_encoder: BaseFairseqModel,
|
||||||
|
pooling_layer,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.w2v_encoder = w2v_encoder
|
||||||
|
self.pooling_layer = pooling_layer
|
||||||
|
|
||||||
|
def upgrade_state_dict_named(self, state_dict, name):
|
||||||
|
super().upgrade_state_dict_named(state_dict, name)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, cfg: Wav2Vec2ClassificationConfig, task: FairseqTask):
|
||||||
|
"""Build a new model instance."""
|
||||||
|
w2v_encoder = Wav2VecEncoder(cfg, None)
|
||||||
|
pooling_layer = get_pooling_layer(
|
||||||
|
cfg,
|
||||||
|
w2v_encoder.w2v_model.encoder.layers[-1].embedding_dim,
|
||||||
|
len(task.target_dictionary),
|
||||||
|
len(w2v_encoder.w2v_model.encoder.layers),
|
||||||
|
)
|
||||||
|
return cls(cfg, w2v_encoder, pooling_layer)
|
||||||
|
|
||||||
|
def get_normalized_probs(self, net_output, log_probs):
|
||||||
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||||
|
logits = net_output
|
||||||
|
|
||||||
|
if log_probs:
|
||||||
|
return utils.log_softmax(logits.float(), dim=-1)
|
||||||
|
else:
|
||||||
|
return utils.softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
|
def get_logits(self, net_output):
|
||||||
|
return net_output
|
||||||
|
|
||||||
|
def forward(self, **kwargs):
|
||||||
|
encoder_out_dict = self.w2v_encoder(**kwargs)
|
||||||
|
w2v_encoder_out = encoder_out_dict["encoder_out"] # TxBxC
|
||||||
|
w2v_encoder_padding_mask = encoder_out_dict["padding_mask"] # BxT
|
||||||
|
# w2v_encoder_layer_results = encoder_out_dict["layer_results"]
|
||||||
|
return self.pooling_layer(
|
||||||
|
last_layer_feats=w2v_encoder_out,
|
||||||
|
padding_mask=w2v_encoder_padding_mask,
|
||||||
|
# all_layer_feats=w2v_encoder_layer_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
# def forward_latent(self, **kwargs):
|
||||||
|
# encoder_out_dict = self.w2v_encoder(**kwargs)
|
||||||
|
# w2v_encoder_out = encoder_out_dict["encoder_out"]
|
||||||
|
# w2v_encoder_padding_mask = encoder_out_dict["encoder_padding_mask"]
|
||||||
|
# w2v_encoder_layer_results = encoder_out_dict["layer_results"]
|
||||||
|
# return self.pooling_layer.forward_latent(
|
||||||
|
# last_layer_feats=w2v_encoder_out,
|
||||||
|
# padding_mask=w2v_encoder_padding_mask,
|
||||||
|
# all_layer_feats=w2v_encoder_layer_results,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def get_pooling_layer(
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
encoder_embed_dim: int,
|
||||||
|
num_targets: int,
|
||||||
|
encoder_layers: int,
|
||||||
|
):
|
||||||
|
assert cfg.pooling == 'mean'
|
||||||
|
if cfg.pooling == "first_token":
|
||||||
|
return FirstToken(cfg, encoder_embed_dim, num_targets)
|
||||||
|
# elif cfg.pooling == "mean":
|
||||||
|
# return MeanPooling(cfg, encoder_embed_dim, num_targets)
|
||||||
|
elif cfg.pooling == "mean":
|
||||||
|
return MeanPoolingFast(cfg, encoder_embed_dim, num_targets)
|
||||||
|
elif cfg.pooling == "mean_amsoftmax":
|
||||||
|
return MeanPoolingFastAMSoftmax(cfg, encoder_embed_dim, num_targets)
|
||||||
|
elif cfg.pooling == "max":
|
||||||
|
return MaxPoolingFast(cfg, encoder_embed_dim, num_targets)
|
||||||
|
elif cfg.pooling == "elmo":
|
||||||
|
return LayerWeightedMeanPooling(
|
||||||
|
cfg, encoder_embed_dim, num_targets, encoder_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{cfg.pooling} has not been implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
|
class Pooling(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
encoder_embed_dim: int,
|
||||||
|
num_targets: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.projection = Linear(encoder_embed_dim, num_targets)
|
||||||
|
|
||||||
|
def forward(self, last_layer_feats, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class FirstToken(Pooling):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, last_layer_feats, **kwargs):
|
||||||
|
return self.projection(last_layer_feats[:, 0])
|
||||||
|
|
||||||
|
|
||||||
|
# class MeanPooling(Pooling):
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# cfg: Wav2VecClassificationConfig,
|
||||||
|
# encoder_embed_dim: int,
|
||||||
|
# num_targets: int,
|
||||||
|
# **kwargs,
|
||||||
|
# ):
|
||||||
|
# super().__init__(cfg, encoder_embed_dim, num_targets)
|
||||||
|
# self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
|
||||||
|
# self.linear = Linear(encoder_embed_dim, encoder_embed_dim)
|
||||||
|
|
||||||
|
# def forward(self, last_layer_feats, padding_mask, **kwargs):
|
||||||
|
# # last_layer_feats: [BxTxD]
|
||||||
|
# # padding_mask: [BxT]
|
||||||
|
# last_layer_feats = self.linear(self.activation_fn(last_layer_feats))
|
||||||
|
# input_lengths = (1 - padding_mask.long()).sum(-1)
|
||||||
|
# pooled_feature_list = []
|
||||||
|
# for i in range(len(last_layer_feats)):
|
||||||
|
# length = input_lengths[i]
|
||||||
|
# pooled_feature = torch.mean(last_layer_feats[i][:length], dim=0)
|
||||||
|
# pooled_feature_list.append(pooled_feature)
|
||||||
|
# return self.projection(torch.stack(pooled_feature_list))
|
||||||
|
|
||||||
|
|
||||||
|
def fn_mean(x, mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: TxBxD
|
||||||
|
mask: BxT
|
||||||
|
Return:
|
||||||
|
y: BxD
|
||||||
|
"""
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.t()[:, :, None]
|
||||||
|
return (x * mask).sum(0) / mask.sum(0)
|
||||||
|
else:
|
||||||
|
return x.sum(0) / x.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
class MeanPoolingFast(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
encoder_embed_dim: int,
|
||||||
|
num_targets: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
|
||||||
|
self.latent_embed_dim = (
|
||||||
|
cfg.latent_embed_dim
|
||||||
|
if cfg.latent_embed_dim is not None
|
||||||
|
else encoder_embed_dim
|
||||||
|
)
|
||||||
|
logging.debug(f"| {self.latent_embed_dim=}")
|
||||||
|
self.linear = Linear(encoder_embed_dim, self.latent_embed_dim)
|
||||||
|
self.projection = Linear(self.latent_embed_dim, num_targets)
|
||||||
|
|
||||||
|
def forward(self, last_layer_feats, padding_mask, **kwargs):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
features - [TxBxD] Acoustic feature with shape
|
||||||
|
padding_mask - [BxT] Padding Mask
|
||||||
|
"""
|
||||||
|
if padding_mask is not None:
|
||||||
|
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
|
||||||
|
else:
|
||||||
|
feat_mask = None
|
||||||
|
feat = self.linear(last_layer_feats)
|
||||||
|
feat = fn_mean(feat, feat_mask)
|
||||||
|
feat = self.activation_fn(feat)
|
||||||
|
return self.projection(feat)
|
||||||
|
|
||||||
|
def forward_latent(self, last_layer_feats, padding_mask, **kwargs):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
features - [TxBxD] Acoustic feature with shape
|
||||||
|
padding_mask - [BxT] Padding Mask
|
||||||
|
"""
|
||||||
|
if padding_mask is not None:
|
||||||
|
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
|
||||||
|
else:
|
||||||
|
feat_mask = None
|
||||||
|
feat = self.linear(last_layer_feats)
|
||||||
|
feat = fn_mean(feat, feat_mask)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
|
||||||
|
class MeanPoolingFastAMSoftmax(MeanPoolingFast):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
encoder_embed_dim: int,
|
||||||
|
num_targets: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(cfg, encoder_embed_dim, num_targets, **kwargs)
|
||||||
|
self.projection = Linear(self.latent_embed_dim, num_targets, bias=False)
|
||||||
|
nn.init.xavier_normal_(self.projection.weight, gain=1)
|
||||||
|
|
||||||
|
def forward(self, last_layer_feats, padding_mask, **kwargs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
features - [BxTxD] Acoustic feature with shape
|
||||||
|
padding_mask - [BxT] Padding Mask
|
||||||
|
"""
|
||||||
|
feat_mask = (~padding_mask).to(last_layer_feats.dtype) # T,B -> B,T
|
||||||
|
feat = self.linear(last_layer_feats) # B,T,D
|
||||||
|
feat = fn_mean(feat, feat_mask) # B,D
|
||||||
|
feat = self.activation_fn(feat)
|
||||||
|
# normalize feat
|
||||||
|
feat_norm = F.normalize(feat, p=2, dim=-1) # B,D
|
||||||
|
weight_norm = F.normalize(self.projection.weight.t(), p=2, dim=-1) # D,K
|
||||||
|
cos_fw = feat_norm @ weight_norm
|
||||||
|
return cos_fw
|
||||||
|
|
||||||
|
|
||||||
|
def fn_max(x, mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: TxBxD
|
||||||
|
mask: BxT
|
||||||
|
Return:
|
||||||
|
y: BxD
|
||||||
|
"""
|
||||||
|
mask = mask.t()[:, :, None].to(torch.bool)
|
||||||
|
return x.masked_fill(~mask, -1e-8).max(0)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPoolingFast(Pooling):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
encoder_embed_dim: int,
|
||||||
|
num_targets: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(cfg, encoder_embed_dim, num_targets)
|
||||||
|
self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
|
||||||
|
self.linear = Linear(encoder_embed_dim, encoder_embed_dim)
|
||||||
|
|
||||||
|
def forward(self, last_layer_feats, padding_mask, **kwargs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
features - [TxBxD] Acoustic feature with shape
|
||||||
|
padding_mask - [BxT] Padding Mask
|
||||||
|
"""
|
||||||
|
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
|
||||||
|
feat = self.linear(last_layer_feats)
|
||||||
|
feat = fn_max(feat, feat_mask)
|
||||||
|
feat = self.activation_fn(feat)
|
||||||
|
return self.projection(feat)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerWeightedMeanPooling(MeanPoolingFast):
|
||||||
|
"""Elmo-style weighted average representation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: Wav2Vec2ClassificationConfig,
|
||||||
|
encoder_embed_dim: int,
|
||||||
|
num_targets: int,
|
||||||
|
encoder_layers: int,
|
||||||
|
):
|
||||||
|
super().__init__(cfg, encoder_embed_dim, num_targets)
|
||||||
|
self.num_layers = encoder_layers
|
||||||
|
self.weights = nn.Parameter(torch.ones(encoder_layers))
|
||||||
|
|
||||||
|
def forward(self, last_layer_feats, padding_mask, all_layer_feats):
|
||||||
|
# last_layer_feats: [BxTxD]
|
||||||
|
# padding_mask: [BxT]
|
||||||
|
if not self.training:
|
||||||
|
msg = (
|
||||||
|
f"Number of layers in input features = {len(all_layer_feats)}."
|
||||||
|
f" Expected {self.num_layers} layers."
|
||||||
|
)
|
||||||
|
assert len(all_layer_feats) == self.num_layers, msg
|
||||||
|
|
||||||
|
# Stack up all layers and reshape to (num_layers, features)
|
||||||
|
all_layer_feats_stacked = torch.stack(all_layer_feats, dim=0)
|
||||||
|
num_layers, *original_feat_shape = all_layer_feats_stacked.shape
|
||||||
|
all_layer_feats_stacked_flat = all_layer_feats_stacked.view(num_layers, -1)
|
||||||
|
|
||||||
|
# Weighted average
|
||||||
|
normalized_weights = F.softmax(self.weights, dim=-1)
|
||||||
|
weighted_avg_features = (
|
||||||
|
normalized_weights.unsqueeze(-1) * all_layer_feats_stacked_flat
|
||||||
|
).sum(dim=0)
|
||||||
|
weighted_avg_features = weighted_avg_features.view(*original_feat_shape)
|
||||||
|
|
||||||
|
# Mean Pooling on weighted average features.
|
||||||
|
return super().forward(weighted_avg_features, padding_mask)
|
269
fairseq/tasks/audio_classification.py
Normal file
269
fairseq/tasks/audio_classification.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
# Copyright (c) 2017-present, Facebook, Inc.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the LICENSE file in
|
||||||
|
# the root directory of this source tree. An additional grant of patent rights
|
||||||
|
# can be found in the PATENTS file in the same directory.
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from omegaconf import II, MISSING
|
||||||
|
from sklearn import metrics as sklearn_metrics
|
||||||
|
|
||||||
|
from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset
|
||||||
|
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
|
||||||
|
from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor
|
||||||
|
from fairseq.dataclass import FairseqDataclass
|
||||||
|
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
|
||||||
|
from fairseq.tasks.audio_finetuning import label_len_fn, LabelEncoder
|
||||||
|
|
||||||
|
from .. import utils
|
||||||
|
from ..logging import metrics
|
||||||
|
from . import FairseqTask, register_task
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioClassificationConfig(AudioPretrainingConfig):
|
||||||
|
target_dictionary: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "override default dictionary location"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_task("audio_classification", dataclass=AudioClassificationConfig)
|
||||||
|
class AudioClassificationTask(AudioPretrainingTask):
|
||||||
|
"""Task for audio classification tasks."""
|
||||||
|
|
||||||
|
cfg: AudioClassificationConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: AudioClassificationConfig,
|
||||||
|
):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.state.add_factory("target_dictionary", self.load_target_dictionary)
|
||||||
|
logging.info(f"=== Number of labels = {len(self.target_dictionary)}")
|
||||||
|
|
||||||
|
def load_target_dictionary(self):
|
||||||
|
if self.cfg.labels:
|
||||||
|
target_dictionary = self.cfg.data
|
||||||
|
if self.cfg.target_dictionary: # override dict
|
||||||
|
target_dictionary = self.cfg.target_dictionary
|
||||||
|
dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt")
|
||||||
|
logger.info("Using dict_path : {}".format(dict_path))
|
||||||
|
return Dictionary.load(dict_path, add_special_symbols=False)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_dataset(
|
||||||
|
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
|
||||||
|
):
|
||||||
|
super().load_dataset(split, task_cfg, **kwargs)
|
||||||
|
task_cfg = task_cfg or self.cfg
|
||||||
|
assert task_cfg.labels is not None
|
||||||
|
text_compression_level = getattr(
|
||||||
|
TextCompressionLevel, str(self.cfg.text_compression_level)
|
||||||
|
)
|
||||||
|
data_path = self.cfg.data
|
||||||
|
if task_cfg.multi_corpus_keys is None:
|
||||||
|
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
|
||||||
|
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
|
||||||
|
text_compressor = TextCompressor(level=text_compression_level)
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
labels = [
|
||||||
|
text_compressor.compress(l)
|
||||||
|
for i, l in enumerate(f)
|
||||||
|
if i not in skipped_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(labels) == len(self.datasets[split]), (
|
||||||
|
f"labels length ({len(labels)}) and dataset length "
|
||||||
|
f"({len(self.datasets[split])}) do not match"
|
||||||
|
)
|
||||||
|
|
||||||
|
process_label = LabelEncoder(self.target_dictionary)
|
||||||
|
|
||||||
|
self.datasets[split] = AddTargetDataset(
|
||||||
|
self.datasets[split],
|
||||||
|
labels,
|
||||||
|
pad=self.target_dictionary.pad(),
|
||||||
|
eos=self.target_dictionary.eos(),
|
||||||
|
batch_targets=True,
|
||||||
|
process_label=process_label,
|
||||||
|
label_len_fn=label_len_fn,
|
||||||
|
add_to_input=False,
|
||||||
|
# text_compression_level=text_compression_level,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
target_dataset_map = OrderedDict()
|
||||||
|
|
||||||
|
multi_corpus_keys = [
|
||||||
|
k.strip() for k in task_cfg.multi_corpus_keys.split(",")
|
||||||
|
]
|
||||||
|
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
|
||||||
|
|
||||||
|
data_keys = [k.split(":") for k in split.split(",")]
|
||||||
|
|
||||||
|
multi_corpus_sampling_weights = [
|
||||||
|
float(val.strip())
|
||||||
|
for val in task_cfg.multi_corpus_sampling_weights.split(",")
|
||||||
|
]
|
||||||
|
data_weights = []
|
||||||
|
for key, file_name in data_keys:
|
||||||
|
k = key.strip()
|
||||||
|
label_path = os.path.join(
|
||||||
|
data_path, f"{file_name.strip()}.{task_cfg.labels}"
|
||||||
|
)
|
||||||
|
skipped_indices = getattr(
|
||||||
|
self.dataset_map[split][k], "skipped_indices", set()
|
||||||
|
)
|
||||||
|
text_compressor = TextCompressor(level=text_compression_level)
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
labels = [
|
||||||
|
text_compressor.compress(l)
|
||||||
|
for i, l in enumerate(f)
|
||||||
|
if i not in skipped_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(labels) == len(self.dataset_map[split][k]), (
|
||||||
|
f"labels length ({len(labels)}) and dataset length "
|
||||||
|
f"({len(self.dataset_map[split][k])}) do not match"
|
||||||
|
)
|
||||||
|
|
||||||
|
process_label = LabelEncoder(self.target_dictionary)
|
||||||
|
|
||||||
|
# TODO: Remove duplication of code from the if block above
|
||||||
|
target_dataset_map[k] = AddTargetDataset(
|
||||||
|
self.dataset_map[split][k],
|
||||||
|
labels,
|
||||||
|
pad=self.target_dictionary.pad(),
|
||||||
|
eos=self.target_dictionary.eos(),
|
||||||
|
batch_targets=True,
|
||||||
|
process_label=process_label,
|
||||||
|
label_len_fn=label_len_fn,
|
||||||
|
add_to_input=False,
|
||||||
|
# text_compression_level=text_compression_level,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])
|
||||||
|
|
||||||
|
if len(target_dataset_map) == 1:
|
||||||
|
self.datasets[split] = list(target_dataset_map.values())[0]
|
||||||
|
else:
|
||||||
|
self.datasets[split] = MultiCorpusDataset(
|
||||||
|
target_dataset_map,
|
||||||
|
distribution=data_weights,
|
||||||
|
seed=0,
|
||||||
|
sort_indices=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_dictionary(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_dictionary(self):
|
||||||
|
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||||
|
model."""
|
||||||
|
return self.state.target_dictionary
|
||||||
|
|
||||||
|
def train_step(self, sample, model, *args, **kwargs):
|
||||||
|
sample["target"] = sample["target"].to(dtype=torch.long)
|
||||||
|
loss, sample_size, logging_output = super().train_step(
|
||||||
|
sample, model, *args, **kwargs
|
||||||
|
)
|
||||||
|
self._log_metrics(sample, model, logging_output)
|
||||||
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
|
def valid_step(self, sample, model, criterion):
|
||||||
|
sample["target"] = sample["target"].to(dtype=torch.long)
|
||||||
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||||
|
self._log_metrics(sample, model, logging_output)
|
||||||
|
return loss, sample_size, logging_output
|
||||||
|
|
||||||
|
def _log_metrics(self, sample, model, logging_output):
|
||||||
|
metrics = self._inference_with_metrics(
|
||||||
|
sample,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
logging_output["_precision"] = metrics["precision"]
|
||||||
|
logging_output["_recall"] = metrics["recall"]
|
||||||
|
logging_output["_f1"] = metrics["f1"]
|
||||||
|
logging_output["_eer"] = metrics["eer"]
|
||||||
|
logging_output["_accuracy"] = metrics["accuracy"]
|
||||||
|
"""
|
||||||
|
logging_output["_correct"] = metrics["correct"]
|
||||||
|
logging_output["_total"] = metrics["total"]
|
||||||
|
|
||||||
|
def _inference_with_metrics(self, sample, model):
|
||||||
|
def _compute_eer(target_list, lprobs):
|
||||||
|
# from scipy.optimize import brentq
|
||||||
|
# from scipy.interpolate import interp1d
|
||||||
|
|
||||||
|
y_one_hot = np.eye(len(self.state.target_dictionary))[target_list]
|
||||||
|
fpr, tpr, thresholds = sklearn_metrics.roc_curve(
|
||||||
|
y_one_hot.ravel(), lprobs.ravel()
|
||||||
|
)
|
||||||
|
# Revisit the interpolation approach.
|
||||||
|
# eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
|
||||||
|
|
||||||
|
fnr = 1 - tpr
|
||||||
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
||||||
|
|
||||||
|
return eer
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
net_output = model(**sample["net_input"])
|
||||||
|
lprobs = (
|
||||||
|
model.get_normalized_probs(net_output, log_probs=True).cpu().detach()
|
||||||
|
)
|
||||||
|
target_list = sample["target"][:, 0].detach().cpu()
|
||||||
|
predicted_list = torch.argmax(lprobs, 1).detach().cpu() # B,C->B
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"correct": torch.sum(target_list == predicted_list).item(),
|
||||||
|
"total": len(target_list),
|
||||||
|
}
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def reduce_metrics(self, logging_outputs, criterion):
|
||||||
|
super().reduce_metrics(logging_outputs, criterion)
|
||||||
|
|
||||||
|
zero = torch.scalar_tensor(0.0)
|
||||||
|
correct, total = 0, 0
|
||||||
|
for log in logging_outputs:
|
||||||
|
correct += log.get("_correct", zero)
|
||||||
|
total += log.get("_total", zero)
|
||||||
|
metrics.log_scalar("_correct", correct)
|
||||||
|
metrics.log_scalar("_total", total)
|
||||||
|
|
||||||
|
if total > 0:
|
||||||
|
def _fn_accuracy(meters):
|
||||||
|
if meters["_total"].sum > 0:
|
||||||
|
return utils.item(meters["_correct"].sum / meters["_total"].sum)
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
metrics.log_derived("accuracy", _fn_accuracy)
|
||||||
|
"""
|
||||||
|
prec_sum, recall_sum, f1_sum, acc_sum, eer_sum = 0.0, 0.0, 0.0, 0.0, 0.0
|
||||||
|
for log in logging_outputs:
|
||||||
|
prec_sum += log.get("_precision", zero).item()
|
||||||
|
recall_sum += log.get("_recall", zero).item()
|
||||||
|
f1_sum += log.get("_f1", zero).item()
|
||||||
|
acc_sum += log.get("_accuracy", zero).item()
|
||||||
|
eer_sum += log.get("_eer", zero).item()
|
||||||
|
|
||||||
|
metrics.log_scalar("avg_precision", prec_sum / len(logging_outputs))
|
||||||
|
metrics.log_scalar("avg_recall", recall_sum / len(logging_outputs))
|
||||||
|
metrics.log_scalar("avg_f1", f1_sum / len(logging_outputs))
|
||||||
|
metrics.log_scalar("avg_accuracy", acc_sum / len(logging_outputs))
|
||||||
|
metrics.log_scalar("avg_eer", eer_sum / len(logging_outputs))
|
||||||
|
"""
|
@ -7,12 +7,13 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, OrderedDict
|
||||||
|
|
||||||
from fairseq.data import AddTargetDataset, Dictionary, encoders
|
from fairseq.data import AddTargetDataset, Dictionary, encoders
|
||||||
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
|
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
|
||||||
@ -101,7 +102,12 @@ class AudioFinetuningConfig(AudioPretrainingConfig):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
rebuild_batches: bool = True
|
rebuild_batches: bool = True
|
||||||
|
target_dictionary: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "override default dictionary location"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
|
@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
|
||||||
class AudioFinetuningTask(AudioPretrainingTask):
|
class AudioFinetuningTask(AudioPretrainingTask):
|
||||||
@ -120,7 +126,11 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
|||||||
|
|
||||||
def load_target_dictionary(self):
|
def load_target_dictionary(self):
|
||||||
if self.cfg.labels:
|
if self.cfg.labels:
|
||||||
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
|
target_dictionary = self.cfg.data
|
||||||
|
if self.cfg.target_dictionary: # override dict
|
||||||
|
target_dictionary = self.cfg.target_dictionary
|
||||||
|
dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt")
|
||||||
|
logger.info('Using dict_path : {}'.format(dict_path))
|
||||||
return Dictionary.load(dict_path)
|
return Dictionary.load(dict_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -135,6 +145,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
|||||||
TextCompressionLevel, str(self.cfg.text_compression_level)
|
TextCompressionLevel, str(self.cfg.text_compression_level)
|
||||||
)
|
)
|
||||||
data_path = self.cfg.data
|
data_path = self.cfg.data
|
||||||
|
if task_cfg.multi_corpus_keys is None:
|
||||||
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
|
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
|
||||||
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
|
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
|
||||||
text_compressor = TextCompressor(level=text_compression_level)
|
text_compressor = TextCompressor(level=text_compression_level)
|
||||||
@ -163,6 +174,55 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
|||||||
add_to_input=task_cfg.get("autoregressive", False),
|
add_to_input=task_cfg.get("autoregressive", False),
|
||||||
text_compression_level=text_compression_level,
|
text_compression_level=text_compression_level,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
target_dataset_map = OrderedDict()
|
||||||
|
|
||||||
|
multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")]
|
||||||
|
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
|
||||||
|
|
||||||
|
data_keys = [k.split(":") for k in split.split(",")]
|
||||||
|
|
||||||
|
multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")]
|
||||||
|
data_weights = []
|
||||||
|
for key, file_name in data_keys:
|
||||||
|
k = key.strip()
|
||||||
|
label_path = os.path.join(data_path, f"{file_name.strip()}.{task_cfg.labels}")
|
||||||
|
skipped_indices = getattr(self.dataset_map[split][k], "skipped_indices", set())
|
||||||
|
text_compressor = TextCompressor(level=text_compression_level)
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
labels = [
|
||||||
|
text_compressor.compress(l)
|
||||||
|
for i, l in enumerate(f)
|
||||||
|
if i not in skipped_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(labels) == len(self.dataset_map[split][k]), (
|
||||||
|
f"labels length ({len(labels)}) and dataset length "
|
||||||
|
f"({len(self.dataset_map[split][k])}) do not match"
|
||||||
|
)
|
||||||
|
|
||||||
|
process_label = LabelEncoder(self.target_dictionary)
|
||||||
|
|
||||||
|
# TODO: Remove duplication of code from the if block above
|
||||||
|
target_dataset_map[k] = AddTargetDataset(
|
||||||
|
self.dataset_map[split][k],
|
||||||
|
labels,
|
||||||
|
pad=self.target_dictionary.pad(),
|
||||||
|
eos=self.target_dictionary.eos(),
|
||||||
|
batch_targets=True,
|
||||||
|
process_label=process_label,
|
||||||
|
label_len_fn=label_len_fn,
|
||||||
|
add_to_input=task_cfg.get("autoregressive", False),
|
||||||
|
text_compression_level=text_compression_level,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])
|
||||||
|
|
||||||
|
if len(target_dataset_map) == 1:
|
||||||
|
self.datasets[split] = list(target_dataset_map.values())[0]
|
||||||
|
else:
|
||||||
|
self.datasets[split] = MultiCorpusDataset(target_dataset_map, distribution=data_weights, seed=0, sort_indices=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def target_dictionary(self):
|
def target_dictionary(self):
|
||||||
|
@ -11,8 +11,9 @@ import sys
|
|||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, OrderedDict
|
||||||
from omegaconf import MISSING, II
|
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
|
||||||
|
from omegaconf import MISSING, II, OmegaConf
|
||||||
|
|
||||||
from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset
|
from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset
|
||||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||||
@ -44,6 +45,12 @@ class AudioPretrainingConfig(FairseqDataclass):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "extension of the label file to load, used for fine-tuning"},
|
metadata={"help": "extension of the label file to load, used for fine-tuning"},
|
||||||
)
|
)
|
||||||
|
multi_corpus_keys: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Comma separated names for loading multi corpus datasets"})
|
||||||
|
multi_corpus_sampling_weights: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Comma separated string of sampling weights corresponding to the multi_corpus_keys"})
|
||||||
binarized_dataset: bool = field(
|
binarized_dataset: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
@ -121,7 +128,7 @@ class AudioPretrainingTask(FairseqTask):
|
|||||||
TextCompressionLevel, str(self.cfg.text_compression_level)
|
TextCompressionLevel, str(self.cfg.text_compression_level)
|
||||||
)
|
)
|
||||||
|
|
||||||
compute_mask = task_cfg.precompute_mask_config is not None
|
compute_mask = getattr(task_cfg, "precompute_mask_config", None) is not None
|
||||||
mask_args = {}
|
mask_args = {}
|
||||||
if compute_mask:
|
if compute_mask:
|
||||||
mask_args = task_cfg.precompute_mask_config
|
mask_args = task_cfg.precompute_mask_config
|
||||||
@ -140,6 +147,7 @@ class AudioPretrainingTask(FairseqTask):
|
|||||||
**mask_args,
|
**mask_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if task_cfg.multi_corpus_keys is None:
|
||||||
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
|
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
|
||||||
|
|
||||||
self.datasets[split] = FileAudioDataset(
|
self.datasets[split] = FileAudioDataset(
|
||||||
@ -154,6 +162,44 @@ class AudioPretrainingTask(FairseqTask):
|
|||||||
compute_mask=compute_mask,
|
compute_mask=compute_mask,
|
||||||
**mask_args,
|
**mask_args,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
dataset_map = OrderedDict()
|
||||||
|
self.dataset_map = {}
|
||||||
|
multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")]
|
||||||
|
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
|
||||||
|
data_keys = [k.split(":") for k in split.split(",")]
|
||||||
|
|
||||||
|
multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")]
|
||||||
|
data_weights = []
|
||||||
|
|
||||||
|
for key, file_name in data_keys:
|
||||||
|
|
||||||
|
k = key.strip()
|
||||||
|
manifest_path = os.path.join(data_path, "{}.tsv".format(file_name.strip()))
|
||||||
|
|
||||||
|
# TODO: Remove duplication of code from the if block above
|
||||||
|
dataset_map[k] = FileAudioDataset(
|
||||||
|
manifest_path=manifest_path,
|
||||||
|
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
|
||||||
|
max_sample_size=self.cfg.max_sample_size,
|
||||||
|
min_sample_size=self.cfg.min_sample_size,
|
||||||
|
pad=task_cfg.labels is not None or task_cfg.enable_padding,
|
||||||
|
normalize=task_cfg.normalize,
|
||||||
|
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
|
||||||
|
text_compression_level=text_compression_level,
|
||||||
|
compute_mask=compute_mask,
|
||||||
|
corpus_key=corpus_idx_map[k],
|
||||||
|
**mask_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])
|
||||||
|
|
||||||
|
self.dataset_map[split] = dataset_map
|
||||||
|
|
||||||
|
if len(dataset_map) == 1:
|
||||||
|
self.datasets[split] = list(dataset_map.values())[0]
|
||||||
|
else:
|
||||||
|
self.datasets[split] = MultiCorpusDataset(dataset_map, distribution=data_weights, seed=0, sort_indices=True)
|
||||||
|
|
||||||
if getattr(task_cfg, "subsample", 1) < 1:
|
if getattr(task_cfg, "subsample", 1) < 1:
|
||||||
self.datasets[split] = SubsampleDataset(
|
self.datasets[split] = SubsampleDataset(
|
||||||
|
Loading…
Reference in New Issue
Block a user